1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-04-26 13:56:22 +02:00

Make websocket handlers support extractors ()

This commit is contained in:
David Pedersen 2021-07-30 15:19:53 +02:00 committed by GitHub
parent d927c819d3
commit d843f4378b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 332 additions and 143 deletions

View file

@ -5,11 +5,12 @@
//! ```
//! RUST_LOG=tower_http=debug,key_value_store=trace \
//! cargo run \
//! --features ws \
//! --all-features \
//! --example websocket
//! ```
use axum::{
extract::TypedHeader,
prelude::*,
routing::nest,
service::ServiceExt,
@ -57,7 +58,13 @@ async fn main() {
.unwrap();
}
async fn handle_socket(mut socket: WebSocket) {
async fn handle_socket(
mut socket: WebSocket,
// websocket handlers can also use extractors
TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
) {
println!("`{}` connected", user_agent.as_str());
if let Some(msg) = socket.recv().await {
let msg = msg.unwrap();
println!("Client says: {:?}", msg);

View file

@ -248,7 +248,7 @@ use crate::{response::IntoResponse, util::ByteStr};
use async_trait::async_trait;
use bytes::{Buf, Bytes};
use futures_util::stream::Stream;
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
use http::{header, Extensions, HeaderMap, Method, Request, Response, Uri, Version};
use rejection::*;
use serde::de::DeserializeOwned;
use std::{
@ -475,6 +475,46 @@ impl<B> RequestParts<B> {
}
}
#[async_trait]
impl<B> FromRequest<B> for ()
where
B: Send,
{
type Rejection = Infallible;
async fn from_request(_: &mut RequestParts<B>) -> Result<(), Self::Rejection> {
Ok(())
}
}
macro_rules! impl_from_request {
() => {
};
( $head:ident, $($tail:ident),* $(,)? ) => {
#[async_trait]
#[allow(non_snake_case)]
impl<B, $head, $($tail,)*> FromRequest<B> for ($head, $($tail,)*)
where
$head: FromRequest<B> + Send,
$( $tail: FromRequest<B> + Send, )*
B: Send,
{
type Rejection = Response<crate::body::Body>;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let $head = $head::from_request(req).await.map_err(IntoResponse::into_response)?;
$( let $tail = $tail::from_request(req).await.map_err(IntoResponse::into_response)?; )*
Ok(($head, $($tail,)*))
}
}
impl_from_request!($($tail,)*);
};
}
impl_from_request!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
#[async_trait]
impl<T, B> FromRequest<B> for Option<T>
where
@ -1233,3 +1273,39 @@ where
})
}
}
/// Extractor that extracts the raw query string, without parsing it.
///
/// # Example
///
/// ```rust,no_run
/// use axum::prelude::*;
/// use futures::StreamExt;
///
/// async fn handler(extract::RawQuery(query): extract::RawQuery) {
/// // ...
/// }
///
/// let app = route("/users", get(handler));
/// # async {
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
#[derive(Debug)]
pub struct RawQuery(pub Option<String>);
#[async_trait]
impl<B> FromRequest<B> for RawQuery
where
B: Send,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let query = req
.uri()
.and_then(|uri| uri.query())
.map(|query| query.to_string());
Ok(Self(query))
}
}

View file

