diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 0f371e79..a880e7c1 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -65,11 +65,7 @@ use self::rejection::*; use super::{rejection::*, FromRequest, RequestParts}; -use crate::{ - body, - response::{IntoResponse, Response}, - Error, -}; +use crate::{body, response::Response, Error}; use async_trait::async_trait; use bytes::Bytes; use futures_util::{ @@ -106,7 +102,8 @@ use tokio_tungstenite::{ #[derive(Debug)] pub struct WebSocketUpgrade { config: WebSocketConfig, - protocols: Option]>>, + /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. + protocol: Option, sec_websocket_key: HeaderValue, on_upgrade: OnUpgrade, sec_websocket_protocol: Option, @@ -137,6 +134,10 @@ impl WebSocketUpgrade { /// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and /// return the protocol name. /// + /// The protocols should be listed in decreasing order of preference: if the client offers + /// multiple protocols that the server could support, the server will pick the first one in + /// this list. + /// /// # Examples /// /// ``` @@ -164,13 +165,27 @@ impl WebSocketUpgrade { I: IntoIterator, I::Item: Into>, { - self.protocols = Some( - protocols + if let Some(req_protocols) = self + .sec_websocket_protocol + .as_ref() + .and_then(|p| p.to_str().ok()) + { + self.protocol = protocols .into_iter() + // FIXME: This will often allocate a new `String` and so is less efficient than it + // could be. But that can't be fixed without breaking changes to the public API. .map(Into::into) - .collect::>() - .into(), - ); + .find(|protocol| { + req_protocols + .split(',') + .any(|req_protocol| req_protocol.trim() == protocol) + }) + .map(|protocol| match protocol { + Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(), + Cow::Borrowed(s) => HeaderValue::from_static(s), + }); + } + self } @@ -180,15 +195,40 @@ impl WebSocketUpgrade { /// When using `WebSocketUpgrade`, the response produced by this method /// should be returned from the handler. See the [module docs](self) for an /// example. - pub fn on_upgrade(self, callback: F) -> impl IntoResponse + pub fn on_upgrade(self, callback: F) -> Response where F: FnOnce(WebSocket) -> Fut + Send + 'static, Fut: Future + Send + 'static, { - WebSocketUpgradeResponse { - extractor: self, - callback, + let on_upgrade = self.on_upgrade; + let config = self.config; + + tokio::spawn(async move { + let upgraded = on_upgrade.await.expect("connection upgrade failed"); + let socket = + WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config)) + .await; + let socket = WebSocket { inner: socket }; + callback(socket).await; + }); + + let mut builder = Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header( + header::CONNECTION, + HeaderValue::from_str("upgrade").unwrap(), + ) + .header(header::UPGRADE, HeaderValue::from_str("websocket").unwrap()) + .header( + header::SEC_WEBSOCKET_ACCEPT, + sign(self.sec_websocket_key.as_bytes()), + ); + + if let Some(protocol) = self.protocol { + builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); } + + builder.body(body::boxed(body::Empty::new())).unwrap() } } @@ -240,7 +280,7 @@ where Ok(Self { config: Default::default(), - protocols: None, + protocol: None, sec_websocket_key, on_upgrade, sec_websocket_protocol, @@ -286,79 +326,6 @@ fn header_contains( } } -struct WebSocketUpgradeResponse { - extractor: WebSocketUpgrade, - callback: F, -} - -impl IntoResponse for WebSocketUpgradeResponse -where - F: FnOnce(WebSocket) -> Fut + Send + 'static, - Fut: Future + Send + 'static, -{ - fn into_response(self) -> Response { - // check requested protocols - let protocol = self - .extractor - .sec_websocket_protocol - .as_ref() - .and_then(|req_protocols| { - let req_protocols = req_protocols.to_str().ok()?; - let protocols = self.extractor.protocols.as_ref()?; - req_protocols - .split(',') - .map(|req_p| req_p.trim()) - .find(|req_p| protocols.iter().any(|p| p == req_p)) - }); - - let protocol = match protocol { - Some(protocol) => { - if let Ok(protocol) = HeaderValue::from_str(protocol) { - Some(protocol) - } else { - return ( - StatusCode::BAD_REQUEST, - "`Sec-WebSocket-Protocol` header is invalid", - ) - .into_response(); - } - } - None => None, - }; - - let callback = self.callback; - let on_upgrade = self.extractor.on_upgrade; - let config = self.extractor.config; - - tokio::spawn(async move { - let upgraded = on_upgrade.await.expect("connection upgrade failed"); - let socket = - WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config)) - .await; - let socket = WebSocket { inner: socket }; - callback(socket).await; - }); - - let mut builder = Response::builder() - .status(StatusCode::SWITCHING_PROTOCOLS) - .header( - header::CONNECTION, - HeaderValue::from_str("upgrade").unwrap(), - ) - .header(header::UPGRADE, HeaderValue::from_str("websocket").unwrap()) - .header( - header::SEC_WEBSOCKET_ACCEPT, - sign(self.extractor.sec_websocket_key.as_bytes()), - ); - - if let Some(protocol) = protocol { - builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); - } - - builder.body(body::boxed(body::Empty::new())).unwrap() - } -} - /// A stream of WebSocket messages. #[derive(Debug)] pub struct WebSocket {