Add support for WebSockets over HTTP/2 (#2894)

This commit is contained in:
Sabrina Jewson 2024-10-06 08:58:34 +01:00 committed by GitHub
parent d783a8b17e
commit 64e6edac05
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 373 additions and 85 deletions

View file

@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **breaking:** Upgrade matchit to 0.8, changing the path parameter syntax from `/:single` and `/*many` - **breaking:** Upgrade matchit to 0.8, changing the path parameter syntax from `/:single` and `/*many`
to `/{single}` and `/{*many}`; the old syntax produces a panic to avoid silent change in behavior ([#2645]) to `/{single}` and `/{*many}`; the old syntax produces a panic to avoid silent change in behavior ([#2645])
- **change:** Update minimum rust version to 1.75 ([#2943]) - **change:** Update minimum rust version to 1.75 ([#2943])
- **added:** Add support WebSockets over HTTP/2.
They can be enabled by changing `get(ws_endpoint)` handlers to `any(ws_endpoint)`.
[#2473]: https://github.com/tokio-rs/axum/pull/2473 [#2473]: https://github.com/tokio-rs/axum/pull/2473
[#2645]: https://github.com/tokio-rs/axum/pull/2645 [#2645]: https://github.com/tokio-rs/axum/pull/2645

View file

@ -112,6 +112,7 @@ features = [
[dev-dependencies] [dev-dependencies]
anyhow = "1.0" anyhow = "1.0"
axum-macros = { path = "../axum-macros", features = ["__private"] } axum-macros = { path = "../axum-macros", features = ["__private"] }
hyper = { version = "1.1.0", features = ["client"] }
quickcheck = "1.0" quickcheck = "1.0"
quickcheck_macros = "1.0" quickcheck_macros = "1.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] } reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] }

View file

@ -5,12 +5,12 @@
//! ``` //! ```
//! use axum::{ //! use axum::{
//! extract::ws::{WebSocketUpgrade, WebSocket}, //! extract::ws::{WebSocketUpgrade, WebSocket},
//! routing::get, //! routing::any,
//! response::{IntoResponse, Response}, //! response::{IntoResponse, Response},
//! Router, //! Router,
//! }; //! };
//! //!
//! let app = Router::new().route("/ws", get(handler)); //! let app = Router::new().route("/ws", any(handler));
//! //!
//! async fn handler(ws: WebSocketUpgrade) -> Response { //! async fn handler(ws: WebSocketUpgrade) -> Response {
//! ws.on_upgrade(handle_socket) //! ws.on_upgrade(handle_socket)
@ -40,7 +40,7 @@
//! use axum::{ //! use axum::{
//! extract::{ws::{WebSocketUpgrade, WebSocket}, State}, //! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
//! response::Response, //! response::Response,
//! routing::get, //! routing::any,
//! Router, //! Router,
//! }; //! };
//! //!
@ -58,7 +58,7 @@
//! } //! }
//! //!
//! let app = Router::new() //! let app = Router::new()
//! .route("/ws", get(handler)) //! .route("/ws", any(handler))
//! .with_state(AppState { /* ... */ }); //! .with_state(AppState { /* ... */ });
//! # let _: Router = app; //! # let _: Router = app;
//! ``` //! ```
@ -101,7 +101,7 @@ use futures_util::{
use http::{ use http::{
header::{self, HeaderMap, HeaderName, HeaderValue}, header::{self, HeaderMap, HeaderName, HeaderValue},
request::Parts, request::Parts,
Method, StatusCode, Method, StatusCode, Version,
}; };
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use sha1::{Digest, Sha1}; use sha1::{Digest, Sha1};
@ -121,17 +121,20 @@ use tokio_tungstenite::{
/// Extractor for establishing WebSocket connections. /// Extractor for establishing WebSocket connections.
/// ///
/// Note: This extractor requires the request method to be `GET` so it should /// For HTTP/1.1 requests, this extractor requires the request method to be `GET`;
/// always be used with [`get`](crate::routing::get). Requests with other methods will be /// in later versions, `CONNECT` is used instead.
/// rejected. /// To support both, it should be used with [`any`](crate::routing::any).
/// ///
/// See the [module docs](self) for an example. /// See the [module docs](self) for an example.
///
/// [`MethodFilter`]: crate::routing::MethodFilter
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))] #[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> { pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
config: WebSocketConfig, config: WebSocketConfig,
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
protocol: Option<HeaderValue>, protocol: Option<HeaderValue>,
sec_websocket_key: HeaderValue, /// `None` if HTTP/2+ WebSockets are used.
sec_websocket_key: Option<HeaderValue>,
on_upgrade: hyper::upgrade::OnUpgrade, on_upgrade: hyper::upgrade::OnUpgrade,
on_failed_upgrade: F, on_failed_upgrade: F,
sec_websocket_protocol: Option<HeaderValue>, sec_websocket_protocol: Option<HeaderValue>,
@ -212,12 +215,12 @@ impl<F> WebSocketUpgrade<F> {
/// ``` /// ```
/// use axum::{ /// use axum::{
/// extract::ws::{WebSocketUpgrade, WebSocket}, /// extract::ws::{WebSocketUpgrade, WebSocket},
/// routing::get, /// routing::any,
/// response::{IntoResponse, Response}, /// response::{IntoResponse, Response},
/// Router, /// Router,
/// }; /// };
/// ///
/// let app = Router::new().route("/ws", get(handler)); /// let app = Router::new().route("/ws", any(handler));
/// ///
/// async fn handler(ws: WebSocketUpgrade) -> Response { /// async fn handler(ws: WebSocketUpgrade) -> Response {
/// ws.protocols(["graphql-ws", "graphql-transport-ws"]) /// ws.protocols(["graphql-ws", "graphql-transport-ws"])
@ -329,25 +332,34 @@ impl<F> WebSocketUpgrade<F> {
callback(socket).await; callback(socket).await;
}); });
#[allow(clippy::declare_interior_mutable_const)] if let Some(sec_websocket_key) = &self.sec_websocket_key {
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); // If `sec_websocket_key` was `Some`, we are using HTTP/1.1.
#[allow(clippy::declare_interior_mutable_const)]
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
let mut builder = Response::builder() #[allow(clippy::declare_interior_mutable_const)]
.status(StatusCode::SWITCHING_PROTOCOLS) const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
.header(header::CONNECTION, UPGRADE) #[allow(clippy::declare_interior_mutable_const)]
.header(header::UPGRADE, WEBSOCKET) const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(self.sec_websocket_key.as_bytes()),
);
if let Some(protocol) = self.protocol { let mut builder = Response::builder()
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); .status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(sec_websocket_key.as_bytes()),
);
if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
builder.body(Body::empty()).unwrap()
} else {
// Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond
// with a 2XX with an empty body:
// <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
Response::new(Body::empty())
} }
builder.body(Body::empty()).unwrap()
} }
} }
@ -387,28 +399,49 @@ where
type Rejection = WebSocketUpgradeRejection; type Rejection = WebSocketUpgradeRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if parts.method != Method::GET { let sec_websocket_key = if parts.version <= Version::HTTP_11 {
return Err(MethodNotGet.into()); if parts.method != Method::GET {
} return Err(MethodNotGet.into());
}
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") { if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into()); return Err(InvalidConnectionHeader.into());
} }
if !header_eq(&parts.headers, header::UPGRADE, "websocket") { if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into()); return Err(InvalidUpgradeHeader.into());
} }
Some(
parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?
.clone(),
)
} else {
if parts.method != Method::CONNECT {
return Err(MethodNotConnect.into());
}
// if this feature flag is disabled, we wont be receiving an HTTP/2 request to begin
// with.
#[cfg(feature = "http2")]
if parts
.extensions
.get::<hyper::ext::Protocol>()
.map_or(true, |p| p.as_str() != "websocket")
{
return Err(InvalidProtocolPseudoheader.into());
}
None
};
if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") { if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
return Err(InvalidWebSocketVersionHeader.into()); return Err(InvalidWebSocketVersionHeader.into());
} }
let sec_websocket_key = parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?
.clone();
let on_upgrade = parts let on_upgrade = parts
.extensions .extensions
.remove::<hyper::upgrade::OnUpgrade>() .remove::<hyper::upgrade::OnUpgrade>()
@ -706,6 +739,13 @@ pub mod rejection {
pub struct MethodNotGet; pub struct MethodNotGet;
} }
define_rejection! {
#[status = METHOD_NOT_ALLOWED]
#[body = "Request method must be `CONNECT`"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct MethodNotConnect;
}
define_rejection! { define_rejection! {
#[status = BAD_REQUEST] #[status = BAD_REQUEST]
#[body = "Connection header did not include 'upgrade'"] #[body = "Connection header did not include 'upgrade'"]
@ -720,6 +760,13 @@ pub mod rejection {
pub struct InvalidUpgradeHeader; pub struct InvalidUpgradeHeader;
} }
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`:protocol` pseudo-header did not include 'websocket'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidProtocolPseudoheader;
}
define_rejection! { define_rejection! {
#[status = BAD_REQUEST] #[status = BAD_REQUEST]
#[body = "`Sec-WebSocket-Version` header did not include '13'"] #[body = "`Sec-WebSocket-Version` header did not include '13'"]
@ -755,8 +802,10 @@ pub mod rejection {
/// extractor can fail. /// extractor can fail.
pub enum WebSocketUpgradeRejection { pub enum WebSocketUpgradeRejection {
MethodNotGet, MethodNotGet,
MethodNotConnect,
InvalidConnectionHeader, InvalidConnectionHeader,
InvalidUpgradeHeader, InvalidUpgradeHeader,
InvalidProtocolPseudoheader,
InvalidWebSocketVersionHeader, InvalidWebSocketVersionHeader,
WebSocketKeyHeaderMissing, WebSocketKeyHeaderMissing,
ConnectionNotUpgradable, ConnectionNotUpgradable,
@ -838,14 +887,18 @@ mod tests {
use std::future::ready; use std::future::ready;
use super::*; use super::*;
use crate::{routing::get, test_helpers::spawn_service, Router}; use crate::{routing::any, test_helpers::spawn_service, Router};
use http::{Request, Version}; use http::{Request, Version};
use http_body_util::BodyExt as _;
use hyper_util::rt::TokioExecutor;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite; use tokio_tungstenite::tungstenite;
use tower::ServiceExt; use tower::ServiceExt;
#[crate::test] #[crate::test]
async fn rejects_http_1_0_requests() { async fn rejects_http_1_0_requests() {
let svc = get(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| { let svc = any(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
let rejection = ws.unwrap_err(); let rejection = ws.unwrap_err();
assert!(matches!( assert!(matches!(
rejection, rejection,
@ -874,7 +927,7 @@ mod tests {
async fn handler(ws: WebSocketUpgrade) -> Response { async fn handler(ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(|_| async {}) ws.on_upgrade(|_| async {})
} }
let _: Router = Router::new().route("/", get(handler)); let _: Router = Router::new().route("/", any(handler));
} }
#[allow(dead_code)] #[allow(dead_code)]
@ -883,16 +936,61 @@ mod tests {
ws.on_failed_upgrade(|_error: Error| println!("oops!")) ws.on_failed_upgrade(|_error: Error| println!("oops!"))
.on_upgrade(|_| async {}) .on_upgrade(|_| async {})
} }
let _: Router = Router::new().route("/", get(handler)); let _: Router = Router::new().route("/", any(handler));
} }
#[crate::test] #[crate::test]
async fn integration_test() { async fn integration_test() {
let app = Router::new().route( let addr = spawn_service(echo_app());
"/echo", let (socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
get(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))), .await
); .unwrap();
test_echo_app(socket).await;
}
#[crate::test]
#[cfg(feature = "http2")]
async fn http2() {
let addr = spawn_service(echo_app());
let io = TokioIo::new(TcpStream::connect(addr).await.unwrap());
let (mut send_request, conn) =
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(io)
.await
.unwrap();
// Wait a little for the SETTINGS frame to go through…
for _ in 0..10 {
tokio::task::yield_now().await;
}
assert!(conn.is_extended_connect_protocol_enabled());
tokio::spawn(async {
conn.await.unwrap();
});
let req = Request::builder()
.method(Method::CONNECT)
.extension(hyper::ext::Protocol::from_static("websocket"))
.uri("/echo")
.header("sec-websocket-version", "13")
.header("Host", "server.example.com")
.body(Body::empty())
.unwrap();
let response = send_request.send_request(req).await.unwrap();
let status = response.status();
if status != 200 {
let body = response.into_body().collect().await.unwrap().to_bytes();
let body = std::str::from_utf8(&body).unwrap();
panic!("response status was {}: {body}", status);
}
let upgraded = hyper::upgrade::on(response).await.unwrap();
let upgraded = TokioIo::new(upgraded);
let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
test_echo_app(socket).await;
}
fn echo_app() -> Router {
async fn handle_socket(mut socket: WebSocket) { async fn handle_socket(mut socket: WebSocket) {
while let Some(Ok(msg)) = socket.recv().await { while let Some(Ok(msg)) = socket.recv().await {
match msg { match msg {
@ -908,11 +1006,13 @@ mod tests {
} }
} }
let addr = spawn_service(app); Router::new().route(
let (mut socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo")) "/echo",
.await any(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
.unwrap(); )
}
async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(mut socket: WebSocketStream<S>) {
let input = tungstenite::Message::Text("foobar".to_owned()); let input = tungstenite::Message::Text("foobar".to_owned());
socket.send(input.clone()).await.unwrap(); socket.send(input.clone()).await.unwrap();
let output = socket.next().await.unwrap().unwrap(); let output = socket.next().await.unwrap().unwrap();

View file

@ -1035,13 +1035,11 @@ where
match $svc { match $svc {
MethodEndpoint::None => {} MethodEndpoint::None => {}
MethodEndpoint::Route(route) => { MethodEndpoint::Route(route) => {
return RouteFuture::from_future(route.clone().oneshot_inner($req)) return route.clone().oneshot_inner($req);
.strip_body($method == Method::HEAD);
} }
MethodEndpoint::BoxedHandler(handler) => { MethodEndpoint::BoxedHandler(handler) => {
let route = handler.clone().into_route(state); let mut route = handler.clone().into_route(state);
return RouteFuture::from_future(route.clone().oneshot_inner($req)) return route.oneshot_inner($req);
.strip_body($method == Method::HEAD);
} }
} }
} }

View file

@ -670,12 +670,10 @@ where
fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<E> { fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<E> {
match self { match self {
Fallback::Default(route) | Fallback::Service(route) => { Fallback::Default(route) | Fallback::Service(route) => route.oneshot_inner(req),
RouteFuture::from_future(route.oneshot_inner(req))
}
Fallback::BoxedHandler(handler) => { Fallback::BoxedHandler(handler) => {
let mut route = handler.clone().into_route(state); let mut route = handler.clone().into_route(state);
RouteFuture::from_future(route.oneshot_inner(req)) route.oneshot_inner(req)
} }
} }
} }

View file

@ -7,7 +7,7 @@ use axum_core::{extract::Request, response::IntoResponse};
use bytes::Bytes; use bytes::Bytes;
use http::{ use http::{
header::{self, CONTENT_LENGTH}, header::{self, CONTENT_LENGTH},
HeaderMap, HeaderValue, HeaderMap, HeaderValue, Method,
}; };
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use std::{ use std::{
@ -42,11 +42,9 @@ impl<E> Route<E> {
)) ))
} }
pub(crate) fn oneshot_inner( pub(crate) fn oneshot_inner(&mut self, req: Request) -> RouteFuture<E> {
&mut self, let method = req.method().clone();
req: Request, RouteFuture::from_future(method, self.0.clone().oneshot(req))
) -> Oneshot<BoxCloneService<Request, Response, E>, Request> {
self.0.clone().oneshot(req)
} }
pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError> pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
@ -98,8 +96,7 @@ where
#[inline] #[inline]
fn call(&mut self, req: Request<B>) -> Self::Future { fn call(&mut self, req: Request<B>) -> Self::Future {
let req = req.map(Body::new); self.oneshot_inner(req.map(Body::new))
RouteFuture::from_future(self.oneshot_inner(req))
} }
} }
@ -108,7 +105,7 @@ pin_project! {
pub struct RouteFuture<E> { pub struct RouteFuture<E> {
#[pin] #[pin]
kind: RouteFutureKind<E>, kind: RouteFutureKind<E>,
strip_body: bool, method: Method,
allow_header: Option<Bytes>, allow_header: Option<Bytes>,
} }
} }
@ -131,20 +128,16 @@ pin_project! {
impl<E> RouteFuture<E> { impl<E> RouteFuture<E> {
pub(crate) fn from_future( pub(crate) fn from_future(
method: Method,
future: Oneshot<BoxCloneService<Request, Response, E>, Request>, future: Oneshot<BoxCloneService<Request, Response, E>, Request>,
) -> Self { ) -> Self {
Self { Self {
kind: RouteFutureKind::Future { future }, kind: RouteFutureKind::Future { future },
strip_body: false, method,
allow_header: None, allow_header: None,
} }
} }
pub(crate) fn strip_body(mut self, strip_body: bool) -> Self {
self.strip_body = strip_body;
self
}
pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self { pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
self.allow_header = Some(allow_header); self.allow_header = Some(allow_header);
self self
@ -171,10 +164,24 @@ impl<E> Future for RouteFuture<E> {
set_allow_header(res.headers_mut(), this.allow_header); set_allow_header(res.headers_mut(), this.allow_header);
// make sure to set content-length before removing the body if *this.method == Method::CONNECT && res.status().is_success() {
set_content_length(res.size_hint(), res.headers_mut()); // From https://httpwg.org/specs/rfc9110.html#CONNECT:
// > A server MUST NOT send any Transfer-Encoding or
// > Content-Length header fields in a 2xx (Successful)
// > response to CONNECT.
if res.headers().contains_key(&CONTENT_LENGTH)
|| res.headers().contains_key(&header::TRANSFER_ENCODING)
|| res.size_hint().lower() != 0
{
error!("response to CONNECT with nonempty body");
res = res.map(|_| Body::empty());
}
} else {
// make sure to set content-length before removing the body
set_content_length(res.size_hint(), res.headers_mut());
}
let res = if *this.strip_body { let res = if *this.method == Method::HEAD {
res.map(|_| Body::empty()) res.map(|_| Body::empty())
} else { } else {
res res

View file

@ -367,7 +367,11 @@ where
let close_rx = close_rx.clone(); let close_rx = close_rx.clone();
tokio::spawn(async move { tokio::spawn(async move {
let builder = Builder::new(TokioExecutor::new()); #[allow(unused_mut)]
let mut builder = Builder::new(TokioExecutor::new());
// CONNECT protocol needed for HTTP/2 websockets
#[cfg(feature = "http2")]
builder.http2().enable_connect_protocol();
let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service); let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service);
pin_mut!(conn); pin_mut!(conn);

View file

@ -0,0 +1,13 @@
[package]
name = "example-websockets-http2"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum", features = ["ws", "http2"] }
axum-server = { version = "0.6", features = ["tls-rustls"] }
tokio = { version = "1", features = ["full"] }
tower-http = { version = "0.5.0", features = ["fs"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View file

@ -0,0 +1,7 @@
<p>Open this page in two windows and try sending some messages!</p>
<form action="javascript:void(0)">
<input type="text" name="content" required>
<button>Send</button>
</form>
<div id="messages"></div>
<script src='script.js'></script>

View file

@ -0,0 +1,11 @@
const socket = new WebSocket('wss://localhost:3000/ws');
socket.addEventListener('message', e => {
document.getElementById("messages").append(e.data, document.createElement("br"));
});
const form = document.querySelector("form");
form.addEventListener("submit", () => {
socket.send(form.elements.namedItem("content").value);
form.elements.namedItem("content").value = "";
});

View file

@ -0,0 +1,22 @@
-----BEGIN CERTIFICATE-----
MIIDkzCCAnugAwIBAgIUXVYkRCrM/ge03DVymDtXCuybp7gwDQYJKoZIhvcNAQEL
BQAwWTELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X
DTIxMDczMTE0MjIxMloXDTIyMDczMTE0MjIxMlowWTELMAkGA1UEBhMCVVMxEzAR
BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5
IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
MIIBCgKCAQEA02V5ZjmqLB/VQwTarrz/35qsa83L+DbAoa0001+jVmmC+G9Nufi0
daroFWj/Uicv2fZWETU8JoZKUrX4BK9og5cg5rln/CtBRWCUYIwRgY9R/CdBGPn4
kp+XkSJaCw74ZIyLy/Zfux6h8ES1m9YRnBza+s7U+ImRBRf4MRPtXQ3/mqJxAZYq
dOnKnvssRyD2qutgVTAxwMUvJWIivRhRYDj7WOpS4CEEeQxP1iH1/T5P7FdtTGdT
bVBABCA8JhL96uFGPpOYHcM/7R5EIA3yZ5FNg931QzoDITjtXGtQ6y9/l/IYkWm6
J67RWcN0IoTsZhz0WNU4gAeslVtJLofn8QIDAQABo1MwUTAdBgNVHQ4EFgQUzFnK
NfS4LAYuKeWwHbzooER0yZ0wHwYDVR0jBBgwFoAUzFnKNfS4LAYuKeWwHbzooER0
yZ0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAk4O+e9jia59W
ZwetN4GU7OWcYhmOgSizRSs6u7mTfp62LDMt96WKU3THksOnZ44HnqWQxsSfdFVU
XJD12tjvVU8Z4FWzQajcHeemUYiDze8EAh6TnxnUcOrU8IcwiKGxCWRY/908jnWg
+MMscfMCMYTRdeTPqD8fGzAlUCtmyzH6KLE3s4Oo/r5+NR+Uvrwpdvb7xe0MwwO9
Q/zR4N8ep/HwHVEObcaBofE1ssZLksX7ZgCP9wMgXRWpNAtC5EWxMbxYjBfWFH24
fDJlBMiGJWg8HHcxK7wQhFh+fuyNzE+xEWPsI9VL1zDftd9x8/QsOagyEOnY8Vxr
AopvZ09uEQ==
-----END CERTIFICATE-----

View file

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDTZXlmOaosH9VD
BNquvP/fmqxrzcv4NsChrTTTX6NWaYL4b025+LR1qugVaP9SJy/Z9lYRNTwmhkpS
tfgEr2iDlyDmuWf8K0FFYJRgjBGBj1H8J0EY+fiSn5eRIloLDvhkjIvL9l+7HqHw
RLWb1hGcHNr6ztT4iZEFF/gxE+1dDf+aonEBlip06cqe+yxHIPaq62BVMDHAxS8l
YiK9GFFgOPtY6lLgIQR5DE/WIfX9Pk/sV21MZ1NtUEAEIDwmEv3q4UY+k5gdwz/t
HkQgDfJnkU2D3fVDOgMhOO1ca1DrL3+X8hiRabonrtFZw3QihOxmHPRY1TiAB6yV
W0kuh+fxAgMBAAECggEADltu8k1qTFLhJgsXWxTFAAe+PBgfCT2WuaRM2So+qqjB
12Of0MieYPt5hbK63HaC3nfHgqWt7yPhulpXfOH45C8IcgMXl93MMg0MJr58leMI
+2ojFrIrerHSFm5R1TxwDEwrVm/mMowzDWFtQCc6zPJ8wNn5RuP48HKfTZ3/2fjw
zEjSwPO2wFMfo1EJNTjlI303lFbdFBs67NaX6puh30M7Tn+gznHKyO5a7F57wkIt
fkgnEy/sgMedQlwX7bRpUoD6f0fZzV8Qz4cHFywtYErczZJh3VGitJoO/VCIDdty
RPXOAqVDd7EpP1UUehZlKVWZ0OZMEfRgKbRCel5abQKBgQDwgwrIQ5+BiZv6a0VT
ETeXB+hRbvBinRykNo/RvLc3j1enRh9/zO/ShadZIXgOAiM1Jnr5Gp8KkNGca6K1
myhtad7xYPODYzNXXp6T1OPgZxHZLIYzVUj6ypXeV64Te5ZiDaJ1D49czsq+PqsQ
XRcgBJSNpFtDFiXWpjXWfx8PxwKBgQDhAnLY5Sl2eeQo+ud0MvjwftB/mN2qCzJY
5AlQpRI4ThWxJgGPuHTR29zVa5iWNYuA5LWrC1y/wx+t5HKUwq+5kxvs+npYpDJD
ZX/w0Glc6s0Jc/mFySkbw9B2LePedL7lRF5OiAyC6D106Sc9V2jlL4IflmOzt4CD
ZTNbLtC6hwKBgHfIzBXxl/9sCcMuqdg1Ovp9dbcZCaATn7ApfHd5BccmHQGyav27
k7XF2xMJGEHhzqcqAxUNrSgV+E9vTBomrHvRvrd5Ec7eGTPqbBA0d0nMC5eeFTh7
wV0miH20LX6Gjt9G6yJiHYSbeV5G1+vOcTYBEft5X/qJjU7aePXbWh0BAoGBAJlV
5tgCCuhvFloK6fHYzqZtdT6O+PfpW20SMXrgkvMF22h2YvgDFrDwqKRUB47NfHzg
3yBpxNH1ccA5/w97QO8w3gX3h6qicpJVOAPusu6cIBACFZfjRv1hyszOZwvw+Soa
Fj5kHkqTY1YpkREPYS9V2dIW1Wjic1SXgZDw7VM/AoGAP/cZ3ZHTSCDTFlItqy5C
rIy2AiY0WJsx+K0qcvtosPOOwtnGjWHb1gdaVdfX/IRkSsX4PAOdnsyidNC5/l/m
y8oa+5WEeGFclWFhr4dnTA766o8HrM2UjIgWWYBF2VKdptGnHxFeJWFUmeQC/xeW
w37pCS7ykL+7gp7V0WShYsw=
-----END PRIVATE KEY-----

View file

@ -0,0 +1,97 @@
//! Run with
//!
//! ```not_rust
//! cargo run -p example-websockets-http2
//! ```
use axum::{
extract::{
ws::{self, WebSocketUpgrade},
State,
},
http::Version,
routing::any,
Router,
};
use axum_server::tls_rustls::RustlsConfig;
use std::{net::SocketAddr, path::PathBuf};
use tokio::sync::broadcast;
use tower_http::services::ServeDir;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets");
// configure certificate and private key used by https
let config = RustlsConfig::from_pem_file(
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("self_signed_certs")
.join("cert.pem"),
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("self_signed_certs")
.join("key.pem"),
)
.await
.unwrap();
// build our application with some routes and a broadcast channel
let app = Router::new()
.fallback_service(ServeDir::new(assets_dir).append_index_html_on_directories(true))
.route("/ws", any(ws_handler))
.with_state(broadcast::channel::<String>(16).0);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
let mut server = axum_server::bind_rustls(addr, config);
// IMPORTANT: This is required to advertise our support for HTTP/2 websockets to the client.
// If you use axum::serve, it is enabled by default.
server.http_builder().http2().enable_connect_protocol();
server.serve(app.into_make_service()).await.unwrap();
}
async fn ws_handler(
ws: WebSocketUpgrade,
version: Version,
State(sender): State<broadcast::Sender<String>>,
) -> axum::response::Response {
tracing::debug!("accepted a WebSocket using {version:?}");
let mut receiver = sender.subscribe();
ws.on_upgrade(|mut ws| async move {
loop {
tokio::select! {
// Since `ws` is a `Stream`, it is by nature cancel-safe.
res = ws.recv() => {
match res {
Some(Ok(ws::Message::Text(s))) => {
let _ = sender.send(s);
}
Some(Ok(_)) => {}
Some(Err(e)) => tracing::debug!("client disconnected abruptly: {e}"),
None => break,
}
}
// Tokio guarantees that `broadcast::Receiver::recv` is cancel-safe.
res = receiver.recv() => {
match res {
Ok(msg) => if let Err(e) = ws.send(ws::Message::Text(msg)).await {
tracing::debug!("client disconnected abruptly: {e}");
}
Err(_) => continue,
}
}
}
}
})
}

View file

@ -19,7 +19,7 @@
use axum::{ use axum::{
extract::ws::{Message, WebSocket, WebSocketUpgrade}, extract::ws::{Message, WebSocket, WebSocketUpgrade},
response::IntoResponse, response::IntoResponse,
routing::get, routing::any,
Router, Router,
}; };
use axum_extra::TypedHeader; use axum_extra::TypedHeader;
@ -57,7 +57,7 @@ async fn main() {
// build our application with some routes // build our application with some routes
let app = Router::new() let app = Router::new()
.fallback_service(ServeDir::new(assets_dir).append_index_html_on_directories(true)) .fallback_service(ServeDir::new(assets_dir).append_index_html_on_directories(true))
.route("/ws", get(ws_handler)) .route("/ws", any(ws_handler))
// logging so we can see whats going on // logging so we can see whats going on
.layer( .layer(
TraceLayer::new_for_http() TraceLayer::new_for_http()
@ -77,7 +77,7 @@ async fn main() {
.unwrap(); .unwrap();
} }
/// The handler for the HTTP request (this gets called when the HTTP GET lands at the start /// The handler for the HTTP request (this gets called when the HTTP request lands at the start
/// of websocket negotiation). After this completes, the actual switching from HTTP to /// of websocket negotiation). After this completes, the actual switching from HTTP to
/// websocket protocol will occur. /// websocket protocol will occur.
/// This is the last point where we can extract TCP/IP metadata such as IP address of the client /// This is the last point where we can extract TCP/IP metadata such as IP address of the client