diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 365b04a4..36078115 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **breaking:** Upgrade matchit to 0.8, changing the path parameter syntax from `/:single` and `/*many` to `/{single}` and `/{*many}`; the old syntax produces a panic to avoid silent change in behavior ([#2645]) - **change:** Update minimum rust version to 1.75 ([#2943]) +- **added:** Add support WebSockets over HTTP/2. + They can be enabled by changing `get(ws_endpoint)` handlers to `any(ws_endpoint)`. [#2473]: https://github.com/tokio-rs/axum/pull/2473 [#2645]: https://github.com/tokio-rs/axum/pull/2645 diff --git a/axum/Cargo.toml b/axum/Cargo.toml index c4f7513b..e9e6c646 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -112,6 +112,7 @@ features = [ [dev-dependencies] anyhow = "1.0" axum-macros = { path = "../axum-macros", features = ["__private"] } +hyper = { version = "1.1.0", features = ["client"] } quickcheck = "1.0" quickcheck_macros = "1.0" reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] } diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 5a18d190..ba686b6e 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -5,12 +5,12 @@ //! ``` //! use axum::{ //! extract::ws::{WebSocketUpgrade, WebSocket}, -//! routing::get, +//! routing::any, //! response::{IntoResponse, Response}, //! Router, //! }; //! -//! let app = Router::new().route("/ws", get(handler)); +//! let app = Router::new().route("/ws", any(handler)); //! //! async fn handler(ws: WebSocketUpgrade) -> Response { //! ws.on_upgrade(handle_socket) @@ -40,7 +40,7 @@ //! use axum::{ //! extract::{ws::{WebSocketUpgrade, WebSocket}, State}, //! response::Response, -//! routing::get, +//! routing::any, //! Router, //! }; //! @@ -58,7 +58,7 @@ //! } //! //! let app = Router::new() -//! .route("/ws", get(handler)) +//! .route("/ws", any(handler)) //! .with_state(AppState { /* ... */ }); //! # let _: Router = app; //! ``` @@ -101,7 +101,7 @@ use futures_util::{ use http::{ header::{self, HeaderMap, HeaderName, HeaderValue}, request::Parts, - Method, StatusCode, + Method, StatusCode, Version, }; use hyper_util::rt::TokioIo; use sha1::{Digest, Sha1}; @@ -121,17 +121,20 @@ use tokio_tungstenite::{ /// Extractor for establishing WebSocket connections. /// -/// Note: This extractor requires the request method to be `GET` so it should -/// always be used with [`get`](crate::routing::get). Requests with other methods will be -/// rejected. +/// For HTTP/1.1 requests, this extractor requires the request method to be `GET`; +/// in later versions, `CONNECT` is used instead. +/// To support both, it should be used with [`any`](crate::routing::any). /// /// See the [module docs](self) for an example. +/// +/// [`MethodFilter`]: crate::routing::MethodFilter #[cfg_attr(docsrs, doc(cfg(feature = "ws")))] pub struct WebSocketUpgrade { config: WebSocketConfig, /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. protocol: Option, - sec_websocket_key: HeaderValue, + /// `None` if HTTP/2+ WebSockets are used. + sec_websocket_key: Option, on_upgrade: hyper::upgrade::OnUpgrade, on_failed_upgrade: F, sec_websocket_protocol: Option, @@ -212,12 +215,12 @@ impl WebSocketUpgrade { /// ``` /// use axum::{ /// extract::ws::{WebSocketUpgrade, WebSocket}, - /// routing::get, + /// routing::any, /// response::{IntoResponse, Response}, /// Router, /// }; /// - /// let app = Router::new().route("/ws", get(handler)); + /// let app = Router::new().route("/ws", any(handler)); /// /// async fn handler(ws: WebSocketUpgrade) -> Response { /// ws.protocols(["graphql-ws", "graphql-transport-ws"]) @@ -329,25 +332,34 @@ impl WebSocketUpgrade { callback(socket).await; }); - #[allow(clippy::declare_interior_mutable_const)] - const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); - #[allow(clippy::declare_interior_mutable_const)] - const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); + if let Some(sec_websocket_key) = &self.sec_websocket_key { + // If `sec_websocket_key` was `Some`, we are using HTTP/1.1. - let mut builder = Response::builder() - .status(StatusCode::SWITCHING_PROTOCOLS) - .header(header::CONNECTION, UPGRADE) - .header(header::UPGRADE, WEBSOCKET) - .header( - header::SEC_WEBSOCKET_ACCEPT, - sign(self.sec_websocket_key.as_bytes()), - ); + #[allow(clippy::declare_interior_mutable_const)] + const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); + #[allow(clippy::declare_interior_mutable_const)] + const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); - if let Some(protocol) = self.protocol { - builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); + let mut builder = Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(header::CONNECTION, UPGRADE) + .header(header::UPGRADE, WEBSOCKET) + .header( + header::SEC_WEBSOCKET_ACCEPT, + sign(sec_websocket_key.as_bytes()), + ); + + if let Some(protocol) = self.protocol { + builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); + } + + builder.body(Body::empty()).unwrap() + } else { + // Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond + // with a 2XX with an empty body: + // . + Response::new(Body::empty()) } - - builder.body(Body::empty()).unwrap() } } @@ -387,28 +399,49 @@ where type Rejection = WebSocketUpgradeRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - if parts.method != Method::GET { - return Err(MethodNotGet.into()); - } + let sec_websocket_key = if parts.version <= Version::HTTP_11 { + if parts.method != Method::GET { + return Err(MethodNotGet.into()); + } - if !header_contains(&parts.headers, header::CONNECTION, "upgrade") { - return Err(InvalidConnectionHeader.into()); - } + if !header_contains(&parts.headers, header::CONNECTION, "upgrade") { + return Err(InvalidConnectionHeader.into()); + } - if !header_eq(&parts.headers, header::UPGRADE, "websocket") { - return Err(InvalidUpgradeHeader.into()); - } + if !header_eq(&parts.headers, header::UPGRADE, "websocket") { + return Err(InvalidUpgradeHeader.into()); + } + + Some( + parts + .headers + .get(header::SEC_WEBSOCKET_KEY) + .ok_or(WebSocketKeyHeaderMissing)? + .clone(), + ) + } else { + if parts.method != Method::CONNECT { + return Err(MethodNotConnect.into()); + } + + // if this feature flag is disabled, we won’t be receiving an HTTP/2 request to begin + // with. + #[cfg(feature = "http2")] + if parts + .extensions + .get::() + .map_or(true, |p| p.as_str() != "websocket") + { + return Err(InvalidProtocolPseudoheader.into()); + } + + None + }; if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") { return Err(InvalidWebSocketVersionHeader.into()); } - let sec_websocket_key = parts - .headers - .get(header::SEC_WEBSOCKET_KEY) - .ok_or(WebSocketKeyHeaderMissing)? - .clone(); - let on_upgrade = parts .extensions .remove::() @@ -706,6 +739,13 @@ pub mod rejection { pub struct MethodNotGet; } + define_rejection! { + #[status = METHOD_NOT_ALLOWED] + #[body = "Request method must be `CONNECT`"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct MethodNotConnect; + } + define_rejection! { #[status = BAD_REQUEST] #[body = "Connection header did not include 'upgrade'"] @@ -720,6 +760,13 @@ pub mod rejection { pub struct InvalidUpgradeHeader; } + define_rejection! { + #[status = BAD_REQUEST] + #[body = "`:protocol` pseudo-header did not include 'websocket'"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct InvalidProtocolPseudoheader; + } + define_rejection! { #[status = BAD_REQUEST] #[body = "`Sec-WebSocket-Version` header did not include '13'"] @@ -755,8 +802,10 @@ pub mod rejection { /// extractor can fail. pub enum WebSocketUpgradeRejection { MethodNotGet, + MethodNotConnect, InvalidConnectionHeader, InvalidUpgradeHeader, + InvalidProtocolPseudoheader, InvalidWebSocketVersionHeader, WebSocketKeyHeaderMissing, ConnectionNotUpgradable, @@ -838,14 +887,18 @@ mod tests { use std::future::ready; use super::*; - use crate::{routing::get, test_helpers::spawn_service, Router}; + use crate::{routing::any, test_helpers::spawn_service, Router}; use http::{Request, Version}; + use http_body_util::BodyExt as _; + use hyper_util::rt::TokioExecutor; + use tokio::io::{AsyncRead, AsyncWrite}; + use tokio::net::TcpStream; use tokio_tungstenite::tungstenite; use tower::ServiceExt; #[crate::test] async fn rejects_http_1_0_requests() { - let svc = get(|ws: Result| { + let svc = any(|ws: Result| { let rejection = ws.unwrap_err(); assert!(matches!( rejection, @@ -874,7 +927,7 @@ mod tests { async fn handler(ws: WebSocketUpgrade) -> Response { ws.on_upgrade(|_| async {}) } - let _: Router = Router::new().route("/", get(handler)); + let _: Router = Router::new().route("/", any(handler)); } #[allow(dead_code)] @@ -883,16 +936,61 @@ mod tests { ws.on_failed_upgrade(|_error: Error| println!("oops!")) .on_upgrade(|_| async {}) } - let _: Router = Router::new().route("/", get(handler)); + let _: Router = Router::new().route("/", any(handler)); } #[crate::test] async fn integration_test() { - let app = Router::new().route( - "/echo", - get(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))), - ); + let addr = spawn_service(echo_app()); + let (socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo")) + .await + .unwrap(); + test_echo_app(socket).await; + } + #[crate::test] + #[cfg(feature = "http2")] + async fn http2() { + let addr = spawn_service(echo_app()); + let io = TokioIo::new(TcpStream::connect(addr).await.unwrap()); + let (mut send_request, conn) = + hyper::client::conn::http2::Builder::new(TokioExecutor::new()) + .handshake(io) + .await + .unwrap(); + + // Wait a little for the SETTINGS frame to go through… + for _ in 0..10 { + tokio::task::yield_now().await; + } + assert!(conn.is_extended_connect_protocol_enabled()); + tokio::spawn(async { + conn.await.unwrap(); + }); + + let req = Request::builder() + .method(Method::CONNECT) + .extension(hyper::ext::Protocol::from_static("websocket")) + .uri("/echo") + .header("sec-websocket-version", "13") + .header("Host", "server.example.com") + .body(Body::empty()) + .unwrap(); + + let response = send_request.send_request(req).await.unwrap(); + let status = response.status(); + if status != 200 { + let body = response.into_body().collect().await.unwrap().to_bytes(); + let body = std::str::from_utf8(&body).unwrap(); + panic!("response status was {}: {body}", status); + } + let upgraded = hyper::upgrade::on(response).await.unwrap(); + let upgraded = TokioIo::new(upgraded); + let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await; + test_echo_app(socket).await; + } + + fn echo_app() -> Router { async fn handle_socket(mut socket: WebSocket) { while let Some(Ok(msg)) = socket.recv().await { match msg { @@ -908,11 +1006,13 @@ mod tests { } } - let addr = spawn_service(app); - let (mut socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo")) - .await - .unwrap(); + Router::new().route( + "/echo", + any(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))), + ) + } + async fn test_echo_app(mut socket: WebSocketStream) { let input = tungstenite::Message::Text("foobar".to_owned()); socket.send(input.clone()).await.unwrap(); let output = socket.next().await.unwrap().unwrap(); diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 8c50eccf..3b62f728 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1035,13 +1035,11 @@ where match $svc { MethodEndpoint::None => {} MethodEndpoint::Route(route) => { - return RouteFuture::from_future(route.clone().oneshot_inner($req)) - .strip_body($method == Method::HEAD); + return route.clone().oneshot_inner($req); } MethodEndpoint::BoxedHandler(handler) => { - let route = handler.clone().into_route(state); - return RouteFuture::from_future(route.clone().oneshot_inner($req)) - .strip_body($method == Method::HEAD); + let mut route = handler.clone().into_route(state); + return route.oneshot_inner($req); } } } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index be1aeab0..9987dd4f 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -670,12 +670,10 @@ where fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { match self { - Fallback::Default(route) | Fallback::Service(route) => { - RouteFuture::from_future(route.oneshot_inner(req)) - } + Fallback::Default(route) | Fallback::Service(route) => route.oneshot_inner(req), Fallback::BoxedHandler(handler) => { let mut route = handler.clone().into_route(state); - RouteFuture::from_future(route.oneshot_inner(req)) + route.oneshot_inner(req) } } } diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index da1e848a..0b88fbcc 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -7,7 +7,7 @@ use axum_core::{extract::Request, response::IntoResponse}; use bytes::Bytes; use http::{ header::{self, CONTENT_LENGTH}, - HeaderMap, HeaderValue, + HeaderMap, HeaderValue, Method, }; use pin_project_lite::pin_project; use std::{ @@ -42,11 +42,9 @@ impl Route { )) } - pub(crate) fn oneshot_inner( - &mut self, - req: Request, - ) -> Oneshot, Request> { - self.0.clone().oneshot(req) + pub(crate) fn oneshot_inner(&mut self, req: Request) -> RouteFuture { + let method = req.method().clone(); + RouteFuture::from_future(method, self.0.clone().oneshot(req)) } pub(crate) fn layer(self, layer: L) -> Route @@ -98,8 +96,7 @@ where #[inline] fn call(&mut self, req: Request) -> Self::Future { - let req = req.map(Body::new); - RouteFuture::from_future(self.oneshot_inner(req)) + self.oneshot_inner(req.map(Body::new)) } } @@ -108,7 +105,7 @@ pin_project! { pub struct RouteFuture { #[pin] kind: RouteFutureKind, - strip_body: bool, + method: Method, allow_header: Option, } } @@ -131,20 +128,16 @@ pin_project! { impl RouteFuture { pub(crate) fn from_future( + method: Method, future: Oneshot, Request>, ) -> Self { Self { kind: RouteFutureKind::Future { future }, - strip_body: false, + method, allow_header: None, } } - pub(crate) fn strip_body(mut self, strip_body: bool) -> Self { - self.strip_body = strip_body; - self - } - pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self { self.allow_header = Some(allow_header); self @@ -171,10 +164,24 @@ impl Future for RouteFuture { set_allow_header(res.headers_mut(), this.allow_header); - // make sure to set content-length before removing the body - set_content_length(res.size_hint(), res.headers_mut()); + if *this.method == Method::CONNECT && res.status().is_success() { + // From https://httpwg.org/specs/rfc9110.html#CONNECT: + // > A server MUST NOT send any Transfer-Encoding or + // > Content-Length header fields in a 2xx (Successful) + // > response to CONNECT. + if res.headers().contains_key(&CONTENT_LENGTH) + || res.headers().contains_key(&header::TRANSFER_ENCODING) + || res.size_hint().lower() != 0 + { + error!("response to CONNECT with nonempty body"); + res = res.map(|_| Body::empty()); + } + } else { + // make sure to set content-length before removing the body + set_content_length(res.size_hint(), res.headers_mut()); + } - let res = if *this.strip_body { + let res = if *this.method == Method::HEAD { res.map(|_| Body::empty()) } else { res diff --git a/axum/src/serve.rs b/axum/src/serve.rs index 1ba9a145..27e4912a 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -367,7 +367,11 @@ where let close_rx = close_rx.clone(); tokio::spawn(async move { - let builder = Builder::new(TokioExecutor::new()); + #[allow(unused_mut)] + let mut builder = Builder::new(TokioExecutor::new()); + // CONNECT protocol needed for HTTP/2 websockets + #[cfg(feature = "http2")] + builder.http2().enable_connect_protocol(); let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service); pin_mut!(conn); diff --git a/examples/websockets-http2/Cargo.toml b/examples/websockets-http2/Cargo.toml new file mode 100644 index 00000000..19a8d0d7 --- /dev/null +++ b/examples/websockets-http2/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "example-websockets-http2" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum", features = ["ws", "http2"] } +axum-server = { version = "0.6", features = ["tls-rustls"] } +tokio = { version = "1", features = ["full"] } +tower-http = { version = "0.5.0", features = ["fs"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/websockets-http2/assets/index.html b/examples/websockets-http2/assets/index.html new file mode 100644 index 00000000..6a373782 --- /dev/null +++ b/examples/websockets-http2/assets/index.html @@ -0,0 +1,7 @@ +

Open this page in two windows and try sending some messages!

+
+ + +
+
+ diff --git a/examples/websockets-http2/assets/script.js b/examples/websockets-http2/assets/script.js new file mode 100644 index 00000000..952c21da --- /dev/null +++ b/examples/websockets-http2/assets/script.js @@ -0,0 +1,11 @@ +const socket = new WebSocket('wss://localhost:3000/ws'); + +socket.addEventListener('message', e => { + document.getElementById("messages").append(e.data, document.createElement("br")); +}); + +const form = document.querySelector("form"); +form.addEventListener("submit", () => { + socket.send(form.elements.namedItem("content").value); + form.elements.namedItem("content").value = ""; +}); diff --git a/examples/websockets-http2/self_signed_certs/cert.pem b/examples/websockets-http2/self_signed_certs/cert.pem new file mode 100644 index 00000000..656aa880 --- /dev/null +++ b/examples/websockets-http2/self_signed_certs/cert.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUXVYkRCrM/ge03DVymDtXCuybp7gwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTIxMDczMTE0MjIxMloXDTIyMDczMTE0MjIxMlowWTELMAkGA1UEBhMCVVMxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEA02V5ZjmqLB/VQwTarrz/35qsa83L+DbAoa0001+jVmmC+G9Nufi0 +daroFWj/Uicv2fZWETU8JoZKUrX4BK9og5cg5rln/CtBRWCUYIwRgY9R/CdBGPn4 +kp+XkSJaCw74ZIyLy/Zfux6h8ES1m9YRnBza+s7U+ImRBRf4MRPtXQ3/mqJxAZYq +dOnKnvssRyD2qutgVTAxwMUvJWIivRhRYDj7WOpS4CEEeQxP1iH1/T5P7FdtTGdT +bVBABCA8JhL96uFGPpOYHcM/7R5EIA3yZ5FNg931QzoDITjtXGtQ6y9/l/IYkWm6 +J67RWcN0IoTsZhz0WNU4gAeslVtJLofn8QIDAQABo1MwUTAdBgNVHQ4EFgQUzFnK +NfS4LAYuKeWwHbzooER0yZ0wHwYDVR0jBBgwFoAUzFnKNfS4LAYuKeWwHbzooER0 +yZ0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAk4O+e9jia59W +ZwetN4GU7OWcYhmOgSizRSs6u7mTfp62LDMt96WKU3THksOnZ44HnqWQxsSfdFVU +XJD12tjvVU8Z4FWzQajcHeemUYiDze8EAh6TnxnUcOrU8IcwiKGxCWRY/908jnWg ++MMscfMCMYTRdeTPqD8fGzAlUCtmyzH6KLE3s4Oo/r5+NR+Uvrwpdvb7xe0MwwO9 +Q/zR4N8ep/HwHVEObcaBofE1ssZLksX7ZgCP9wMgXRWpNAtC5EWxMbxYjBfWFH24 +fDJlBMiGJWg8HHcxK7wQhFh+fuyNzE+xEWPsI9VL1zDftd9x8/QsOagyEOnY8Vxr +AopvZ09uEQ== +-----END CERTIFICATE----- diff --git a/examples/websockets-http2/self_signed_certs/key.pem b/examples/websockets-http2/self_signed_certs/key.pem new file mode 100644 index 00000000..3de14eb3 --- /dev/null +++ b/examples/websockets-http2/self_signed_certs/key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDTZXlmOaosH9VD +BNquvP/fmqxrzcv4NsChrTTTX6NWaYL4b025+LR1qugVaP9SJy/Z9lYRNTwmhkpS +tfgEr2iDlyDmuWf8K0FFYJRgjBGBj1H8J0EY+fiSn5eRIloLDvhkjIvL9l+7HqHw +RLWb1hGcHNr6ztT4iZEFF/gxE+1dDf+aonEBlip06cqe+yxHIPaq62BVMDHAxS8l +YiK9GFFgOPtY6lLgIQR5DE/WIfX9Pk/sV21MZ1NtUEAEIDwmEv3q4UY+k5gdwz/t +HkQgDfJnkU2D3fVDOgMhOO1ca1DrL3+X8hiRabonrtFZw3QihOxmHPRY1TiAB6yV +W0kuh+fxAgMBAAECggEADltu8k1qTFLhJgsXWxTFAAe+PBgfCT2WuaRM2So+qqjB +12Of0MieYPt5hbK63HaC3nfHgqWt7yPhulpXfOH45C8IcgMXl93MMg0MJr58leMI ++2ojFrIrerHSFm5R1TxwDEwrVm/mMowzDWFtQCc6zPJ8wNn5RuP48HKfTZ3/2fjw +zEjSwPO2wFMfo1EJNTjlI303lFbdFBs67NaX6puh30M7Tn+gznHKyO5a7F57wkIt +fkgnEy/sgMedQlwX7bRpUoD6f0fZzV8Qz4cHFywtYErczZJh3VGitJoO/VCIDdty +RPXOAqVDd7EpP1UUehZlKVWZ0OZMEfRgKbRCel5abQKBgQDwgwrIQ5+BiZv6a0VT +ETeXB+hRbvBinRykNo/RvLc3j1enRh9/zO/ShadZIXgOAiM1Jnr5Gp8KkNGca6K1 +myhtad7xYPODYzNXXp6T1OPgZxHZLIYzVUj6ypXeV64Te5ZiDaJ1D49czsq+PqsQ +XRcgBJSNpFtDFiXWpjXWfx8PxwKBgQDhAnLY5Sl2eeQo+ud0MvjwftB/mN2qCzJY +5AlQpRI4ThWxJgGPuHTR29zVa5iWNYuA5LWrC1y/wx+t5HKUwq+5kxvs+npYpDJD +ZX/w0Glc6s0Jc/mFySkbw9B2LePedL7lRF5OiAyC6D106Sc9V2jlL4IflmOzt4CD +ZTNbLtC6hwKBgHfIzBXxl/9sCcMuqdg1Ovp9dbcZCaATn7ApfHd5BccmHQGyav27 +k7XF2xMJGEHhzqcqAxUNrSgV+E9vTBomrHvRvrd5Ec7eGTPqbBA0d0nMC5eeFTh7 +wV0miH20LX6Gjt9G6yJiHYSbeV5G1+vOcTYBEft5X/qJjU7aePXbWh0BAoGBAJlV +5tgCCuhvFloK6fHYzqZtdT6O+PfpW20SMXrgkvMF22h2YvgDFrDwqKRUB47NfHzg +3yBpxNH1ccA5/w97QO8w3gX3h6qicpJVOAPusu6cIBACFZfjRv1hyszOZwvw+Soa +Fj5kHkqTY1YpkREPYS9V2dIW1Wjic1SXgZDw7VM/AoGAP/cZ3ZHTSCDTFlItqy5C +rIy2AiY0WJsx+K0qcvtosPOOwtnGjWHb1gdaVdfX/IRkSsX4PAOdnsyidNC5/l/m +y8oa+5WEeGFclWFhr4dnTA766o8HrM2UjIgWWYBF2VKdptGnHxFeJWFUmeQC/xeW +w37pCS7ykL+7gp7V0WShYsw= +-----END PRIVATE KEY----- diff --git a/examples/websockets-http2/src/main.rs b/examples/websockets-http2/src/main.rs new file mode 100644 index 00000000..dbc682c4 --- /dev/null +++ b/examples/websockets-http2/src/main.rs @@ -0,0 +1,97 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-websockets-http2 +//! ``` + +use axum::{ + extract::{ + ws::{self, WebSocketUpgrade}, + State, + }, + http::Version, + routing::any, + Router, +}; +use axum_server::tls_rustls::RustlsConfig; +use std::{net::SocketAddr, path::PathBuf}; +use tokio::sync::broadcast; +use tower_http::services::ServeDir; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets"); + + // configure certificate and private key used by https + let config = RustlsConfig::from_pem_file( + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("self_signed_certs") + .join("cert.pem"), + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("self_signed_certs") + .join("key.pem"), + ) + .await + .unwrap(); + + // build our application with some routes and a broadcast channel + let app = Router::new() + .fallback_service(ServeDir::new(assets_dir).append_index_html_on_directories(true)) + .route("/ws", any(ws_handler)) + .with_state(broadcast::channel::(16).0); + + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + tracing::debug!("listening on {}", addr); + + let mut server = axum_server::bind_rustls(addr, config); + + // IMPORTANT: This is required to advertise our support for HTTP/2 websockets to the client. + // If you use axum::serve, it is enabled by default. + server.http_builder().http2().enable_connect_protocol(); + + server.serve(app.into_make_service()).await.unwrap(); +} + +async fn ws_handler( + ws: WebSocketUpgrade, + version: Version, + State(sender): State>, +) -> axum::response::Response { + tracing::debug!("accepted a WebSocket using {version:?}"); + let mut receiver = sender.subscribe(); + ws.on_upgrade(|mut ws| async move { + loop { + tokio::select! { + // Since `ws` is a `Stream`, it is by nature cancel-safe. + res = ws.recv() => { + match res { + Some(Ok(ws::Message::Text(s))) => { + let _ = sender.send(s); + } + Some(Ok(_)) => {} + Some(Err(e)) => tracing::debug!("client disconnected abruptly: {e}"), + None => break, + } + } + // Tokio guarantees that `broadcast::Receiver::recv` is cancel-safe. + res = receiver.recv() => { + match res { + Ok(msg) => if let Err(e) = ws.send(ws::Message::Text(msg)).await { + tracing::debug!("client disconnected abruptly: {e}"); + } + Err(_) => continue, + } + } + } + } + }) +} diff --git a/examples/websockets/src/main.rs b/examples/websockets/src/main.rs index 7b964404..7c4a9801 100644 --- a/examples/websockets/src/main.rs +++ b/examples/websockets/src/main.rs @@ -19,7 +19,7 @@ use axum::{ extract::ws::{Message, WebSocket, WebSocketUpgrade}, response::IntoResponse, - routing::get, + routing::any, Router, }; use axum_extra::TypedHeader; @@ -57,7 +57,7 @@ async fn main() { // build our application with some routes let app = Router::new() .fallback_service(ServeDir::new(assets_dir).append_index_html_on_directories(true)) - .route("/ws", get(ws_handler)) + .route("/ws", any(ws_handler)) // logging so we can see whats going on .layer( TraceLayer::new_for_http() @@ -77,7 +77,7 @@ async fn main() { .unwrap(); } -/// The handler for the HTTP request (this gets called when the HTTP GET lands at the start +/// The handler for the HTTP request (this gets called when the HTTP request lands at the start /// of websocket negotiation). After this completes, the actual switching from HTTP to /// websocket protocol will occur. /// This is the last point where we can extract TCP/IP metadata such as IP address of the client