mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-01 00:50:32 +01:00
Simplify Websocket implementation (#615)
* Remove `WebSocketUpgradeResponse` * Move protocol selection to `WebSocketUpgrade::protocols`
This commit is contained in:
parent
6feea82d61
commit
980a0a466e
1 changed files with 56 additions and 89 deletions
|
@ -65,11 +65,7 @@
|
|||
|
||||
use self::rejection::*;
|
||||
use super::{rejection::*, FromRequest, RequestParts};
|
||||
use crate::{
|
||||
body,
|
||||
response::{IntoResponse, Response},
|
||||
Error,
|
||||
};
|
||||
use crate::{body, response::Response, Error};
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use futures_util::{
|
||||
|
@ -106,7 +102,8 @@ use tokio_tungstenite::{
|
|||
#[derive(Debug)]
|
||||
pub struct WebSocketUpgrade {
|
||||
config: WebSocketConfig,
|
||||
protocols: Option<Box<[Cow<'static, str>]>>,
|
||||
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
|
||||
protocol: Option<HeaderValue>,
|
||||
sec_websocket_key: HeaderValue,
|
||||
on_upgrade: OnUpgrade,
|
||||
sec_websocket_protocol: Option<HeaderValue>,
|
||||
|
@ -137,6 +134,10 @@ impl WebSocketUpgrade {
|
|||
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
|
||||
/// return the protocol name.
|
||||
///
|
||||
/// The protocols should be listed in decreasing order of preference: if the client offers
|
||||
/// multiple protocols that the server could support, the server will pick the first one in
|
||||
/// this list.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
|
@ -164,13 +165,27 @@ impl WebSocketUpgrade {
|
|||
I: IntoIterator,
|
||||
I::Item: Into<Cow<'static, str>>,
|
||||
{
|
||||
self.protocols = Some(
|
||||
protocols
|
||||
if let Some(req_protocols) = self
|
||||
.sec_websocket_protocol
|
||||
.as_ref()
|
||||
.and_then(|p| p.to_str().ok())
|
||||
{
|
||||
self.protocol = protocols
|
||||
.into_iter()
|
||||
// FIXME: This will often allocate a new `String` and so is less efficient than it
|
||||
// could be. But that can't be fixed without breaking changes to the public API.
|
||||
.map(Into::into)
|
||||
.collect::<Vec<_>>()
|
||||
.into(),
|
||||
);
|
||||
.find(|protocol| {
|
||||
req_protocols
|
||||
.split(',')
|
||||
.any(|req_protocol| req_protocol.trim() == protocol)
|
||||
})
|
||||
.map(|protocol| match protocol {
|
||||
Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
|
||||
Cow::Borrowed(s) => HeaderValue::from_static(s),
|
||||
});
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -180,15 +195,40 @@ impl WebSocketUpgrade {
|
|||
/// When using `WebSocketUpgrade`, the response produced by this method
|
||||
/// should be returned from the handler. See the [module docs](self) for an
|
||||
/// example.
|
||||
pub fn on_upgrade<F, Fut>(self, callback: F) -> impl IntoResponse
|
||||
pub fn on_upgrade<F, Fut>(self, callback: F) -> Response
|
||||
where
|
||||
F: FnOnce(WebSocket) -> Fut + Send + 'static,
|
||||
Fut: Future + Send + 'static,
|
||||
{
|
||||
WebSocketUpgradeResponse {
|
||||
extractor: self,
|
||||
callback,
|
||||
let on_upgrade = self.on_upgrade;
|
||||
let config = self.config;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let upgraded = on_upgrade.await.expect("connection upgrade failed");
|
||||
let socket =
|
||||
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
|
||||
.await;
|
||||
let socket = WebSocket { inner: socket };
|
||||
callback(socket).await;
|
||||
});
|
||||
|
||||
let mut builder = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(
|
||||
header::CONNECTION,
|
||||
HeaderValue::from_str("upgrade").unwrap(),
|
||||
)
|
||||
.header(header::UPGRADE, HeaderValue::from_str("websocket").unwrap())
|
||||
.header(
|
||||
header::SEC_WEBSOCKET_ACCEPT,
|
||||
sign(self.sec_websocket_key.as_bytes()),
|
||||
);
|
||||
|
||||
if let Some(protocol) = self.protocol {
|
||||
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
||||
}
|
||||
|
||||
builder.body(body::boxed(body::Empty::new())).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -240,7 +280,7 @@ where
|
|||
|
||||
Ok(Self {
|
||||
config: Default::default(),
|
||||
protocols: None,
|
||||
protocol: None,
|
||||
sec_websocket_key,
|
||||
on_upgrade,
|
||||
sec_websocket_protocol,
|
||||
|
@ -286,79 +326,6 @@ fn header_contains<B>(
|
|||
}
|
||||
}
|
||||
|
||||
struct WebSocketUpgradeResponse<F> {
|
||||
extractor: WebSocketUpgrade,
|
||||
callback: F,
|
||||
}
|
||||
|
||||
impl<F, Fut> IntoResponse for WebSocketUpgradeResponse<F>
|
||||
where
|
||||
F: FnOnce(WebSocket) -> Fut + Send + 'static,
|
||||
Fut: Future + Send + 'static,
|
||||
{
|
||||
fn into_response(self) -> Response {
|
||||
// check requested protocols
|
||||
let protocol = self
|
||||
.extractor
|
||||
.sec_websocket_protocol
|
||||
.as_ref()
|
||||
.and_then(|req_protocols| {
|
||||
let req_protocols = req_protocols.to_str().ok()?;
|
||||
let protocols = self.extractor.protocols.as_ref()?;
|
||||
req_protocols
|
||||
.split(',')
|
||||
.map(|req_p| req_p.trim())
|
||||
.find(|req_p| protocols.iter().any(|p| p == req_p))
|
||||
});
|
||||
|
||||
let protocol = match protocol {
|
||||
Some(protocol) => {
|
||||
if let Ok(protocol) = HeaderValue::from_str(protocol) {
|
||||
Some(protocol)
|
||||
} else {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-WebSocket-Protocol` header is invalid",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let callback = self.callback;
|
||||
let on_upgrade = self.extractor.on_upgrade;
|
||||
let config = self.extractor.config;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let upgraded = on_upgrade.await.expect("connection upgrade failed");
|
||||
let socket =
|
||||
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
|
||||
.await;
|
||||
let socket = WebSocket { inner: socket };
|
||||
callback(socket).await;
|
||||
});
|
||||
|
||||
let mut builder = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(
|
||||
header::CONNECTION,
|
||||
HeaderValue::from_str("upgrade").unwrap(),
|
||||
)
|
||||
.header(header::UPGRADE, HeaderValue::from_str("websocket").unwrap())
|
||||
.header(
|
||||
header::SEC_WEBSOCKET_ACCEPT,
|
||||
sign(self.extractor.sec_websocket_key.as_bytes()),
|
||||
);
|
||||
|
||||
if let Some(protocol) = protocol {
|
||||
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
||||
}
|
||||
|
||||
builder.body(body::boxed(body::Empty::new())).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream of WebSocket messages.
|
||||
#[derive(Debug)]
|
||||
pub struct WebSocket {
|
||||
|
|
Loading…
Reference in a new issue