Simplify Websocket implementation (#615)

* Remove `WebSocketUpgradeResponse`

* Move protocol selection to `WebSocketUpgrade::protocols`
This commit is contained in:
Kai Jewson 2021-12-12 16:05:28 +00:00 committed by GitHub
parent 6feea82d61
commit 980a0a466e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -65,11 +65,7 @@
use self::rejection::*;
use super::{rejection::*, FromRequest, RequestParts};
use crate::{
body,
response::{IntoResponse, Response},
Error,
};
use crate::{body, response::Response, Error};
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::{
@ -106,7 +102,8 @@ use tokio_tungstenite::{
#[derive(Debug)]
pub struct WebSocketUpgrade {
config: WebSocketConfig,
protocols: Option<Box<[Cow<'static, str>]>>,
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
protocol: Option<HeaderValue>,
sec_websocket_key: HeaderValue,
on_upgrade: OnUpgrade,
sec_websocket_protocol: Option<HeaderValue>,
@ -137,6 +134,10 @@ impl WebSocketUpgrade {
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
/// return the protocol name.
///
/// The protocols should be listed in decreasing order of preference: if the client offers
/// multiple protocols that the server could support, the server will pick the first one in
/// this list.
///
/// # Examples
///
/// ```
@ -164,13 +165,27 @@ impl WebSocketUpgrade {
I: IntoIterator,
I::Item: Into<Cow<'static, str>>,
{
self.protocols = Some(
protocols
if let Some(req_protocols) = self
.sec_websocket_protocol
.as_ref()
.and_then(|p| p.to_str().ok())
{
self.protocol = protocols
.into_iter()
// FIXME: This will often allocate a new `String` and so is less efficient than it
// could be. But that can't be fixed without breaking changes to the public API.
.map(Into::into)
.collect::<Vec<_>>()
.into(),
);
.find(|protocol| {
req_protocols
.split(',')
.any(|req_protocol| req_protocol.trim() == protocol)
})
.map(|protocol| match protocol {
Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
Cow::Borrowed(s) => HeaderValue::from_static(s),
});
}
self
}
@ -180,15 +195,40 @@ impl WebSocketUpgrade {
/// When using `WebSocketUpgrade`, the response produced by this method
/// should be returned from the handler. See the [module docs](self) for an
/// example.
pub fn on_upgrade<F, Fut>(self, callback: F) -> impl IntoResponse
pub fn on_upgrade<F, Fut>(self, callback: F) -> Response
where
F: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future + Send + 'static,
{
WebSocketUpgradeResponse {
extractor: self,
callback,
let on_upgrade = self.on_upgrade;
let config = self.config;
tokio::spawn(async move {
let upgraded = on_upgrade.await.expect("connection upgrade failed");
let socket =
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
.await;
let socket = WebSocket { inner: socket };
callback(socket).await;
});
let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(
header::CONNECTION,
HeaderValue::from_str("upgrade").unwrap(),
)
.header(header::UPGRADE, HeaderValue::from_str("websocket").unwrap())
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(self.sec_websocket_key.as_bytes()),
);
if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
builder.body(body::boxed(body::Empty::new())).unwrap()
}
}
@ -240,7 +280,7 @@ where
Ok(Self {
config: Default::default(),
protocols: None,
protocol: None,
sec_websocket_key,
on_upgrade,
sec_websocket_protocol,
@ -286,79 +326,6 @@ fn header_contains<B>(
}
}
struct WebSocketUpgradeResponse<F> {
extractor: WebSocketUpgrade,
callback: F,
}
impl<F, Fut> IntoResponse for WebSocketUpgradeResponse<F>
where
F: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future + Send + 'static,
{
fn into_response(self) -> Response {
// check requested protocols
let protocol = self
.extractor
.sec_websocket_protocol
.as_ref()
.and_then(|req_protocols| {
let req_protocols = req_protocols.to_str().ok()?;
let protocols = self.extractor.protocols.as_ref()?;
req_protocols
.split(',')
.map(|req_p| req_p.trim())
.find(|req_p| protocols.iter().any(|p| p == req_p))
});
let protocol = match protocol {
Some(protocol) => {
if let Ok(protocol) = HeaderValue::from_str(protocol) {
Some(protocol)
} else {
return (
StatusCode::BAD_REQUEST,
"`Sec-WebSocket-Protocol` header is invalid",
)
.into_response();
}
}
None => None,
};
let callback = self.callback;
let on_upgrade = self.extractor.on_upgrade;
let config = self.extractor.config;
tokio::spawn(async move {
let upgraded = on_upgrade.await.expect("connection upgrade failed");
let socket =
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
.await;
let socket = WebSocket { inner: socket };
callback(socket).await;
});
let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(
header::CONNECTION,
HeaderValue::from_str("upgrade").unwrap(),
)
.header(header::UPGRADE, HeaderValue::from_str("websocket").unwrap())
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(self.extractor.sec_websocket_key.as_bytes()),
);
if let Some(protocol) = protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
builder.body(body::boxed(body::Empty::new())).unwrap()
}
}
/// A stream of WebSocket messages.
#[derive(Debug)]
pub struct WebSocket {