mirror of
https://github.com/tokio-rs/axum.git
synced 2025-04-26 13:56:22 +02:00
Make websocket handlers support extractors (#41)
This commit is contained in:
parent
d927c819d3
commit
d843f4378b
6 changed files with 332 additions and 143 deletions
|
@ -5,11 +5,12 @@
|
|||
//! ```
|
||||
//! RUST_LOG=tower_http=debug,key_value_store=trace \
|
||||
//! cargo run \
|
||||
//! --features ws \
|
||||
//! --all-features \
|
||||
//! --example websocket
|
||||
//! ```
|
||||
|
||||
use axum::{
|
||||
extract::TypedHeader,
|
||||
prelude::*,
|
||||
routing::nest,
|
||||
service::ServiceExt,
|
||||
|
@ -57,7 +58,13 @@ async fn main() {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
async fn handle_socket(mut socket: WebSocket) {
|
||||
async fn handle_socket(
|
||||
mut socket: WebSocket,
|
||||
// websocket handlers can also use extractors
|
||||
TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
|
||||
) {
|
||||
println!("`{}` connected", user_agent.as_str());
|
||||
|
||||
if let Some(msg) = socket.recv().await {
|
||||
let msg = msg.unwrap();
|
||||
println!("Client says: {:?}", msg);
|
||||
|
|
|
@ -248,7 +248,7 @@ use crate::{response::IntoResponse, util::ByteStr};
|
|||
use async_trait::async_trait;
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures_util::stream::Stream;
|
||||
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
|
||||
use http::{header, Extensions, HeaderMap, Method, Request, Response, Uri, Version};
|
||||
use rejection::*;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::{
|
||||
|
@ -475,6 +475,46 @@ impl<B> RequestParts<B> {
|
|||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<B> FromRequest<B> for ()
|
||||
where
|
||||
B: Send,
|
||||
{
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request(_: &mut RequestParts<B>) -> Result<(), Self::Rejection> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_from_request {
|
||||
() => {
|
||||
};
|
||||
|
||||
( $head:ident, $($tail:ident),* $(,)? ) => {
|
||||
#[async_trait]
|
||||
#[allow(non_snake_case)]
|
||||
impl<B, $head, $($tail,)*> FromRequest<B> for ($head, $($tail,)*)
|
||||
where
|
||||
$head: FromRequest<B> + Send,
|
||||
$( $tail: FromRequest<B> + Send, )*
|
||||
B: Send,
|
||||
{
|
||||
type Rejection = Response<crate::body::Body>;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
let $head = $head::from_request(req).await.map_err(IntoResponse::into_response)?;
|
||||
$( let $tail = $tail::from_request(req).await.map_err(IntoResponse::into_response)?; )*
|
||||
Ok(($head, $($tail,)*))
|
||||
}
|
||||
}
|
||||
|
||||
impl_from_request!($($tail,)*);
|
||||
};
|
||||
}
|
||||
|
||||
impl_from_request!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
|
||||
|
||||
#[async_trait]
|
||||
impl<T, B> FromRequest<B> for Option<T>
|
||||
where
|
||||
|
@ -1233,3 +1273,39 @@ where
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Extractor that extracts the raw query string, without parsing it.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use axum::prelude::*;
|
||||
/// use futures::StreamExt;
|
||||
///
|
||||
/// async fn handler(extract::RawQuery(query): extract::RawQuery) {
|
||||
/// // ...
|
||||
/// }
|
||||
///
|
||||
/// let app = route("/users", get(handler));
|
||||
/// # async {
|
||||
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct RawQuery(pub Option<String>);
|
||||
|
||||
#[async_trait]
|
||||
impl<B> FromRequest<B> for RawQuery
|
||||
where
|
||||
B: Send,
|
||||
{
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
let query = req
|
||||
.uri()
|
||||
.and_then(|uri| uri.query())
|
||||
.map(|query| query.to_string());
|
||||
Ok(Self(query))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -172,7 +172,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
mod sealed {
|
||||
pub(crate) mod sealed {
|
||||
#![allow(unreachable_pub, missing_docs, missing_debug_implementations)]
|
||||
|
||||
pub trait HiddentTrait {}
|
||||
|
@ -188,8 +188,8 @@ mod sealed {
|
|||
/// See the [module docs](crate::handler) for more details.
|
||||
#[async_trait]
|
||||
pub trait Handler<B, In>: Sized {
|
||||
// This seals the trait. We cannot use the regular "sealed super trait" approach
|
||||
// due to coherence.
|
||||
// This seals the trait. We cannot use the regular "sealed super trait"
|
||||
// approach due to coherence.
|
||||
#[doc(hidden)]
|
||||
type Sealed: sealed::HiddentTrait;
|
||||
|
||||
|
@ -256,7 +256,8 @@ where
|
|||
}
|
||||
|
||||
macro_rules! impl_handler {
|
||||
() => {};
|
||||
() => {
|
||||
};
|
||||
|
||||
( $head:ident, $($tail:ident),* $(,)? ) => {
|
||||
#[async_trait]
|
||||
|
|
13
src/lib.rs
13
src/lib.rs
|
@ -65,7 +65,7 @@
|
|||
//! ["extractors"](#extractors) as arguments and returns something that
|
||||
//! can be converted [into a response](#building-responses).
|
||||
//!
|
||||
//! Handlers is where you custom domain logic lives and axum applications are
|
||||
//! Handlers is where your custom domain logic lives and axum applications are
|
||||
//! built by routing between handlers.
|
||||
//!
|
||||
//! Some examples of handlers:
|
||||
|
@ -78,14 +78,14 @@
|
|||
//! // Handler that immediately returns an empty `200 OK` response.
|
||||
//! async fn unit_handler() {}
|
||||
//!
|
||||
//! // Handler that immediately returns an empty `200 Ok` response with a plain
|
||||
//! // Handler that immediately returns an empty `200 OK` response with a plain
|
||||
//! // text body.
|
||||
//! async fn string_handler() -> String {
|
||||
//! "Hello, World!".to_string()
|
||||
//! }
|
||||
//!
|
||||
//! // Handler that buffers the request body and returns it if it is valid UTF-8
|
||||
//! async fn buffer_body(body: Bytes) -> Result<String, StatusCode> {
|
||||
//! // Handler that buffers the request body and returns it.
|
||||
//! async fn echo(body: Bytes) -> Result<String, StatusCode> {
|
||||
//! if let Ok(string) = String::from_utf8(body.to_vec()) {
|
||||
//! Ok(string)
|
||||
//! } else {
|
||||
|
@ -248,7 +248,7 @@
|
|||
//! "foo"
|
||||
//! }
|
||||
//!
|
||||
//! // String works too and will get a text/plain content-type
|
||||
//! // String works too and will get a `text/plain` content-type
|
||||
//! async fn plain_text_string(uri: Uri) -> String {
|
||||
//! format!("Hi from {}", uri.path())
|
||||
//! }
|
||||
|
@ -547,14 +547,13 @@
|
|||
//! [`Timeout`]: tower::timeout::Timeout
|
||||
//! [examples]: https://github.com/tokio-rs/axum/tree/main/examples
|
||||
|
||||
#![doc(html_root_url = "https://docs.rs/tower-http/0.1.0")]
|
||||
#![doc(html_root_url = "https://docs.rs/axum/0.1.0")]
|
||||
#.
|
||||
#[derive(Debug)]
|
||||
pub struct ResponseFuture(Result<Option<HeaderValue>, Option<(StatusCode, &'static str)>>);
|
||||
|
||||
impl ResponseFuture {
|
||||
pub(super) fn ok(key: HeaderValue) -> Self {
|
||||
Self(Ok(Some(key)))
|
||||
}
|
||||
|
||||
pub(super) fn err(status: StatusCode, body: &'static str) -> Self {
|
||||
Self(Err(Some((status, body))))
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for ResponseFuture {
|
||||
type Output = Result<Response<Full<Bytes>>, Infallible>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let res = match self.get_mut().0.as_mut() {
|
||||
Ok(key) => Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(
|
||||
http::header::CONNECTION,
|
||||
HeaderValue::from_str("upgrade").unwrap(),
|
||||
)
|
||||
.header(
|
||||
http::header::UPGRADE,
|
||||
HeaderValue::from_str("websocket").unwrap(),
|
||||
)
|
||||
.header(
|
||||
http::header::SEC_WEBSOCKET_ACCEPT,
|
||||
sign(key.take().unwrap().as_bytes()),
|
||||
)
|
||||
.body(Full::new(Bytes::new()))
|
||||
.unwrap(),
|
||||
Err(err) => {
|
||||
let (status, body) = err.take().unwrap();
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.body(Full::from(body))
|
||||
.unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
Poll::Ready(Ok(res))
|
||||
}
|
||||
}
|
||||
|
||||
fn sign(key: &[u8]) -> HeaderValue {
|
||||
let mut sha1 = Sha1::default();
|
||||
sha1.update(key);
|
||||
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
|
||||
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
|
||||
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
|
||||
opaque_future! {
|
||||
/// Response future for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub type ResponseFuture = futures_util::future::BoxFuture<'static, Result<Response<BoxBody>, Infallible>>;
|
||||
}
|
||||
|
|
294
src/ws/mod.rs
294
src/ws/mod.rs
|
@ -17,11 +17,48 @@
|
|||
//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
//! # };
|
||||
//! ```
|
||||
//!
|
||||
//! Websocket handlers can also use extractors, however the first function
|
||||
//! argument must be of type [`WebSocket`]:
|
||||
//!
|
||||
//! ```
|
||||
//! use axum::{prelude::*, extract::{RequestParts, FromRequest}, ws::{ws, WebSocket}};
|
||||
//! use http::{HeaderMap, StatusCode};
|
||||
//!
|
||||
//! /// An extractor that authorizes requests.
|
||||
//! struct RequireAuth;
|
||||
//!
|
||||
//! #[async_trait::async_trait]
|
||||
//! impl<B> FromRequest<B> for RequireAuth
|
||||
//! where
|
||||
//! B: Send,
|
||||
//! {
|
||||
//! type Rejection = StatusCode;
|
||||
//!
|
||||
//! async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
//! # unimplemented!()
|
||||
//! // Put your auth logic here...
|
||||
//! }
|
||||
//! }
|
||||
//!
|
||||
//! let app = route("/ws", ws(handle_socket));
|
||||
//!
|
||||
//! async fn handle_socket(
|
||||
//! mut socket: WebSocket,
|
||||
//! // Run `RequireAuth` for each request before upgrading.
|
||||
//! _auth: RequireAuth,
|
||||
//! ) {
|
||||
//! // ...
|
||||
//! }
|
||||
//! # async {
|
||||
//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
//! # };
|
||||
//! ```
|
||||
|
||||
use crate::{
|
||||
routing::EmptyRouter,
|
||||
service::{BoxResponseBody, OnMethod},
|
||||
};
|
||||
use crate::body::{box_body, BoxBody};
|
||||
use crate::extract::{FromRequest, RequestParts};
|
||||
use crate::response::IntoResponse;
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use future::ResponseFuture;
|
||||
use futures_util::{sink::SinkExt, stream::StreamExt};
|
||||
|
@ -31,7 +68,11 @@ use http::{
|
|||
};
|
||||
use http_body::Full;
|
||||
use hyper::upgrade::{OnUpgrade, Upgraded};
|
||||
use std::{borrow::Cow, convert::Infallible, fmt, future::Future, task::Context, task::Poll};
|
||||
use sha1::{Digest, Sha1};
|
||||
use std::{
|
||||
borrow::Cow, convert::Infallible, fmt, future::Future, marker::PhantomData, task::Context,
|
||||
task::Poll,
|
||||
};
|
||||
use tokio_tungstenite::{
|
||||
tungstenite::protocol::{self, WebSocketConfig},
|
||||
WebSocketStream,
|
||||
|
@ -44,31 +85,103 @@ pub mod future;
|
|||
/// each connection.
|
||||
///
|
||||
/// See the [module docs](crate::ws) for more details.
|
||||
pub fn ws<F, Fut, B>(callback: F) -> OnMethod<BoxResponseBody<WebSocketUpgrade<F>, B>, EmptyRouter>
|
||||
pub fn ws<F, B, T>(callback: F) -> WebSocketUpgrade<F, B, T>
|
||||
where
|
||||
F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
|
||||
Fut: Future<Output = ()> + Send + 'static,
|
||||
F: WebSocketHandler<B, T>,
|
||||
{
|
||||
let svc = WebSocketUpgrade {
|
||||
WebSocketUpgrade {
|
||||
callback,
|
||||
config: WebSocketConfig::default(),
|
||||
};
|
||||
crate::service::get::<_, B>(svc)
|
||||
_request_body: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for async functions that can be used to handle websocket requests.
|
||||
///
|
||||
/// You shouldn't need to depend on this trait directly. It is automatically
|
||||
/// implemented to closures of the right types.
|
||||
///
|
||||
/// See the [module docs](crate::ws) for more details.
|
||||
#[async_trait]
|
||||
pub trait WebSocketHandler<B, In>: Sized {
|
||||
// This seals the trait. We cannot use the regular "sealed super trait"
|
||||
// approach due to coherence.
|
||||
#[doc(hidden)]
|
||||
type Sealed: crate::handler::sealed::HiddentTrait;
|
||||
|
||||
/// Call the handler with the given websocket stream and input parsed by
|
||||
/// extractors.
|
||||
async fn call(self, stream: WebSocket, input: In);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<F, Fut, B> WebSocketHandler<B, ()> for F
|
||||
where
|
||||
F: FnOnce(WebSocket) -> Fut + Send,
|
||||
Fut: Future<Output = ()> + Send,
|
||||
B: Send,
|
||||
{
|
||||
type Sealed = crate::handler::sealed::Hidden;
|
||||
|
||||
async fn call(self, stream: WebSocket, _: ()) {
|
||||
self(stream).await
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_ws_handler {
|
||||
() => {
|
||||
};
|
||||
|
||||
( $head:ident, $($tail:ident),* $(,)? ) => {
|
||||
#[async_trait]
|
||||
#[allow(non_snake_case)]
|
||||
impl<F, Fut, B, $head, $($tail,)*> WebSocketHandler<B, ($head, $($tail,)*)> for F
|
||||
where
|
||||
B: Send,
|
||||
$head: FromRequest<B> + Send + 'static,
|
||||
$( $tail: FromRequest<B> + Send + 'static, )*
|
||||
F: FnOnce(WebSocket, $head, $($tail,)*) -> Fut + Send,
|
||||
Fut: Future<Output = ()> + Send,
|
||||
{
|
||||
type Sealed = crate::handler::sealed::Hidden;
|
||||
|
||||
async fn call(self, stream: WebSocket, ($head, $($tail,)*): ($head, $($tail,)*)) {
|
||||
self(stream, $head, $($tail,)*).await
|
||||
}
|
||||
}
|
||||
|
||||
impl_ws_handler!($($tail,)*);
|
||||
};
|
||||
}
|
||||
|
||||
impl_ws_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
|
||||
|
||||
/// [`Service`] that upgrades connections to websockets and spawns a task to
|
||||
/// handle the stream.
|
||||
///
|
||||
/// Created with [`ws`].
|
||||
///
|
||||
/// See the [module docs](crate::ws) for more details.
|
||||
#[derive(Clone)]
|
||||
pub struct WebSocketUpgrade<F> {
|
||||
pub struct WebSocketUpgrade<F, B, T> {
|
||||
callback: F,
|
||||
config: WebSocketConfig,
|
||||
_request_body: PhantomData<fn() -> (B, T)>,
|
||||
}
|
||||
|
||||
impl<F> fmt::Debug for WebSocketUpgrade<F> {
|
||||
impl<F, B, T> Clone for WebSocketUpgrade<F, B, T>
|
||||
where
|
||||
F: Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
callback: self.callback.clone(),
|
||||
config: self.config,
|
||||
_request_body: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, B, T> fmt::Debug for WebSocketUpgrade<F, B, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("WebSocketUpgrade")
|
||||
.field("callback", &format_args!("{}", std::any::type_name::<F>()))
|
||||
|
@ -77,7 +190,7 @@ impl<F> fmt::Debug for WebSocketUpgrade<F> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<F> WebSocketUpgrade<F> {
|
||||
impl<F, B, T> WebSocketUpgrade<F, B, T> {
|
||||
/// Set the size of the internal message send queue.
|
||||
pub fn max_send_queue(mut self, max: usize) -> Self {
|
||||
self.config.max_send_queue = Some(max);
|
||||
|
@ -97,12 +210,13 @@ impl<F> WebSocketUpgrade<F> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<ReqBody, F, Fut> Service<Request<ReqBody>> for WebSocketUpgrade<F>
|
||||
impl<ReqBody, F, T> Service<Request<ReqBody>> for WebSocketUpgrade<F, ReqBody, T>
|
||||
where
|
||||
F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
|
||||
Fut: Future<Output = ()> + Send + 'static,
|
||||
F: WebSocketHandler<ReqBody, T> + Clone + Send + 'static,
|
||||
T: FromRequest<ReqBody> + Send + 'static,
|
||||
ReqBody: Send + 'static,
|
||||
{
|
||||
type Response = Response<Full<Bytes>>;
|
||||
type Response = Response<BoxBody>;
|
||||
type Error = Infallible;
|
||||
type Future = ResponseFuture;
|
||||
|
||||
|
@ -111,62 +225,104 @@ where
|
|||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
|
||||
if !header_eq(
|
||||
&req,
|
||||
header::CONNECTION,
|
||||
HeaderValue::from_static("upgrade"),
|
||||
) {
|
||||
return ResponseFuture::err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Connection header did not include 'upgrade'",
|
||||
);
|
||||
}
|
||||
let this = self.clone();
|
||||
|
||||
if !header_eq(&req, header::UPGRADE, HeaderValue::from_static("websocket")) {
|
||||
return ResponseFuture::err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Upgrade` header did not include 'websocket'",
|
||||
);
|
||||
}
|
||||
ResponseFuture(Box::pin(async move {
|
||||
if req.method() != http::Method::GET {
|
||||
return response(StatusCode::NOT_FOUND, "Request method must be `GET`");
|
||||
}
|
||||
|
||||
if !header_eq(
|
||||
&req,
|
||||
header::SEC_WEBSOCKET_VERSION,
|
||||
HeaderValue::from_static("13"),
|
||||
) {
|
||||
return ResponseFuture::err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Version` header did not include '13'",
|
||||
);
|
||||
}
|
||||
if !header_eq(
|
||||
&req,
|
||||
header::CONNECTION,
|
||||
HeaderValue::from_static("upgrade"),
|
||||
) {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Connection header did not include 'upgrade'",
|
||||
);
|
||||
}
|
||||
|
||||
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
|
||||
key
|
||||
} else {
|
||||
return ResponseFuture::err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Key` header missing",
|
||||
);
|
||||
};
|
||||
if !header_eq(&req, header::UPGRADE, HeaderValue::from_static("websocket")) {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Upgrade` header did not include 'websocket'",
|
||||
);
|
||||
}
|
||||
|
||||
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
|
||||
if !header_eq(
|
||||
&req,
|
||||
header::SEC_WEBSOCKET_VERSION,
|
||||
HeaderValue::from_static("13"),
|
||||
) {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Version` header did not include '13'",
|
||||
);
|
||||
}
|
||||
|
||||
let config = self.config;
|
||||
let callback = self.callback.clone();
|
||||
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
|
||||
key
|
||||
} else {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Key` header missing",
|
||||
);
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
let upgraded = on_upgrade.await.unwrap();
|
||||
let socket =
|
||||
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
|
||||
.await;
|
||||
let socket = WebSocket { inner: socket };
|
||||
callback(socket).await;
|
||||
});
|
||||
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
|
||||
|
||||
ResponseFuture::ok(key)
|
||||
let config = this.config;
|
||||
let callback = this.callback.clone();
|
||||
|
||||
let mut req = RequestParts::new(req);
|
||||
let input = match T::from_request(&mut req).await {
|
||||
Ok(input) => input,
|
||||
Err(rejection) => {
|
||||
let res = rejection.into_response().map(box_body);
|
||||
return Ok(res);
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
let upgraded = on_upgrade.await.unwrap();
|
||||
let socket = WebSocketStream::from_raw_socket(
|
||||
upgraded,
|
||||
protocol::Role::Server,
|
||||
Some(config),
|
||||
)
|
||||
.await;
|
||||
let socket = WebSocket { inner: socket };
|
||||
callback.call(socket, input).await;
|
||||
});
|
||||
|
||||
let res = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(
|
||||
http::header::CONNECTION,
|
||||
HeaderValue::from_str("upgrade").unwrap(),
|
||||
)
|
||||
.header(
|
||||
http::header::UPGRADE,
|
||||
HeaderValue::from_str("websocket").unwrap(),
|
||||
)
|
||||
.header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()))
|
||||
.body(box_body(Full::new(Bytes::new())))
|
||||
.unwrap();
|
||||
|
||||
Ok(res)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
fn response<E>(status: StatusCode, body: &'static str) -> Result<Response<BoxBody>, E> {
|
||||
let res = Response::builder()
|
||||
.status(status)
|
||||
.body(box_body(Full::from(body)))
|
||||
.unwrap();
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
|
||||
if let Some(header) = req.headers().get(&key) {
|
||||
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
|
||||
|
@ -175,6 +331,14 @@ fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
|
|||
}
|
||||
}
|
||||
|
||||
fn sign(key: &[u8]) -> HeaderValue {
|
||||
let mut sha1 = Sha1::default();
|
||||
sha1.update(key);
|
||||
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
|
||||
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
|
||||
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
|
||||
}
|
||||
|
||||
/// A stream of websocket messages.
|
||||
#[derive(Debug)]
|
||||
pub struct WebSocket {
|
||||
|
|
Loading…
Add table
Reference in a new issue