mirror of
https://github.com/tokio-rs/axum.git
synced 2025-03-13 19:27:53 +01:00
Add support for WebSocket protocol negotiation. (#83)
This commit is contained in:
parent
9fbababc3a
commit
ba74787532
2 changed files with 78 additions and 4 deletions
|
@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
# Unreleased
|
||||
|
||||
- Fix stripping prefix when nesting services at `/` ([#91](https://github.com/tokio-rs/axum/pull/91))
|
||||
- Add support for WebSocket protocol negotiation. ([#83](https://github.com/tokio-rs/axum/pull/83))
|
||||
|
||||
## Breaking changes
|
||||
|
||||
|
|
|
@ -73,6 +73,7 @@ use http_body::Full;
|
|||
use hyper::upgrade::{OnUpgrade, Upgraded};
|
||||
use sha1::{Digest, Sha1};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
borrow::Cow, convert::Infallible, fmt, future::Future, marker::PhantomData, task::Context,
|
||||
task::Poll,
|
||||
|
@ -96,6 +97,7 @@ where
|
|||
WebSocketUpgrade {
|
||||
callback,
|
||||
config: WebSocketConfig::default(),
|
||||
protocols: Vec::new().into(),
|
||||
_request_body: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -169,6 +171,7 @@ impl_ws_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T1
|
|||
pub struct WebSocketUpgrade<F, B, T> {
|
||||
callback: F,
|
||||
config: WebSocketConfig,
|
||||
protocols: Arc<[Cow<'static, str>]>,
|
||||
_request_body: PhantomData<fn() -> (B, T)>,
|
||||
}
|
||||
|
||||
|
@ -180,6 +183,7 @@ where
|
|||
Self {
|
||||
callback: self.callback.clone(),
|
||||
config: self.config,
|
||||
protocols: self.protocols.clone(),
|
||||
_request_body: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -212,6 +216,47 @@ impl<F, B, T> WebSocketUpgrade<F, B, T> {
|
|||
self.config.max_frame_size = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the known protocols.
|
||||
///
|
||||
/// If the protocol name specified by `Sec-WebSocket-Protocol` header
|
||||
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and return the protocol name.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use axum::prelude::*;
|
||||
/// # use axum::ws::{ws, WebSocket};
|
||||
/// # use std::net::SocketAddr;
|
||||
/// #
|
||||
/// # async fn handle_socket(socket: WebSocket) {
|
||||
/// # todo!()
|
||||
/// # }
|
||||
/// #
|
||||
/// # #[tokio::main]
|
||||
/// # async fn main() {
|
||||
/// let app = route("/ws", ws(handle_socket).protocols(["graphql-ws", "graphql-transport-ws"]));
|
||||
/// #
|
||||
/// # let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
/// # hyper::Server::bind(&addr)
|
||||
/// # .serve(app.into_make_service())
|
||||
/// # .await
|
||||
/// # .unwrap();
|
||||
/// # }
|
||||
///
|
||||
/// ```
|
||||
pub fn protocols<I>(mut self, protocols: I) -> Self
|
||||
where
|
||||
I: IntoIterator,
|
||||
I::Item: Into<Cow<'static, str>>,
|
||||
{
|
||||
self.protocols = protocols
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect::<Vec<_>>()
|
||||
.into();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<ReqBody, F, T> Service<Request<ReqBody>> for WebSocketUpgrade<F, ReqBody, T>
|
||||
|
@ -230,6 +275,7 @@ where
|
|||
|
||||
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
|
||||
let this = self.clone();
|
||||
let protocols = self.protocols.clone();
|
||||
|
||||
ResponseFuture(Box::pin(async move {
|
||||
if req.method() != http::Method::GET {
|
||||
|
@ -257,6 +303,31 @@ where
|
|||
);
|
||||
}
|
||||
|
||||
// check requested protocols
|
||||
let protocol =
|
||||
req.headers()
|
||||
.get(&header::SEC_WEBSOCKET_PROTOCOL)
|
||||
.and_then(|req_protocols| {
|
||||
let req_protocols = req_protocols.to_str().ok()?;
|
||||
req_protocols
|
||||
.split(',')
|
||||
.map(|req_p| req_p.trim())
|
||||
.find(|req_p| protocols.iter().any(|p| p == req_p))
|
||||
});
|
||||
let protocol = match protocol {
|
||||
Some(protocol) => {
|
||||
if let Ok(protocol) = HeaderValue::from_str(protocol) {
|
||||
Some(protocol)
|
||||
} else {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Protocol` header is invalid",
|
||||
);
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
|
||||
key
|
||||
} else {
|
||||
|
@ -292,7 +363,7 @@ where
|
|||
callback.call(socket, input).await;
|
||||
});
|
||||
|
||||
let res = Response::builder()
|
||||
let mut builder = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(
|
||||
http::header::CONNECTION,
|
||||
|
@ -302,9 +373,11 @@ where
|
|||
http::header::UPGRADE,
|
||||
HeaderValue::from_str("websocket").unwrap(),
|
||||
)
|
||||
.header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()))
|
||||
.body(box_body(Full::new(Bytes::new())))
|
||||
.unwrap();
|
||||
.header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()));
|
||||
if let Some(protocol) = protocol {
|
||||
builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
||||
}
|
||||
let res = builder.body(box_body(Full::new(Bytes::new()))).unwrap();
|
||||
|
||||
Ok(res)
|
||||
}))
|
||||
|
|
Loading…
Add table
Reference in a new issue