@ -172,7 +172,7 @@ where
}
}
mod sealed {
pub(crate) mod sealed {
#![allow(unreachable_pub, missing_docs, missing_debug_implementations)]
pub trait HiddentTrait {}
@ -188,8 +188,8 @@ mod sealed {
/// See the [module docs](crate::handler) for more details.
#[async_trait]
pub trait Handler<B, In>: Sized {
// This seals the trait. We cannot use the regular "sealed super trait" approach
// due to coherence.
// This seals the trait. We cannot use the regular "sealed super trait"
// approach due to coherence.
#[doc(hidden)]
type Sealed: sealed::HiddentTrait;
@ -256,7 +256,8 @@ where
}
macro_rules! impl_handler {
() => {};
() => {
};
( $head:ident, $($tail:ident),* $(,)? ) => {
#[async_trait]

View file

@ -65,7 +65,7 @@
//! ["extractors"](#extractors) as arguments and returns something that
//! can be converted [into a response](#building-responses).
//!
//! Handlers is where you custom domain logic lives and axum applications are
//! Handlers is where your custom domain logic lives and axum applications are
//! built by routing between handlers.
//!
//! Some examples of handlers:
@ -78,14 +78,14 @@
//! // Handler that immediately returns an empty `200 OK` response.
//! async fn unit_handler() {}
//!
//! // Handler that immediately returns an empty `200 Ok` response with a plain
//! // Handler that immediately returns an empty `200 OK` response with a plain
//! // text body.
//! async fn string_handler() -> String {
//! "Hello, World!".to_string()
//! }
//!
//! // Handler that buffers the request body and returns it if it is valid UTF-8
//! async fn buffer_body(body: Bytes) -> Result<String, StatusCode> {
//! // Handler that buffers the request body and returns it.
//! async fn echo(body: Bytes) -> Result<String, StatusCode> {
//! if let Ok(string) = String::from_utf8(body.to_vec()) {
//! Ok(string)
//! } else {
@ -248,7 +248,7 @@
//! "foo"
//! }
//!
//! // String works too and will get a text/plain content-type
//! // String works too and will get a `text/plain` content-type
//! async fn plain_text_string(uri: Uri) -> String {
//! format!("Hi from {}", uri.path())
//! }
@ -547,14 +547,13 @@
//! [`Timeout`]: tower::timeout::Timeout
//! [examples]: https://github.com/tokio-rs/axum/tree/main/examples
#![doc(html_root_url = "https://docs.rs/tower-http/0.1.0")]
#![doc(html_root_url = "https://docs.rs/axum/0.1.0")]
#![warn(
clippy::all,
clippy::dbg_macro,
clippy::todo,
clippy::empty_enum,
clippy::enum_glob_use,
clippy::pub_enum_variant_names,
clippy::mem_forget,
clippy::unused_self,
clippy::filter_map_next,

View file

@ -1,68 +1,10 @@
//! Future types.
use bytes::Bytes;
use http::{HeaderValue, Response, StatusCode};
use http_body::Full;
use sha1::{Digest, Sha1};
use std::{
convert::Infallible,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use crate::body::BoxBody;
use http::Response;
use std::convert::Infallible;
/// Response future for [`WebSocketUpgrade`](super::WebSocketUpgrade).
#[derive(Debug)]
pub struct ResponseFuture(Result<Option<HeaderValue>, Option<(StatusCode, &'static str)>>);
impl ResponseFuture {
pub(super) fn ok(key: HeaderValue) -> Self {
Self(Ok(Some(key)))
}
pub(super) fn err(status: StatusCode, body: &'static str) -> Self {
Self(Err(Some((status, body))))
}
}
impl Future for ResponseFuture {
type Output = Result<Response<Full<Bytes>>, Infallible>;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = match self.get_mut().0.as_mut() {
Ok(key) => Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(
http::header::CONNECTION,
HeaderValue::from_str("upgrade").unwrap(),
)
.header(
http::header::UPGRADE,
HeaderValue::from_str("websocket").unwrap(),
)
.header(
http::header::SEC_WEBSOCKET_ACCEPT,
sign(key.take().unwrap().as_bytes()),
)
.body(Full::new(Bytes::new()))
.unwrap(),
Err(err) => {
let (status, body) = err.take().unwrap();
Response::builder()
.status(status)
.body(Full::from(body))
.unwrap()
}
};
Poll::Ready(Ok(res))
}
}
fn sign(key: &[u8]) -> HeaderValue {
let mut sha1 = Sha1::default();
sha1.update(key);
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
opaque_future! {
/// Response future for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub type ResponseFuture = futures_util::future::BoxFuture<'static, Result<Response<BoxBody>, Infallible>>;
}

View file

@ -17,11 +17,48 @@
//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! Websocket handlers can also use extractors, however the first function
//! argument must be of type [`WebSocket`]:
//!
//! ```
//! use axum::{prelude::*, extract::{RequestParts, FromRequest}, ws::{ws, WebSocket}};
//! use http::{HeaderMap, StatusCode};
//!
//! /// An extractor that authorizes requests.
//! struct RequireAuth;
//!
//! #[async_trait::async_trait]
//! impl<B> FromRequest<B> for RequireAuth
//! where
//! B: Send,
//! {
//! type Rejection = StatusCode;
//!
//! async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
//! # unimplemented!()
//! // Put your auth logic here...
//! }
//! }
//!
//! let app = route("/ws", ws(handle_socket));
//!
//! async fn handle_socket(
//! mut socket: WebSocket,
//! // Run `RequireAuth` for each request before upgrading.
//! _auth: RequireAuth,
//! ) {
//! // ...
//! }
//! # async {
//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
use crate::{
routing::EmptyRouter,
service::{BoxResponseBody, OnMethod},
};
use crate::body::{box_body, BoxBody};
use crate::extract::{FromRequest, RequestParts};
use crate::response::IntoResponse;
use async_trait::async_trait;
use bytes::Bytes;
use future::ResponseFuture;
use futures_util::{sink::SinkExt, stream::StreamExt};
@ -31,7 +68,11 @@ use http::{
};
use http_body::Full;
use hyper::upgrade::{OnUpgrade, Upgraded};
use std::{borrow::Cow, convert::Infallible, fmt, future::Future, task::Context, task::Poll};
use sha1::{Digest, Sha1};
use std::{
borrow::Cow, convert::Infallible, fmt, future::Future, marker::PhantomData, task::Context,
task::Poll,
};
use tokio_tungstenite::{
tungstenite::protocol::{self, WebSocketConfig},
WebSocketStream,
@ -44,31 +85,103 @@ pub mod future;
/// each connection.
///
/// See the [module docs](crate::ws) for more details.
pub fn ws<F, Fut, B>(callback: F) -> OnMethod<BoxResponseBody<WebSocketUpgrade<F>, B>, EmptyRouter>
pub fn ws<F, B, T>(callback: F) -> WebSocketUpgrade<F, B, T>
where
F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
F: WebSocketHandler<B, T>,
{
let svc = WebSocketUpgrade {
WebSocketUpgrade {
callback,
config: WebSocketConfig::default(),
};
crate::service::get::<_, B>(svc)
_request_body: PhantomData,
}
}
/// Trait for async functions that can be used to handle websocket requests.
///
/// You shouldn't need to depend on this trait directly. It is automatically
/// implemented to closures of the right types.
///
/// See the [module docs](crate::ws) for more details.
#[async_trait]
pub trait WebSocketHandler<B, In>: Sized {
// This seals the trait. We cannot use the regular "sealed super trait"
// approach due to coherence.
#[doc(hidden)]
type Sealed: crate::handler::sealed::HiddentTrait;
/// Call the handler with the given websocket stream and input parsed by
/// extractors.
async fn call(self, stream: WebSocket, input: In);
}
#[async_trait]
impl<F, Fut, B> WebSocketHandler<B, ()> for F
where
F: FnOnce(WebSocket) -> Fut + Send,
Fut: Future<Output = ()> + Send,
B: Send,
{
type Sealed = crate::handler::sealed::Hidden;
async fn call(self, stream: WebSocket, _: ()) {
self(stream).await
}
}
macro_rules! impl_ws_handler {
() => {
};
( $head:ident, $($tail:ident),* $(,)? ) => {
#[async_trait]
#[allow(non_snake_case)]
impl<F, Fut, B, $head, $($tail,)*> WebSocketHandler<B, ($head, $($tail,)*)> for F
where
B: Send,
$head: FromRequest<B> + Send + 'static,
$( $tail: FromRequest<B> + Send + 'static, )*
F: FnOnce(WebSocket, $head, $($tail,)*) -> Fut + Send,
Fut: Future<Output = ()> + Send,
{
type Sealed = crate::handler::sealed::Hidden;
async fn call(self, stream: WebSocket, ($head, $($tail,)*): ($head, $($tail,)*)) {
self(stream, $head, $($tail,)*).await
}
}
impl_ws_handler!($($tail,)*);
};
}
impl_ws_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
/// [`Service`] that upgrades connections to websockets and spawns a task to
/// handle the stream.
///
/// Created with [`ws`].
///
/// See the [module docs](crate::ws) for more details.
#[derive(Clone)]
pub struct WebSocketUpgrade<F> {
pub struct WebSocketUpgrade<F, B, T> {
callback: F,
config: WebSocketConfig,
_request_body: PhantomData<fn() -> (B, T)>,
}
impl<F> fmt::Debug for WebSocketUpgrade<F> {
impl<F, B, T> Clone for WebSocketUpgrade<F, B, T>
where
F: Clone,
{
fn clone(&self) -> Self {
Self {
callback: self.callback.clone(),
config: self.config,
_request_body: PhantomData,
}
}
}
impl<F, B, T> fmt::Debug for WebSocketUpgrade<F, B, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WebSocketUpgrade")
.field("callback", &format_args!("{}", std::any::type_name::<F>()))
@ -77,7 +190,7 @@ impl<F> fmt::Debug for WebSocketUpgrade<F> {
}
}
impl<F> WebSocketUpgrade<F> {
impl<F, B, T> WebSocketUpgrade<F, B, T> {
/// Set the size of the internal message send queue.
pub fn max_send_queue(mut self, max: usize) -> Self {
self.config.max_send_queue = Some(max);
@ -97,12 +210,13 @@ impl<F> WebSocketUpgrade<F> {
}
}
impl<ReqBody, F, Fut> Service<Request<ReqBody>> for WebSocketUpgrade<F>
impl<ReqBody, F, T> Service<Request<ReqBody>> for WebSocketUpgrade<F, ReqBody, T>
where
F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
F: WebSocketHandler<ReqBody, T> + Clone + Send + 'static,
T: FromRequest<ReqBody> + Send + 'static,
ReqBody: Send + 'static,
{
type Response = Response<Full<Bytes>>;
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = ResponseFuture;
@ -111,62 +225,104 @@ where
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
if !header_eq(
&req,
header::CONNECTION,
HeaderValue::from_static("upgrade"),
) {
return ResponseFuture::err(
StatusCode::BAD_REQUEST,
"Connection header did not include 'upgrade'",
);
}
let this = self.clone();
if !header_eq(&req, header::UPGRADE, HeaderValue::from_static("websocket")) {
return ResponseFuture::err(
StatusCode::BAD_REQUEST,
"`Upgrade` header did not include 'websocket'",
);
}
ResponseFuture(Box::pin(async move {
if req.method() != http::Method::GET {
return response(StatusCode::NOT_FOUND, "Request method must be `GET`");
}
if !header_eq(
&req,
header::SEC_WEBSOCKET_VERSION,
HeaderValue::from_static("13"),
) {
return ResponseFuture::err(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Version` header did not include '13'",
);
}
if !header_eq(
&req,
header::CONNECTION,
HeaderValue::from_static("upgrade"),
) {
return response(
StatusCode::BAD_REQUEST,
"Connection header did not include 'upgrade'",
);
}
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
key
} else {
return ResponseFuture::err(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Key` header missing",
);
};
if !header_eq(&req, header::UPGRADE, HeaderValue::from_static("websocket")) {
return response(
StatusCode::BAD_REQUEST,
"`Upgrade` header did not include 'websocket'",
);
}
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
if !header_eq(
&req,
header::SEC_WEBSOCKET_VERSION,
HeaderValue::from_static("13"),
) {
return response(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Version` header did not include '13'",
);
}
let config = self.config;
let callback = self.callback.clone();
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
key
} else {
return response(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Key` header missing",
);
};
tokio::spawn(async move {
let upgraded = on_upgrade.await.unwrap();
let socket =
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
.await;
let socket = WebSocket { inner: socket };
callback(socket).await;
});
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
ResponseFuture::ok(key)
let config = this.config;
let callback = this.callback.clone();
let mut req = RequestParts::new(req);
let input = match T::from_request(&mut req).await {
Ok(input) => input,
Err(rejection) => {
let res = rejection.into_response().map(box_body);
return Ok(res);
}
};
tokio::spawn(async move {
let upgraded = on_upgrade.await.unwrap();
let socket = WebSocketStream::from_raw_socket(
upgraded,
protocol::Role::Server,
Some(config),
)
.await;
let socket = WebSocket { inner: socket };
callback.call(socket, input).await;
});
let res = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(
http::header::CONNECTION,
HeaderValue::from_str("upgrade").unwrap(),
)
.header(
http::header::UPGRADE,
HeaderValue::from_str("websocket").unwrap(),
)
.header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()))
.body(box_body(Full::new(Bytes::new())))
.unwrap();
Ok(res)
}))
}
}
fn response<E>(status: StatusCode, body: &'static str) -> Result<Response<BoxBody>, E> {
let res = Response::builder()
.status(status)
.body(box_body(Full::from(body)))
.unwrap();
Ok(res)
}
fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
if let Some(header) = req.headers().get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
@ -175,6 +331,14 @@ fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
}
}
fn sign(key: &[u8]) -> HeaderValue {
let mut sha1 = Sha1::default();
sha1.update(key);
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
}
/// A stream of websocket messages.
#[derive(Debug)]
pub struct WebSocket {