diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 1f20a58f..c1c22845 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **added:** Add `WebSocket::protocol` to return the selected WebSocket subprotocol, if there is one. ([#1022]) + +[#1022]: https://github.com/tokio-rs/axum/pull/1022 # 0.5.5 (10. May, 2022) diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index eef1f681..008115a5 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -207,12 +207,17 @@ impl WebSocketUpgrade { let on_upgrade = self.on_upgrade; let config = self.config; + let protocol = self.protocol.clone(); + tokio::spawn(async move { let upgraded = on_upgrade.await.expect("connection upgrade failed"); let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config)) .await; - let socket = WebSocket { inner: socket }; + let socket = WebSocket { + inner: socket, + protocol, + }; callback(socket).await; }); @@ -309,6 +314,7 @@ fn header_contains(req: &RequestParts, key: HeaderName, value: &'static st #[derive(Debug)] pub struct WebSocket { inner: WebSocketStream, + protocol: Option, } impl WebSocket { @@ -331,6 +337,11 @@ impl WebSocket { pub async fn close(mut self) -> Result<(), Error> { self.inner.close(None).await.map_err(Error::new) } + + /// Return the selected WebSocket subprotocol, if one has been chosen. + pub fn protocol(&self) -> Option<&HeaderValue> { + self.protocol.as_ref() + } } impl Stream for WebSocket {