Add support for WebSocket protocol negotiation. (#83)

This commit is contained in:
Sunli 2021-08-03 14:43:37 +08:00 committed by GitHub
parent 9fbababc3a
commit ba74787532
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 78 additions and 4 deletions

View file

@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- Fix stripping prefix when nesting services at `/` ([#91](https://github.com/tokio-rs/axum/pull/91))
- Add support for WebSocket protocol negotiation. ([#83](https://github.com/tokio-rs/axum/pull/83))
## Breaking changes

View file

@ -73,6 +73,7 @@ use http_body::Full;
use hyper::upgrade::{OnUpgrade, Upgraded};
use sha1::{Digest, Sha1};
use std::pin::Pin;
use std::sync::Arc;
use std::{
borrow::Cow, convert::Infallible, fmt, future::Future, marker::PhantomData, task::Context,
task::Poll,
@ -96,6 +97,7 @@ where
WebSocketUpgrade {
callback,
config: WebSocketConfig::default(),
protocols: Vec::new().into(),
_request_body: PhantomData,
}
}
@ -169,6 +171,7 @@ impl_ws_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T1
pub struct WebSocketUpgrade<F, B, T> {
callback: F,
config: WebSocketConfig,
protocols: Arc<[Cow<'static, str>]>,
_request_body: PhantomData<fn() -> (B, T)>,
}
@ -180,6 +183,7 @@ where
Self {
callback: self.callback.clone(),
config: self.config,
protocols: self.protocols.clone(),
_request_body: PhantomData,
}
}
@ -212,6 +216,47 @@ impl<F, B, T> WebSocketUpgrade<F, B, T> {
self.config.max_frame_size = Some(max);
self
}
/// Set the known protocols.
///
/// If the protocol name specified by `Sec-WebSocket-Protocol` header
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and return the protocol name.
///
/// # Examples
///
/// ```no_run
/// # use axum::prelude::*;
/// # use axum::ws::{ws, WebSocket};
/// # use std::net::SocketAddr;
/// #
/// # async fn handle_socket(socket: WebSocket) {
/// # todo!()
/// # }
/// #
/// # #[tokio::main]
/// # async fn main() {
/// let app = route("/ws", ws(handle_socket).protocols(["graphql-ws", "graphql-transport-ws"]));
/// #
/// # let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
/// # hyper::Server::bind(&addr)
/// # .serve(app.into_make_service())
/// # .await
/// # .unwrap();
/// # }
///
/// ```
pub fn protocols<I>(mut self, protocols: I) -> Self
where
I: IntoIterator,
I::Item: Into<Cow<'static, str>>,
{
self.protocols = protocols
.into_iter()
.map(Into::into)
.collect::<Vec<_>>()
.into();
self
}
}
impl<ReqBody, F, T> Service<Request<ReqBody>> for WebSocketUpgrade<F, ReqBody, T>
@ -230,6 +275,7 @@ where
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
let this = self.clone();
let protocols = self.protocols.clone();
ResponseFuture(Box::pin(async move {
if req.method() != http::Method::GET {
@ -257,6 +303,31 @@ where
);
}
// check requested protocols
let protocol =
req.headers()
.get(&header::SEC_WEBSOCKET_PROTOCOL)
.and_then(|req_protocols| {
let req_protocols = req_protocols.to_str().ok()?;
req_protocols
.split(',')
.map(|req_p| req_p.trim())
.find(|req_p| protocols.iter().any(|p| p == req_p))
});
let protocol = match protocol {
Some(protocol) => {
if let Ok(protocol) = HeaderValue::from_str(protocol) {
Some(protocol)
} else {
return response(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Protocol` header is invalid",
);
}
}
None => None,
};
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
key
} else {
@ -292,7 +363,7 @@ where
callback.call(socket, input).await;
});
let res = Response::builder()
let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(
http::header::CONNECTION,
@ -302,9 +373,11 @@ where
http::header::UPGRADE,
HeaderValue::from_str("websocket").unwrap(),
)
.header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()))
.body(box_body(Full::new(Bytes::new())))
.unwrap();
.header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()));
if let Some(protocol) = protocol {
builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
let res = builder.body(box_body(Full::new(Bytes::new()))).unwrap();
Ok(res)
}))