From c2e7fc0b3d337cba663bebaab900ddee14c695e9 Mon Sep 17 00:00:00 2001 From: Sabrina Jewson Date: Sun, 6 Oct 2024 20:09:06 +0100 Subject: [PATCH] Add `MethodFilter::CONNECT` (#2961) --- axum-extra/CHANGELOG.md | 2 + axum-extra/src/routing/mod.rs | 23 +++++++++ axum/CHANGELOG.md | 7 +++ axum/src/routing/method_filter.rs | 34 +++++++++++-- axum/src/routing/method_routing.rs | 79 +++++++++++++++++++++++++++++- axum/src/routing/mod.rs | 6 +-- 6 files changed, 143 insertions(+), 8 deletions(-) diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index 066029ae..25b84c0d 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -7,8 +7,10 @@ and this project adheres to [Semantic Versioning]. # Unreleased +- **added:** Add `RouterExt::typed_connect` ([#2961]) - **added:** Add `json!` for easy construction of JSON responses ([#2962]) +[#2961]: https://github.com/tokio-rs/axum/pull/2961 [#2962]: https://github.com/tokio-rs/axum/pull/2962 # 0.9.4 diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index a294c547..5fce9591 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -131,6 +131,19 @@ pub trait RouterExt: sealed::Sealed { T: SecondElementIs

+ 'static, P: TypedPath; + /// Add a typed `CONNECT` route to the router. + /// + /// The path will be inferred from the first argument to the handler function which must + /// implement [`TypedPath`]. + /// + /// See [`TypedPath`] for more details and examples. + #[cfg(feature = "typed-routing")] + fn typed_connect(self, handler: H) -> Self + where + H: axum::handler::Handler, + T: SecondElementIs

+ 'static, + P: TypedPath; + /// Add another route to the router with an additional "trailing slash redirect" route. /// /// If you add a route _without_ a trailing slash, such as `/foo`, this method will also add a @@ -255,6 +268,16 @@ where self.route(P::PATH, axum::routing::trace(handler)) } + #[cfg(feature = "typed-routing")] + fn typed_connect(self, handler: H) -> Self + where + H: axum::handler::Handler, + T: SecondElementIs

+ 'static, + P: TypedPath, + { + self.route(P::PATH, axum::routing::connect(handler)) + } + #[track_caller] fn route_with_tsr(mut self, path: &str, method_router: MethodRouter) -> Self where diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index aa9d8067..ed355e82 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +# Unreleased + +- **added:** Add `MethodFilter::CONNECT`, `routing::connect[_service]` + and `MethodRouter::connect[_service]` ([#2961]) + +[#2961]: https://github.com/tokio-rs/axum/pull/2961 + # 0.7.7 - **change**: Remove manual tables of content from the documentation, since diff --git a/axum/src/routing/method_filter.rs b/axum/src/routing/method_filter.rs index 1cea4235..040783ec 100644 --- a/axum/src/routing/method_filter.rs +++ b/axum/src/routing/method_filter.rs @@ -9,6 +9,24 @@ use std::{ pub struct MethodFilter(u16); impl MethodFilter { + /// Match `CONNECT` requests. + /// + /// This is useful for implementing HTTP/2's [extended CONNECT method], + /// in which the `:protocol` pseudoheader is read + /// (using [`hyper::ext::Protocol`]) + /// and the connection upgraded to a bidirectional byte stream + /// (using [`hyper::upgrade::on`]). + /// + /// As seen in the [HTTP Upgrade Token Registry], + /// common uses include WebSockets and proxying UDP or IP – + /// though note that when using [`WebSocketUpgrade`] + /// it's more useful to use [`any`](crate::routing::any) + /// as HTTP/1.1 WebSockets need to support `GET`. + /// + /// [extended CONNECT]: https://www.rfc-editor.org/rfc/rfc8441.html#section-4 + /// [HTTP Upgrade Token Registry]: https://www.iana.org/assignments/http-upgrade-tokens/http-upgrade-tokens.xhtml + /// [`WebSocketUpgrade`]: crate::extract::WebSocketUpgrade + pub const CONNECT: Self = Self::from_bits(0b0_0000_0001); /// Match `DELETE` requests. pub const DELETE: Self = Self::from_bits(0b0_0000_0010); /// Match `GET` requests. @@ -71,6 +89,7 @@ impl TryFrom for MethodFilter { fn try_from(m: Method) -> Result { match m { + Method::CONNECT => Ok(MethodFilter::CONNECT), Method::DELETE => Ok(MethodFilter::DELETE), Method::GET => Ok(MethodFilter::GET), Method::HEAD => Ok(MethodFilter::HEAD), @@ -90,6 +109,11 @@ mod tests { #[test] fn from_http_method() { + assert_eq!( + MethodFilter::try_from(Method::CONNECT).unwrap(), + MethodFilter::CONNECT + ); + assert_eq!( MethodFilter::try_from(Method::DELETE).unwrap(), MethodFilter::DELETE @@ -130,9 +154,11 @@ mod tests { MethodFilter::TRACE ); - assert!(MethodFilter::try_from(http::Method::CONNECT) - .unwrap_err() - .to_string() - .contains("CONNECT")); + assert!( + MethodFilter::try_from(http::Method::from_bytes(b"CUSTOM").unwrap()) + .unwrap_err() + .to_string() + .contains("CUSTOM") + ); } } diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 1eb6075b..5ed4f6a9 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -59,6 +59,19 @@ macro_rules! top_level_service_fn { ); }; + ( + $name:ident, CONNECT + ) => { + top_level_service_fn!( + /// Route `CONNECT` requests to the given service. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`get_service`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -118,6 +131,19 @@ macro_rules! top_level_handler_fn { ); }; + ( + $name:ident, CONNECT + ) => { + top_level_handler_fn!( + /// Route `CONNECT` requests to the given handler. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`get`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -187,6 +213,19 @@ macro_rules! chained_service_fn { ); }; + ( + $name:ident, CONNECT + ) => { + chained_service_fn!( + /// Chain an additional service that will only accept `CONNECT` requests. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`MethodRouter::get_service`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -249,6 +288,19 @@ macro_rules! chained_handler_fn { ); }; + ( + $name:ident, CONNECT + ) => { + chained_handler_fn!( + /// Chain an additional handler that will only accept `CONNECT` requests. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`MethodRouter::get`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -278,6 +330,7 @@ macro_rules! chained_handler_fn { }; } +top_level_service_fn!(connect_service, CONNECT); top_level_service_fn!(delete_service, DELETE); top_level_service_fn!(get_service, GET); top_level_service_fn!(head_service, HEAD); @@ -381,6 +434,7 @@ where .skip_allow_header() } +top_level_handler_fn!(connect, CONNECT); top_level_handler_fn!(delete, DELETE); top_level_handler_fn!(get, GET); top_level_handler_fn!(head, HEAD); @@ -497,6 +551,7 @@ pub struct MethodRouter { post: MethodEndpoint, put: MethodEndpoint, trace: MethodEndpoint, + connect: MethodEndpoint, fallback: Fallback, allow_header: AllowHeader, } @@ -538,6 +593,7 @@ impl fmt::Debug for MethodRouter { .field("post", &self.post) .field("put", &self.put) .field("trace", &self.trace) + .field("connect", &self.connect) .field("fallback", &self.fallback) .field("allow_header", &self.allow_header) .finish() @@ -582,6 +638,7 @@ where ) } + chained_handler_fn!(connect, CONNECT); chained_handler_fn!(delete, DELETE); chained_handler_fn!(get, GET); chained_handler_fn!(head, HEAD); @@ -689,6 +746,7 @@ where post: MethodEndpoint::None, put: MethodEndpoint::None, trace: MethodEndpoint::None, + connect: MethodEndpoint::None, allow_header: AllowHeader::None, fallback: Fallback::Default(fallback), } @@ -705,6 +763,7 @@ where post: self.post.with_state(&state), put: self.put.with_state(&state), trace: self.trace.with_state(&state), + connect: self.connect.with_state(&state), allow_header: self.allow_header, fallback: self.fallback.with_state(state), } @@ -853,9 +912,20 @@ where &["DELETE"], ); + set_endpoint( + "CONNECT", + &mut self.options, + &endpoint, + filter, + MethodFilter::CONNECT, + &mut self.allow_header, + &["CONNECT"], + ); + self } + chained_service_fn!(connect_service, CONNECT); chained_service_fn!(delete_service, DELETE); chained_service_fn!(get_service, GET); chained_service_fn!(head_service, HEAD); @@ -899,6 +969,7 @@ where post: self.post.map(layer_fn.clone()), put: self.put.map(layer_fn.clone()), trace: self.trace.map(layer_fn.clone()), + connect: self.connect.map(layer_fn.clone()), fallback: self.fallback.map(layer_fn), allow_header: self.allow_header, } @@ -923,6 +994,7 @@ where && self.post.is_none() && self.put.is_none() && self.trace.is_none() + && self.connect.is_none() { panic!( "Adding a route_layer before any routes is a no-op. \ @@ -943,7 +1015,8 @@ where self.patch = self.patch.map(layer_fn.clone()); self.post = self.post.map(layer_fn.clone()); self.put = self.put.map(layer_fn.clone()); - self.trace = self.trace.map(layer_fn); + self.trace = self.trace.map(layer_fn.clone()); + self.connect = self.connect.map(layer_fn); self } @@ -984,6 +1057,7 @@ where self.post = merge_inner(path, "POST", self.post, other.post); self.put = merge_inner(path, "PUT", self.put, other.put); self.trace = merge_inner(path, "TRACE", self.trace, other.trace); + self.connect = merge_inner(path, "CONNECT", self.connect, other.connect); self.fallback = self .fallback @@ -1059,6 +1133,7 @@ where post, put, trace, + connect, fallback, allow_header, } = self; @@ -1072,6 +1147,7 @@ where call!(req, method, PUT, put); call!(req, method, DELETE, delete); call!(req, method, TRACE, trace); + call!(req, method, CONNECT, connect); let future = fallback.clone().call_with_state(req, state); @@ -1114,6 +1190,7 @@ impl Clone for MethodRouter { post: self.post.clone(), put: self.put.clone(), trace: self.trace.clone(), + connect: self.connect.clone(), fallback: self.fallback.clone(), allow_header: self.allow_header.clone(), } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index dc6ca815..822be773 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -40,9 +40,9 @@ mod tests; pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route}; pub use self::method_routing::{ - any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service, - options, options_service, patch, patch_service, post, post_service, put, put_service, trace, - trace_service, MethodRouter, + any, any_service, connect, connect_service, delete, delete_service, get, get_service, head, + head_service, on, on_service, options, options_service, patch, patch_service, post, + post_service, put, put_service, trace, trace_service, MethodRouter, }; macro_rules! panic_on_err {