Change WebSocket API to use an extractor (#121)

Fixes https://github.com/tokio-rs/axum/issues/111

Example usage:

```rust
use axum::{
    prelude::*,
    extract::ws::{WebSocketUpgrade, WebSocket},
    response::IntoResponse,
};

let app = route("/ws", get(handler));

async fn handler(ws: WebSocketUpgrade) -> impl IntoResponse {
    ws.on_upgrade(handle_socket)
}

async fn handle_socket(mut socket: WebSocket) {
    while let Some(msg) = socket.recv().await {
        let msg = if let Ok(msg) = msg {
            msg
        } else {
            // client disconnected
            return;
        };

        if socket.send(msg).await.is_err() {
            // client disconnected
            return;
        }
    }
}
```
This commit is contained in:
David Pedersen 2021-08-07 17:26:23 +02:00 committed by GitHub
parent 404a3b5e8a
commit 4194cf70da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 619 additions and 652 deletions

View file

@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Breaking changes
- Change WebSocket API to use an extractor ([#121](https://github.com/tokio-rs/axum/pull/121))
- Add `RoutingDsl::or` for combining routes. ([#108](https://github.com/tokio-rs/axum/pull/108))
- Ensure a `HandleError` service created from `axum::ServiceExt::handle_error`
_does not_ implement `RoutingDsl` as that could lead to confusing routing

View file

@ -14,9 +14,9 @@ use futures::{sink::SinkExt, stream::StreamExt};
use tokio::sync::broadcast;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::prelude::*;
use axum::response::Html;
use axum::ws::{ws, Message, WebSocket};
use axum::response::{Html, IntoResponse};
use axum::AddExtensionLayer;
// Our shared state
@ -33,7 +33,7 @@ async fn main() {
let app_state = Arc::new(AppState { user_set, tx });
let app = route("/", get(index))
.route("/websocket", ws(websocket))
.route("/websocket", get(websocket_handler))
.layer(AddExtensionLayer::new(app_state));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
@ -44,10 +44,14 @@ async fn main() {
.unwrap();
}
async fn websocket(
stream: WebSocket,
async fn websocket_handler(
ws: WebSocketUpgrade,
extract::Extension(state): extract::Extension<Arc<AppState>>,
) {
) -> impl IntoResponse {
ws.on_upgrade(|socket| websocket(socket, state))
}
async fn websocket(stream: WebSocket, state: Arc<AppState>) {
// By splitting we can send and receive at the same time.
let (mut sender, mut receiver) = stream.split();

View file

@ -7,11 +7,14 @@
//! ```
use axum::{
extract::TypedHeader,
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
TypedHeader,
},
prelude::*,
response::IntoResponse,
routing::nest,
service::ServiceExt,
ws::{ws, Message, WebSocket},
};
use http::StatusCode;
use std::net::SocketAddr;
@ -44,7 +47,7 @@ async fn main() {
)
// routes are matched from bottom to top, so we have to put `nest` at the
// top since it matches all routes
.route("/ws", ws(handle_socket))
.route("/ws", get(ws_handler))
// logging so we can see whats going on
.layer(
TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true)),
@ -59,15 +62,18 @@ async fn main() {
.unwrap();
}
async fn handle_socket(
mut socket: WebSocket,
// websocket handlers can also use extractors
async fn ws_handler(
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
) {
) -> impl IntoResponse {
if let Some(TypedHeader(user_agent)) = user_agent {
println!("`{}` connected", user_agent.as_str());
}
ws.on_upgrade(handle_socket)
}
async fn handle_socket(mut socket: WebSocket) {
if let Some(msg) = socket.recv().await {
if let Ok(msg) = msg {
println!("Client says: {:?}", msg);

View file

@ -254,6 +254,10 @@ pub mod connect_info;
pub mod extractor_middleware;
pub mod rejection;
#[cfg(feature = "ws")]
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub mod ws;
mod content_length_limit;
mod extension;
mod form;
@ -292,6 +296,11 @@ pub mod multipart;
#[doc(inline)]
pub use self::multipart::Multipart;
#[cfg(feature = "ws")]
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
#[doc(inline)]
pub use self::ws::WebSocketUpgrade;
#[cfg(feature = "headers")]
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
mod typed_header;

586
src/extract/ws.rs Normal file
View file

@ -0,0 +1,586 @@
//! Handle WebSocket connections.
//!
//! # Example
//!
//! ```
//! use axum::{
//! prelude::*,
//! extract::ws::{WebSocketUpgrade, WebSocket},
//! response::IntoResponse,
//! };
//!
//! let app = route("/ws", get(handler));
//!
//! async fn handler(ws: WebSocketUpgrade) -> impl IntoResponse {
//! ws.on_upgrade(handle_socket)
//! }
//!
//! async fn handle_socket(mut socket: WebSocket) {
//! while let Some(msg) = socket.recv().await {
//! let msg = if let Ok(msg) = msg {
//! msg
//! } else {
//! // client disconnected
//! return;
//! };
//!
//! if socket.send(msg).await.is_err() {
//! // client disconnected
//! return;
//! }
//! }
//! }
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
use self::rejection::*;
use super::{rejection::*, FromRequest, RequestParts};
use crate::response::IntoResponse;
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::{
sink::{Sink, SinkExt},
stream::{Stream, StreamExt},
};
use http::{
header::{self, HeaderName, HeaderValue},
Method, Response, StatusCode,
};
use hyper::{
upgrade::{OnUpgrade, Upgraded},
Body,
};
use sha1::{Digest, Sha1};
use std::{
borrow::Cow,
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio_tungstenite::{
tungstenite::protocol::{self, WebSocketConfig},
WebSocketStream,
};
use tower::BoxError;
/// Extractor for establishing WebSocket connections.
///
/// Note: This extractor requires the request method to be `GET` so it should
/// always be used with [`get`](crate::handler::get). Requests with other methods will be
/// rejected.
///
/// See the [module docs](self) for an example.
#[derive(Debug)]
pub struct WebSocketUpgrade {
config: WebSocketConfig,
protocols: Option<Box<[Cow<'static, str>]>>,
sec_websocket_key: HeaderValue,
on_upgrade: OnUpgrade,
sec_websocket_protocol: Option<HeaderValue>,
}
impl WebSocketUpgrade {
/// Set the size of the internal message send queue.
pub fn max_send_queue(mut self, max: usize) -> Self {
self.config.max_send_queue = Some(max);
self
}
/// Set the maximum message size (defaults to 64 megabytes)
pub fn max_message_size(mut self, max: usize) -> Self {
self.config.max_message_size = Some(max);
self
}
/// Set the maximum frame size (defaults to 16 megabytes)
pub fn max_frame_size(mut self, max: usize) -> Self {
self.config.max_frame_size = Some(max);
self
}
/// Set the known protocols.
///
/// If the protocol name specified by `Sec-WebSocket-Protocol` header
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and return the protocol name.
///
/// # Examples
///
/// ```
/// use axum::{
/// prelude::*,
/// extract::ws::{WebSocketUpgrade, WebSocket},
/// response::IntoResponse,
/// };
///
/// let app = route("/ws", get(handler));
///
/// async fn handler(ws: WebSocketUpgrade) -> impl IntoResponse {
/// ws.protocols(["graphql-ws", "graphql-transport-ws"])
/// .on_upgrade(|socket| async {
/// // ...
/// })
/// }
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn protocols<I>(mut self, protocols: I) -> Self
where
I: IntoIterator,
I::Item: Into<Cow<'static, str>>,
{
self.protocols = Some(
protocols
.into_iter()
.map(Into::into)
.collect::<Vec<_>>()
.into(),
);
self
}
/// Finalize upgrading the connection and call the provided callback with
/// the stream.
///
/// When using `WebSocketUpgrade`, the response produced by this method
/// should be returned from the handler. See the [module docs](self) for an
/// example.
pub fn on_upgrade<F, Fut>(self, callback: F) -> impl IntoResponse
where
F: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future + Send + 'static,
{
WebSocketUpgradeResponse {
extractor: self,
callback,
}
}
}
#[async_trait]
impl<B> FromRequest<B> for WebSocketUpgrade
where
B: Send,
{
type Rejection = WebSocketUpgradeRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
if req.method().ok_or(MethodAlreadyExtracted)? != Method::GET {
return Err(MethodNotGet.into());
}
if !header_contains(req, header::CONNECTION, "upgrade")? {
return Err(InvalidConnectionHeader.into());
}
if !header_eq(req, header::UPGRADE, "websocket")? {
return Err(InvalidUpgradeHeader.into());
}
if !header_eq(req, header::SEC_WEBSOCKET_VERSION, "13")? {
return Err(InvalidWebsocketVersionHeader.into());
}
let sec_websocket_key = if let Some(key) = req
.headers_mut()
.ok_or(HeadersAlreadyExtracted)?
.remove(header::SEC_WEBSOCKET_KEY)
{
key
} else {
return Err(WebsocketKeyHeaderMissing.into());
};
let on_upgrade = req
.extensions_mut()
.ok_or(ExtensionsAlreadyExtracted)?
.remove::<OnUpgrade>()
.unwrap();
let sec_websocket_protocol = req
.headers()
.ok_or(HeadersAlreadyExtracted)?
.get(header::SEC_WEBSOCKET_PROTOCOL)
.cloned();
Ok(Self {
config: Default::default(),
protocols: None,
sec_websocket_key,
on_upgrade,
sec_websocket_protocol,
})
}
}
fn header_eq<B>(
req: &RequestParts<B>,
key: HeaderName,
value: &'static str,
) -> Result<bool, HeadersAlreadyExtracted> {
if let Some(header) = req.headers().ok_or(HeadersAlreadyExtracted)?.get(&key) {
Ok(header.as_bytes().eq_ignore_ascii_case(value.as_bytes()))
} else {
Ok(false)
}
}
fn header_contains<B>(
req: &RequestParts<B>,
key: HeaderName,
value: &'static str,
) -> Result<bool, HeadersAlreadyExtracted> {
let header = if let Some(header) = req.headers().ok_or(HeadersAlreadyExtracted)?.get(&key) {
header
} else {
return Ok(false);
};
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
Ok(header.to_ascii_lowercase().contains(value))
} else {
Ok(false)
}
}
struct WebSocketUpgradeResponse<F> {
extractor: WebSocketUpgrade,
callback: F,
}
impl<F, Fut> IntoResponse for WebSocketUpgradeResponse<F>
where
F: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future + Send + 'static,
{
fn into_response(self) -> Response<Body> {
// check requested protocols
let protocol = self
.extractor
.sec_websocket_protocol
.as_ref()
.and_then(|req_protocols| {
let req_protocols = req_protocols.to_str().ok()?;
let protocols = self.extractor.protocols.as_ref()?;
req_protocols
.split(',')
.map(|req_p| req_p.trim())
.find(|req_p| protocols.iter().any(|p| p == req_p))
});
let protocol = match protocol {
Some(protocol) => {
if let Ok(protocol) = HeaderValue::from_str(protocol) {
Some(protocol)
} else {
return (
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Protocol` header is invalid",
)
.into_response();
}
}
None => None,
};
let callback = self.callback;
let on_upgrade = self.extractor.on_upgrade;
let config = self.extractor.config;
tokio::spawn(async move {
let upgraded = on_upgrade.await.expect("connection upgrade failed");
let socket =
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
.await;
let socket = WebSocket { inner: socket };
callback(socket).await;
});
let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(
header::CONNECTION,
HeaderValue::from_str("upgrade").unwrap(),
)
.header(header::UPGRADE, HeaderValue::from_str("websocket").unwrap())
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(self.extractor.sec_websocket_key.as_bytes()),
);
if let Some(protocol) = protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
builder.body(Body::empty()).unwrap()
}
}
/// A stream of WebSocket messages.
#[derive(Debug)]
pub struct WebSocket {
inner: WebSocketStream<Upgraded>,
}
impl WebSocket {
/// Receive another message.
///
/// Returns `None` if the stream stream has closed.
pub async fn recv(&mut self) -> Option<Result<Message, BoxError>> {
self.next().await
}
/// Send a message.
pub async fn send(&mut self, msg: Message) -> Result<(), BoxError> {
self.inner.send(msg.inner).await.map_err(Into::into)
}
/// Gracefully close this WebSocket.
pub async fn close(mut self) -> Result<(), BoxError> {
self.inner.close(None).await.map_err(Into::into)
}
}
impl Stream for WebSocket {
type Item = Result<Message, BoxError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx).map(|option_msg| {
option_msg.map(|result_msg| {
result_msg
.map_err(Into::into)
.map(|inner| Message { inner })
})
})
}
}
impl Sink<Message> for WebSocket {
type Error = BoxError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into)
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
Pin::new(&mut self.inner)
.start_send(item.inner)
.map_err(Into::into)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into)
}
}
/// A WebSocket message.
#[derive(Eq, PartialEq, Clone)]
pub struct Message {
inner: protocol::Message,
}
impl Message {
/// Construct a new Text `Message`.
pub fn text<S>(s: S) -> Message
where
S: Into<String>,
{
Message {
inner: protocol::Message::text(s),
}
}
/// Construct a new Binary `Message`.
pub fn binary<V>(v: V) -> Message
where
V: Into<Vec<u8>>,
{
Message {
inner: protocol::Message::binary(v),
}
}
/// Construct a new Ping `Message`.
pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message {
Message {
inner: protocol::Message::Ping(v.into()),
}
}
/// Construct a new Pong `Message`.
///
/// Note that one rarely needs to manually construct a Pong message because
/// the underlying tungstenite socket automatically responds to the Ping
/// messages it receives. Manual construction might still be useful in some
/// cases like in tests or to send unidirectional heartbeats.
pub fn pong<V: Into<Vec<u8>>>(v: V) -> Message {
Message {
inner: protocol::Message::Pong(v.into()),
}
}
/// Construct the default Close `Message`.
pub fn close() -> Message {
Message {
inner: protocol::Message::Close(None),
}
}
/// Construct a Close `Message` with a code and reason.
pub fn close_with<C, R>(code: C, reason: R) -> Message
where
C: Into<u16>,
R: Into<Cow<'static, str>>,
{
Message {
inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
code: protocol::frame::coding::CloseCode::from(code.into()),
reason: reason.into(),
})),
}
}
/// Returns true if this message is a Text message.
pub fn is_text(&self) -> bool {
self.inner.is_text()
}
/// Returns true if this message is a Binary message.
pub fn is_binary(&self) -> bool {
self.inner.is_binary()
}
/// Returns true if this message a is a Close message.
pub fn is_close(&self) -> bool {
self.inner.is_close()
}
/// Returns true if this message is a Ping message.
pub fn is_ping(&self) -> bool {
self.inner.is_ping()
}
/// Returns true if this message is a Pong message.
pub fn is_pong(&self) -> bool {
self.inner.is_pong()
}
/// Try to get the close frame (close code and reason)
pub fn close_frame(&self) -> Option<(u16, &str)> {
if let protocol::Message::Close(Some(close_frame)) = &self.inner {
Some((close_frame.code.into(), close_frame.reason.as_ref()))
} else {
None
}
}
/// Try to get a reference to the string text, if this is a Text message.
pub fn to_str(&self) -> Option<&str> {
if let protocol::Message::Text(s) = &self.inner {
Some(s)
} else {
None
}
}
/// Return the bytes of this message, if the message can contain data.
pub fn as_bytes(&self) -> &[u8] {
match self.inner {
protocol::Message::Text(ref s) => s.as_bytes(),
protocol::Message::Binary(ref v) => v,
protocol::Message::Ping(ref v) => v,
protocol::Message::Pong(ref v) => v,
protocol::Message::Close(_) => &[],
}
}
/// Destructure this message into binary data.
pub fn into_bytes(self) -> Vec<u8> {
self.inner.into_data()
}
}
impl fmt::Debug for Message {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.inner.fmt(f)
}
}
impl From<Message> for Vec<u8> {
fn from(msg: Message) -> Self {
msg.into_bytes()
}
}
fn sign(key: &[u8]) -> HeaderValue {
let mut sha1 = Sha1::default();
sha1.update(key);
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
}
pub mod rejection {
//! WebSocket specific rejections.
use crate::extract::rejection::*;
define_rejection! {
#[status = METHOD_NOT_ALLOWED]
#[body = "Request method must be `GET`"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct MethodNotGet;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Connection header did not include 'upgrade'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidConnectionHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Upgrade` header did not include 'websocket'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidUpgradeHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Sec-Websocket-Version` header did not include '13'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidWebsocketVersionHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Sec-Websocket-Key` header missing"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct WebsocketKeyHeaderMissing;
}
composite_rejection! {
/// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
///
/// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
/// extractor can fail.
pub enum WebSocketUpgradeRejection {
MethodNotGet,
InvalidConnectionHeader,
InvalidUpgradeHeader,
InvalidWebsocketVersionHeader,
WebsocketKeyHeaderMissing,
MethodAlreadyExtracted,
HeadersAlreadyExtracted,
ExtensionsAlreadyExtracted,
}
}
}

View file

@ -727,10 +727,6 @@ pub mod routing;
pub mod service;
pub mod sse;
#[cfg(feature = "ws")]
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub mod ws;
#[cfg(test)]
mod tests;

View file

@ -1,10 +0,0 @@
//! Future types.
use crate::body::BoxBody;
use http::Response;
use std::convert::Infallible;
opaque_future! {
/// Response future for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub type ResponseFuture = futures_util::future::BoxFuture<'static, Result<Response<BoxBody>, Infallible>>;
}

View file

@ -1,625 +0,0 @@
//! Handle websocket connections.
//!
//! # Example
//!
//! ```
//! use axum::{prelude::*, ws::{ws, WebSocket}};
//!
//! let app = route("/ws", ws(handle_socket));
//!
//! async fn handle_socket(mut socket: WebSocket) {
//! while let Some(msg) = socket.recv().await {
//! let msg = msg.unwrap();
//! socket.send(msg).await.unwrap();
//! }
//! }
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! Websocket handlers can also use extractors, however the first function
//! argument must be of type [`WebSocket`]:
//!
//! ```
//! use axum::{prelude::*, extract::{RequestParts, FromRequest}, ws::{ws, WebSocket}};
//! use http::{HeaderMap, StatusCode};
//!
//! /// An extractor that authorizes requests.
//! struct RequireAuth;
//!
//! #[async_trait::async_trait]
//! impl<B> FromRequest<B> for RequireAuth
//! where
//! B: Send,
//! {
//! type Rejection = StatusCode;
//!
//! async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
//! # unimplemented!()
//! // Put your auth logic here...
//! }
//! }
//!
//! let app = route("/ws", ws(handle_socket));
//!
//! async fn handle_socket(
//! mut socket: WebSocket,
//! // Run `RequireAuth` for each request before upgrading.
//! _auth: RequireAuth,
//! ) {
//! // ...
//! }
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
use crate::body::{box_body, BoxBody};
use crate::extract::{FromRequest, RequestParts};
use crate::response::IntoResponse;
use async_trait::async_trait;
use bytes::Bytes;
use future::ResponseFuture;
use futures_util::{
sink::{Sink, SinkExt},
stream::{Stream, StreamExt},
};
use http::{
header::{self, HeaderName},
HeaderValue, Request, Response, StatusCode,
};
use http_body::Full;
use hyper::upgrade::{OnUpgrade, Upgraded};
use sha1::{Digest, Sha1};
use std::pin::Pin;
use std::sync::Arc;
use std::{
borrow::Cow, convert::Infallible, fmt, future::Future, marker::PhantomData, task::Context,
task::Poll,
};
use tokio_tungstenite::{
tungstenite::protocol::{self, WebSocketConfig},
WebSocketStream,
};
use tower::{BoxError, Service};
pub mod future;
/// Create a new [`WebSocketUpgrade`] service that will call the closure with
/// each connection.
///
/// See the [module docs](crate::ws) for more details.
pub fn ws<F, B, T>(callback: F) -> WebSocketUpgrade<F, B, T>
where
F: WebSocketHandler<B, T>,
{
WebSocketUpgrade {
callback,
config: WebSocketConfig::default(),
protocols: Vec::new().into(),
_request_body: PhantomData,
}
}
/// Trait for async functions that can be used to handle websocket requests.
///
/// You shouldn't need to depend on this trait directly. It is automatically
/// implemented to closures of the right types.
///
/// See the [module docs](crate::ws) for more details.
#[async_trait]
pub trait WebSocketHandler<B, In>: Sized {
// This seals the trait. We cannot use the regular "sealed super trait"
// approach due to coherence.
#[doc(hidden)]
type Sealed: crate::handler::sealed::HiddentTrait;
/// Call the handler with the given websocket stream and input parsed by
/// extractors.
async fn call(self, stream: WebSocket, input: In);
}
#[async_trait]
impl<F, Fut, B> WebSocketHandler<B, ()> for F
where
F: FnOnce(WebSocket) -> Fut + Send,
Fut: Future<Output = ()> + Send,
B: Send,
{
type Sealed = crate::handler::sealed::Hidden;
async fn call(self, stream: WebSocket, _: ()) {
self(stream).await
}
}
macro_rules! impl_ws_handler {
() => {
};
( $head:ident, $($tail:ident),* $(,)? ) => {
#[async_trait]
#[allow(non_snake_case)]
impl<F, Fut, B, $head, $($tail,)*> WebSocketHandler<B, ($head, $($tail,)*)> for F
where
B: Send,
$head: FromRequest<B> + Send + 'static,
$( $tail: FromRequest<B> + Send + 'static, )*
F: FnOnce(WebSocket, $head, $($tail,)*) -> Fut + Send,
Fut: Future<Output = ()> + Send,
{
type Sealed = crate::handler::sealed::Hidden;
async fn call(self, stream: WebSocket, ($head, $($tail,)*): ($head, $($tail,)*)) {
self(stream, $head, $($tail,)*).await
}
}
impl_ws_handler!($($tail,)*);
};
}
impl_ws_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
/// [`Service`] that upgrades connections to websockets and spawns a task to
/// handle the stream.
///
/// Created with [`ws`].
///
/// See the [module docs](crate::ws) for more details.
pub struct WebSocketUpgrade<F, B, T> {
callback: F,
config: WebSocketConfig,
protocols: Arc<[Cow<'static, str>]>,
_request_body: PhantomData<fn() -> (B, T)>,
}
impl<F, B, T> Clone for WebSocketUpgrade<F, B, T>
where
F: Clone,
{
fn clone(&self) -> Self {
Self {
callback: self.callback.clone(),
config: self.config,
protocols: self.protocols.clone(),
_request_body: PhantomData,
}
}
}
impl<F, B, T> fmt::Debug for WebSocketUpgrade<F, B, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WebSocketUpgrade")
.field("callback", &format_args!("{}", std::any::type_name::<F>()))
.field("config", &self.config)
.finish()
}
}
impl<F, B, T> WebSocketUpgrade<F, B, T> {
/// Set the size of the internal message send queue.
pub fn max_send_queue(mut self, max: usize) -> Self {
self.config.max_send_queue = Some(max);
self
}
/// Set the maximum message size (defaults to 64 megabytes)
pub fn max_message_size(mut self, max: usize) -> Self {
self.config.max_message_size = Some(max);
self
}
/// Set the maximum frame size (defaults to 16 megabytes)
pub fn max_frame_size(mut self, max: usize) -> Self {
self.config.max_frame_size = Some(max);
self
}
/// Set the known protocols.
///
/// If the protocol name specified by `Sec-WebSocket-Protocol` header
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and return the protocol name.
///
/// # Examples
///
/// ```no_run
/// # use axum::prelude::*;
/// # use axum::ws::{ws, WebSocket};
/// # use std::net::SocketAddr;
/// #
/// # async fn handle_socket(socket: WebSocket) {
/// # todo!()
/// # }
/// #
/// # #[tokio::main]
/// # async fn main() {
/// let app = route("/ws", ws(handle_socket).protocols(["graphql-ws", "graphql-transport-ws"]));
/// #
/// # let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
/// # axum::Server::bind(&addr)
/// # .serve(app.into_make_service())
/// # .await
/// # .unwrap();
/// # }
///
/// ```
pub fn protocols<I>(mut self, protocols: I) -> Self
where
I: IntoIterator,
I::Item: Into<Cow<'static, str>>,
{
self.protocols = protocols
.into_iter()
.map(Into::into)
.collect::<Vec<_>>()
.into();
self
}
}
impl<ReqBody, F, T> Service<Request<ReqBody>> for WebSocketUpgrade<F, ReqBody, T>
where
F: WebSocketHandler<ReqBody, T> + Clone + Send + 'static,
T: FromRequest<ReqBody> + Send + 'static,
ReqBody: Send + 'static,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = ResponseFuture;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
let this = self.clone();
let protocols = self.protocols.clone();
ResponseFuture {
future: Box::pin(async move {
if req.method() != http::Method::GET {
return response(StatusCode::NOT_FOUND, "Request method must be `GET`");
}
if !header_contains(&req, header::CONNECTION, "upgrade") {
return response(
StatusCode::BAD_REQUEST,
"Connection header did not include 'upgrade'",
);
}
if !header_eq(&req, header::UPGRADE, "websocket") {
return response(
StatusCode::BAD_REQUEST,
"`Upgrade` header did not include 'websocket'",
);
}
if !header_eq(&req, header::SEC_WEBSOCKET_VERSION, "13") {
return response(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Version` header did not include '13'",
);
}
// check requested protocols
let protocol =
req.headers()
.get(&header::SEC_WEBSOCKET_PROTOCOL)
.and_then(|req_protocols| {
let req_protocols = req_protocols.to_str().ok()?;
req_protocols
.split(',')
.map(|req_p| req_p.trim())
.find(|req_p| protocols.iter().any(|p| p == req_p))
});
let protocol = match protocol {
Some(protocol) => {
if let Ok(protocol) = HeaderValue::from_str(protocol) {
Some(protocol)
} else {
return response(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Protocol` header is invalid",
);
}
}
None => None,
};
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
key
} else {
return response(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Key` header missing",
);
};
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
let config = this.config;
let callback = this.callback.clone();
let mut req = RequestParts::new(req);
let input = match T::from_request(&mut req).await {
Ok(input) => input,
Err(rejection) => {
let res = rejection.into_response().map(box_body);
return Ok(res);
}
};
tokio::spawn(async move {
let upgraded = on_upgrade.await.unwrap();
let socket = WebSocketStream::from_raw_socket(
upgraded,
protocol::Role::Server,
Some(config),
)
.await;
let socket = WebSocket { inner: socket };
callback.call(socket, input).await;
});
let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(
http::header::CONNECTION,
HeaderValue::from_str("upgrade").unwrap(),
)
.header(
http::header::UPGRADE,
HeaderValue::from_str("websocket").unwrap(),
)
.header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()));
if let Some(protocol) = protocol {
builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
let res = builder.body(box_body(Full::new(Bytes::new()))).unwrap();
Ok(res)
}),
}
}
}
fn response<E>(status: StatusCode, body: &'static str) -> Result<Response<BoxBody>, E> {
let res = Response::builder()
.status(status)
.body(box_body(Full::from(body)))
.unwrap();
Ok(res)
}
fn header_eq<B>(req: &Request<B>, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = req.headers().get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
} else {
false
}
}
fn header_contains<B>(req: &Request<B>, key: HeaderName, value: &'static str) -> bool {
let header = if let Some(header) = req.headers().get(&key) {
header
} else {
return false;
};
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
header.to_ascii_lowercase().contains(value)
} else {
false
}
}
fn sign(key: &[u8]) -> HeaderValue {
let mut sha1 = Sha1::default();
sha1.update(key);
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
}
/// A stream of websocket messages.
#[derive(Debug)]
pub struct WebSocket {
inner: WebSocketStream<Upgraded>,
}
impl WebSocket {
/// Receive another message.
///
/// Returns `None` if the stream stream has closed.
pub async fn recv(&mut self) -> Option<Result<Message, BoxError>> {
self.next().await
}
/// Send a message.
pub async fn send(&mut self, msg: Message) -> Result<(), BoxError> {
self.inner.send(msg.inner).await.map_err(Into::into)
}
/// Gracefully close this websocket.
pub async fn close(mut self) -> Result<(), BoxError> {
self.inner.close(None).await.map_err(Into::into)
}
}
impl Stream for WebSocket {
type Item = Result<Message, BoxError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx).map(|option_msg| {
option_msg.map(|result_msg| {
result_msg
.map_err(Into::into)
.map(|inner| Message { inner })
})
})
}
}
impl Sink<Message> for WebSocket {
type Error = BoxError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into)
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
Pin::new(&mut self.inner)
.start_send(item.inner)
.map_err(Into::into)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into)
}
}
/// A WebSocket message.
#[derive(Eq, PartialEq, Clone)]
pub struct Message {
inner: protocol::Message,
}
impl Message {
/// Construct a new Text `Message`.
pub fn text<S>(s: S) -> Message
where
S: Into<String>,
{
Message {
inner: protocol::Message::text(s),
}
}
/// Construct a new Binary `Message`.
pub fn binary<V>(v: V) -> Message
where
V: Into<Vec<u8>>,
{
Message {
inner: protocol::Message::binary(v),
}
}
/// Construct a new Ping `Message`.
pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message {
Message {
inner: protocol::Message::Ping(v.into()),
}
}
/// Construct a new Pong `Message`.
///
/// Note that one rarely needs to manually construct a Pong message because
/// the underlying tungstenite socket automatically responds to the Ping
/// messages it receives. Manual construction might still be useful in some
/// cases like in tests or to send unidirectional heartbeats.
pub fn pong<V: Into<Vec<u8>>>(v: V) -> Message {
Message {
inner: protocol::Message::Pong(v.into()),
}
}
/// Construct the default Close `Message`.
pub fn close() -> Message {
Message {
inner: protocol::Message::Close(None),
}
}
/// Construct a Close `Message` with a code and reason.
pub fn close_with<C, R>(code: C, reason: R) -> Message
where
C: Into<u16>,
R: Into<Cow<'static, str>>,
{
Message {
inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
code: protocol::frame::coding::CloseCode::from(code.into()),
reason: reason.into(),
})),
}
}
/// Returns true if this message is a Text message.
pub fn is_text(&self) -> bool {
self.inner.is_text()
}
/// Returns true if this message is a Binary message.
pub fn is_binary(&self) -> bool {
self.inner.is_binary()
}
/// Returns true if this message a is a Close message.
pub fn is_close(&self) -> bool {
self.inner.is_close()
}
/// Returns true if this message is a Ping message.
pub fn is_ping(&self) -> bool {
self.inner.is_ping()
}
/// Returns true if this message is a Pong message.
pub fn is_pong(&self) -> bool {
self.inner.is_pong()
}
/// Try to get the close frame (close code and reason)
pub fn close_frame(&self) -> Option<(u16, &str)> {
if let protocol::Message::Close(Some(close_frame)) = &self.inner {
Some((close_frame.code.into(), close_frame.reason.as_ref()))
} else {
None
}
}
/// Try to get a reference to the string text, if this is a Text message.
pub fn to_str(&self) -> Option<&str> {
if let protocol::Message::Text(s) = &self.inner {
Some(s)
} else {
None
}
}
/// Return the bytes of this message, if the message can contain data.
pub fn as_bytes(&self) -> &[u8] {
match self.inner {
protocol::Message::Text(ref s) => s.as_bytes(),
protocol::Message::Binary(ref v) => v,
protocol::Message::Ping(ref v) => v,
protocol::Message::Pong(ref v) => v,
protocol::Message::Close(_) => &[],
}
}
/// Destructure this message into binary data.
pub fn into_bytes(self) -> Vec<u8> {
self.inner.into_data()
}
}
impl fmt::Debug for Message {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.inner.fmt(f)
}
}
impl From<Message> for Vec<u8> {
fn from(msg: Message) -> Self {
msg.into_bytes()
}
}