mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-03 17:52:18 +01:00
Simplify Websocket implementation (#615)
* Remove `WebSocketUpgradeResponse` * Move protocol selection to `WebSocketUpgrade::protocols`
This commit is contained in:
parent
6feea82d61
commit
980a0a466e
1 changed files with 56 additions and 89 deletions
|
@ -65,11 +65,7 @@
|
||||||
|
|
||||||
use self::rejection::*;
|
use self::rejection::*;
|
||||||
use super::{rejection::*, FromRequest, RequestParts};
|
use super::{rejection::*, FromRequest, RequestParts};
|
||||||
use crate::{
|
use crate::{body, response::Response, Error};
|
||||||
body,
|
|
||||||
response::{IntoResponse, Response},
|
|
||||||
Error,
|
|
||||||
};
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures_util::{
|
use futures_util::{
|
||||||
|
@ -106,7 +102,8 @@ use tokio_tungstenite::{
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct WebSocketUpgrade {
|
pub struct WebSocketUpgrade {
|
||||||
config: WebSocketConfig,
|
config: WebSocketConfig,
|
||||||
protocols: Option<Box<[Cow<'static, str>]>>,
|
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
|
||||||
|
protocol: Option<HeaderValue>,
|
||||||
sec_websocket_key: HeaderValue,
|
sec_websocket_key: HeaderValue,
|
||||||
on_upgrade: OnUpgrade,
|
on_upgrade: OnUpgrade,
|
||||||
sec_websocket_protocol: Option<HeaderValue>,
|
sec_websocket_protocol: Option<HeaderValue>,
|
||||||
|
@ -137,6 +134,10 @@ impl WebSocketUpgrade {
|
||||||
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
|
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
|
||||||
/// return the protocol name.
|
/// return the protocol name.
|
||||||
///
|
///
|
||||||
|
/// The protocols should be listed in decreasing order of preference: if the client offers
|
||||||
|
/// multiple protocols that the server could support, the server will pick the first one in
|
||||||
|
/// this list.
|
||||||
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
|
@ -164,13 +165,27 @@ impl WebSocketUpgrade {
|
||||||
I: IntoIterator,
|
I: IntoIterator,
|
||||||
I::Item: Into<Cow<'static, str>>,
|
I::Item: Into<Cow<'static, str>>,
|
||||||
{
|
{
|
||||||
self.protocols = Some(
|
if let Some(req_protocols) = self
|
||||||
protocols
|
.sec_websocket_protocol
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|p| p.to_str().ok())
|
||||||
|
{
|
||||||
|
self.protocol = protocols
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
// FIXME: This will often allocate a new `String` and so is less efficient than it
|
||||||
|
// could be. But that can't be fixed without breaking changes to the public API.
|
||||||
.map(Into::into)
|
.map(Into::into)
|
||||||
.collect::<Vec<_>>()
|
.find(|protocol| {
|
||||||
.into(),
|
req_protocols
|
||||||
);
|
.split(',')
|
||||||
|
.any(|req_protocol| req_protocol.trim() == protocol)
|
||||||
|
})
|
||||||
|
.map(|protocol| match protocol {
|
||||||
|
Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
|
||||||
|
Cow::Borrowed(s) => HeaderValue::from_static(s),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,15 +195,40 @@ impl WebSocketUpgrade {
|
||||||
/// When using `WebSocketUpgrade`, the response produced by this method
|
/// When using `WebSocketUpgrade`, the response produced by this method
|
||||||
/// should be returned from the handler. See the [module docs](self) for an
|
/// should be returned from the handler. See the [module docs](self) for an
|
||||||
/// example.
|
/// example.
|
||||||
pub fn on_upgrade<F, Fut>(self, callback: F) -> impl IntoResponse
|
pub fn on_upgrade<F, Fut>(self, callback: F) -> Response
|
||||||
where
|
where
|
||||||
F: FnOnce(WebSocket) -> Fut + Send + 'static,
|
F: FnOnce(WebSocket) -> Fut + Send + 'static,
|
||||||
Fut: Future + Send + 'static,
|
Fut: Future + Send + 'static,
|
||||||
{
|
{
|
||||||
WebSocketUpgradeResponse {
|
let on_upgrade = self.on_upgrade;
|
||||||
extractor: self,
|
let config = self.config;
|
||||||
callback,
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let upgraded = on_upgrade.await.expect("connection upgrade failed");
|
||||||
|
let socket =
|
||||||
|
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
|
||||||
|
.await;
|
||||||
|
let socket = WebSocket { inner: socket };
|
||||||
|
callback(socket).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut builder = Response::builder()
|
||||||
|
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||||
|
.header(
|
||||||
|
header::CONNECTION,
|
||||||
|
HeaderValue::from_str("upgrade").unwrap(),
|
||||||
|
)
|
||||||
|
.header(header::UPGRADE, HeaderValue::from_str("websocket").unwrap())
|
||||||
|
.header(
|
||||||
|
header::SEC_WEBSOCKET_ACCEPT,
|
||||||
|
sign(self.sec_websocket_key.as_bytes()),
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(protocol) = self.protocol {
|
||||||
|
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
builder.body(body::boxed(body::Empty::new())).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -240,7 +280,7 @@ where
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
config: Default::default(),
|
config: Default::default(),
|
||||||
protocols: None,
|
protocol: None,
|
||||||
sec_websocket_key,
|
sec_websocket_key,
|
||||||
on_upgrade,
|
on_upgrade,
|
||||||
sec_websocket_protocol,
|
sec_websocket_protocol,
|
||||||
|
@ -286,79 +326,6 @@ fn header_contains<B>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct WebSocketUpgradeResponse<F> {
|
|
||||||
extractor: WebSocketUpgrade,
|
|
||||||
callback: F,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F, Fut> IntoResponse for WebSocketUpgradeResponse<F>
|
|
||||||
where
|
|
||||||
F: FnOnce(WebSocket) -> Fut + Send + 'static,
|
|
||||||
Fut: Future + Send + 'static,
|
|
||||||
{
|
|
||||||
fn into_response(self) -> Response {
|
|
||||||
// check requested protocols
|
|
||||||
let protocol = self
|
|
||||||
.extractor
|
|
||||||
.sec_websocket_protocol
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|req_protocols| {
|
|
||||||
let req_protocols = req_protocols.to_str().ok()?;
|
|
||||||
let protocols = self.extractor.protocols.as_ref()?;
|
|
||||||
req_protocols
|
|
||||||
.split(',')
|
|
||||||
.map(|req_p| req_p.trim())
|
|
||||||
.find(|req_p| protocols.iter().any(|p| p == req_p))
|
|
||||||
});
|
|
||||||
|
|
||||||
let protocol = match protocol {
|
|
||||||
Some(protocol) => {
|
|
||||||
if let Ok(protocol) = HeaderValue::from_str(protocol) {
|
|
||||||
Some(protocol)
|
|
||||||
} else {
|
|
||||||
return (
|
|
||||||
StatusCode::BAD_REQUEST,
|
|
||||||
"`Sec-WebSocket-Protocol` header is invalid",
|
|
||||||
)
|
|
||||||
.into_response();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let callback = self.callback;
|
|
||||||
let on_upgrade = self.extractor.on_upgrade;
|
|
||||||
let config = self.extractor.config;
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let upgraded = on_upgrade.await.expect("connection upgrade failed");
|
|
||||||
let socket =
|
|
||||||
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
|
|
||||||
.await;
|
|
||||||
let socket = WebSocket { inner: socket };
|
|
||||||
callback(socket).await;
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut builder = Response::builder()
|
|
||||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
|
||||||
.header(
|
|
||||||
header::CONNECTION,
|
|
||||||
HeaderValue::from_str("upgrade").unwrap(),
|
|
||||||
)
|
|
||||||
.header(header::UPGRADE, HeaderValue::from_str("websocket").unwrap())
|
|
||||||
.header(
|
|
||||||
header::SEC_WEBSOCKET_ACCEPT,
|
|
||||||
sign(self.extractor.sec_websocket_key.as_bytes()),
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(protocol) = protocol {
|
|
||||||
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
|
||||||
}
|
|
||||||
|
|
||||||
builder.body(body::boxed(body::Empty::new())).unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A stream of WebSocket messages.
|
/// A stream of WebSocket messages.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct WebSocket {
|
pub struct WebSocket {
|
||||||
|
|
Loading…
Reference in a new issue