From a2ab338e6806f3df80bbb7cf0a1325c69bbbf546 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sun, 9 Oct 2022 21:33:40 +0200 Subject: [PATCH] Rewrite how state is passed from Router to MethodRouter --- .../into_service_state_in_extension.rs | 85 ---- axum/src/handler/mod.rs | 5 +- axum/src/routing/method_routing.rs | 389 +++++++++++------- axum/src/routing/mod.rs | 19 +- axum/src/routing/service.rs | 8 +- 5 files changed, 264 insertions(+), 242 deletions(-) delete mode 100644 axum/src/handler/into_service_state_in_extension.rs diff --git a/axum/src/handler/into_service_state_in_extension.rs b/axum/src/handler/into_service_state_in_extension.rs deleted file mode 100644 index 304622a5..00000000 --- a/axum/src/handler/into_service_state_in_extension.rs +++ /dev/null @@ -1,85 +0,0 @@ -use super::Handler; -use crate::response::Response; -use http::Request; -use std::{ - convert::Infallible, - fmt, - marker::PhantomData, - sync::Arc, - task::{Context, Poll}, -}; -use tower_service::Service; - -pub(crate) struct IntoServiceStateInExtension { - handler: H, - _marker: PhantomData (T, S, B)>, -} - -#[test] -fn traits() { - use crate::test_helpers::*; - assert_send::>(); - assert_sync::>(); -} - -impl IntoServiceStateInExtension { - pub(crate) fn new(handler: H) -> Self { - Self { - handler, - _marker: PhantomData, - } - } -} - -impl fmt::Debug for IntoServiceStateInExtension { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("IntoServiceStateInExtension") - .finish_non_exhaustive() - } -} - -impl Clone for IntoServiceStateInExtension -where - H: Clone, -{ - fn clone(&self) -> Self { - Self { - handler: self.handler.clone(), - _marker: PhantomData, - } - } -} - -impl Service> for IntoServiceStateInExtension -where - H: Handler + Clone + Send + 'static, - B: Send + 'static, - S: Send + Sync + 'static, -{ - type Response = Response; - type Error = Infallible; - type Future = super::future::IntoServiceFuture; - - #[inline] - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - // `IntoServiceStateInExtension` can only be constructed from async functions which are always ready, or - // from `Layered` which bufferes in `::call` and is therefore - // also always ready. - Poll::Ready(Ok(())) - } - - fn call(&mut self, mut req: Request) -> Self::Future { - use futures_util::future::FutureExt; - - let state = req - .extensions_mut() - .remove::>() - .expect("state extension missing. This is a bug in axum, please file an issue"); - - let handler = self.handler.clone(); - let future = Handler::call(handler, req, state); - let future = future.map(Ok as _); - - super::future::IntoServiceFuture::new(future) - } -} diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index f63c58ae..68a6deb0 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -51,13 +51,10 @@ use tower_service::Service; mod boxed; pub mod future; -mod into_service_state_in_extension; mod service; +pub(crate) use self::boxed::BoxedHandler; pub use self::service::HandlerService; -pub(crate) use self::{ - boxed::BoxedHandler, into_service_state_in_extension::IntoServiceStateInExtension, -}; /// Trait for async functions that can be used to handle requests. /// diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index fffb6339..44067217 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -6,10 +6,11 @@ use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::{ body::{Body, Bytes, HttpBody}, error_handling::{HandleError, HandleErrorLayer}, - handler::{Handler, IntoServiceStateInExtension}, + handler::{BoxedHandler, Handler}, http::{Method, Request, StatusCode}, response::Response, routing::{future::RouteFuture, Fallback, MethodFilter, Route}, + util::try_downcast, }; use axum_core::response::IntoResponse; use bytes::BytesMut; @@ -511,19 +512,19 @@ where /// {} /// ``` pub struct MethodRouter { - get: Option>, - head: Option>, - delete: Option>, - options: Option>, - patch: Option>, - post: Option>, - put: Option>, - trace: Option>, + get: MethodEndpoint, + head: MethodEndpoint, + delete: MethodEndpoint, + options: MethodEndpoint, + patch: MethodEndpoint, + post: MethodEndpoint, + put: MethodEndpoint, + trace: MethodEndpoint, fallback: Fallback, allow_header: AllowHeader, } -#[derive(Clone)] +#[derive(Clone, Debug)] enum AllowHeader { /// No `Allow` header value has been built-up yet. This is the default state None, @@ -561,6 +562,7 @@ impl fmt::Debug for MethodRouter { .field("put", &self.put) .field("trace", &self.trace) .field("fallback", &self.fallback) + .field("allow_header", &self.allow_header) .finish() } } @@ -599,7 +601,10 @@ where T: 'static, S: Send + Sync + 'static, { - self.on_service(filter, IntoServiceStateInExtension::new(handler)) + self.on_endpoint( + filter, + MethodEndpoint::BoxedHandler(BoxedHandler::new(handler)), + ) } chained_handler_fn!(delete, DELETE); @@ -612,13 +617,14 @@ where chained_handler_fn!(trace, TRACE); /// Add a fallback [`Handler`] to the router. - pub fn fallback(self, handler: H) -> Self + pub fn fallback(mut self, handler: H) -> Self where H: Handler, T: 'static, S: Send + Sync + 'static, { - self.fallback_service(IntoServiceStateInExtension::new(handler)) + self.fallback = Fallback::BoxedHandler(BoxedHandler::new(handler)); + self } } @@ -708,14 +714,14 @@ where })); Self { - get: None, - head: None, - delete: None, - options: None, - patch: None, - post: None, - put: None, - trace: None, + get: MethodEndpoint::None, + head: MethodEndpoint::None, + delete: MethodEndpoint::None, + options: MethodEndpoint::None, + patch: MethodEndpoint::None, + post: MethodEndpoint::None, + put: MethodEndpoint::None, + trace: MethodEndpoint::None, allow_header: AllowHeader::None, fallback: Fallback::Default(fallback), } @@ -724,17 +730,25 @@ where /// Provide the state. /// /// See [`State`](crate::extract::State) for more details about accessing state. - pub fn with_state(self, state: S) -> WithState { + pub fn with_state(self, state: S) -> WithState { self.with_state_arc(Arc::new(state)) } /// Provide the [`Arc`]'ed state. /// /// See [`State`](crate::extract::State) for more details about accessing state. - pub fn with_state_arc(self, state: Arc) -> WithState { + pub fn with_state_arc(self, state: Arc) -> WithState { WithState { - method_router: self, - state, + get: self.get.into_route(&state), + head: self.head.into_route(&state), + delete: self.delete.into_route(&state), + options: self.options.into_route(&state), + patch: self.patch.into_route(&state), + post: self.post.into_route(&state), + put: self.put.into_route(&state), + trace: self.trace.into_route(&state), + fallback: self.fallback.into_route(&state), + allow_header: self.allow_header, } } @@ -745,14 +759,14 @@ where S2: 'static, { MethodRouter { - get: self.get, - head: self.head, - delete: self.delete, - options: self.options, - patch: self.patch, - post: self.post, - put: self.put, - trace: self.trace, + get: self.get.map_state(state), + head: self.head.map_state(state), + delete: self.delete.map_state(state), + options: self.options.map_state(state), + patch: self.patch.map_state(state), + post: self.post.map_state(state), + put: self.put.map_state(state), + trace: self.trace.map_state(state), fallback: self.fallback.map_state(state), allow_header: self.allow_header, } @@ -765,14 +779,14 @@ where S2: 'static, { Some(MethodRouter { - get: self.get, - head: self.head, - delete: self.delete, - options: self.options, - patch: self.patch, - post: self.post, - put: self.put, - trace: self.trace, + get: self.get.downcast_state()?, + head: self.head.downcast_state()?, + delete: self.delete.downcast_state()?, + options: self.options.downcast_state()?, + patch: self.patch.downcast_state()?, + post: self.post.downcast_state()?, + put: self.put.downcast_state()?, + trace: self.trace.downcast_state()?, fallback: self.fallback.downcast_state()?, allow_header: self.allow_header, }) @@ -804,112 +818,118 @@ where /// # }; /// ``` #[track_caller] - pub fn on_service(mut self, filter: MethodFilter, svc: T) -> Self + pub fn on_service(self, filter: MethodFilter, svc: T) -> Self where T: Service, Error = E> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, { - // written using an inner function to generate less IR + self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc))) + } + + #[track_caller] + fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint) -> Self { + // written as a separate function to generate less IR #[track_caller] - fn set_service( + fn set_endpoint( method_name: &str, - out: &mut Option, - svc: &T, - svc_filter: MethodFilter, + out: &mut MethodEndpoint, + endpoint: &MethodEndpoint, + endpoint_filter: MethodFilter, filter: MethodFilter, allow_header: &mut AllowHeader, methods: &[&'static str], ) where - T: Clone, + MethodEndpoint: Clone, { - if svc_filter.contains(filter) { + if endpoint_filter.contains(filter) { if out.is_some() { - panic!("Overlapping method route. Cannot add two method routes that both handle `{}`", method_name) + panic!( + "Overlapping method route. Cannot add two method routes that both handle \ + `{method_name}`", + ) } - *out = Some(svc.clone()); + *out = endpoint.clone(); for method in methods { append_allow_header(allow_header, method); } } } - let svc = Route::new(svc); - - set_service( + set_endpoint( "GET", &mut self.get, - &svc, + &endpoint, filter, MethodFilter::GET, &mut self.allow_header, &["GET", "HEAD"], ); - set_service( + set_endpoint( "HEAD", &mut self.head, - &svc, + &endpoint, filter, MethodFilter::HEAD, &mut self.allow_header, &["HEAD"], ); - set_service( + set_endpoint( "TRACE", &mut self.trace, - &svc, + &endpoint, filter, MethodFilter::TRACE, &mut self.allow_header, &["TRACE"], ); - set_service( + set_endpoint( "PUT", &mut self.put, - &svc, + &endpoint, filter, MethodFilter::PUT, &mut self.allow_header, &["PUT"], ); - set_service( + set_endpoint( "POST", &mut self.post, - &svc, + &endpoint, filter, MethodFilter::POST, &mut self.allow_header, &["POST"], ); - set_service( + set_endpoint( "PATCH", &mut self.patch, - &svc, + &endpoint, filter, MethodFilter::PATCH, &mut self.allow_header, &["PATCH"], ); - set_service( + set_endpoint( "OPTIONS", &mut self.options, - &svc, + &endpoint, filter, MethodFilter::OPTIONS, &mut self.allow_header, &["OPTIONS"], ); - set_service( + set_endpoint( "DELETE", &mut self.delete, - &svc, + &endpoint, filter, MethodFilter::DELETE, &mut self.allow_header, @@ -976,10 +996,12 @@ where #[track_caller] pub fn route_layer(mut self, layer: L) -> MethodRouter where - L: Layer>, + L: Layer> + Clone + Send + 'static, L::Service: Service, Error = E> + Clone + Send + 'static, >>::Response: IntoResponse + 'static, >>::Future: Send + 'static, + E: 'static, + S: 'static, { if self.get.is_none() && self.head.is_none() @@ -996,19 +1018,19 @@ where ); } - let layer_fn = |svc| { + let layer_fn = move |svc| { let svc = layer.layer(svc); let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc); Route::new(svc) }; - self.get = self.get.map(layer_fn); - self.head = self.head.map(layer_fn); - self.delete = self.delete.map(layer_fn); - self.options = self.options.map(layer_fn); - self.patch = self.patch.map(layer_fn); - self.post = self.post.map(layer_fn); - self.put = self.put.map(layer_fn); + self.get = self.get.map(layer_fn.clone()); + self.head = self.head.map(layer_fn.clone()); + self.delete = self.delete.map(layer_fn.clone()); + self.options = self.options.map(layer_fn.clone()); + self.patch = self.patch.map(layer_fn.clone()); + self.post = self.post.map(layer_fn.clone()); + self.put = self.put.map(layer_fn.clone()); self.trace = self.trace.map(layer_fn); self @@ -1022,24 +1044,27 @@ where ) -> Self { // written using inner functions to generate less IR #[track_caller] - fn merge_inner( + fn merge_inner( path: Option<&str>, name: &str, - first: Option, - second: Option, - ) -> Option { + first: MethodEndpoint, + second: MethodEndpoint, + ) -> MethodEndpoint { match (first, second) { - (Some(_), Some(_)) => { + (MethodEndpoint::None, MethodEndpoint::None) => MethodEndpoint::None, + (pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => pick, + _ => { if let Some(path) = path { panic!( "Overlapping method route. Handler for `{name} {path}` already exists" - ) + ); } else { - panic!("Overlapping method route. Cannot merge two method routes that both define `{name}`") + panic!( + "Overlapping method route. Cannot merge two method routes that both \ + define `{name}`" + ); } } - (Some(svc), None) | (None, Some(svc)) => Some(svc), - (None, None) => None, } } @@ -1155,6 +1180,92 @@ where } } +enum MethodEndpoint { + None, + Route(Route), + BoxedHandler(BoxedHandler), +} + +impl MethodEndpoint { + fn is_some(&self) -> bool { + matches!(self, Self::Route(_) | Self::BoxedHandler(_)) + } + + fn is_none(&self) -> bool { + matches!(self, Self::None) + } + + fn map(self, f: F) -> MethodEndpoint + where + S: 'static, + B: 'static, + E: 'static, + F: FnOnce(Route) -> Route + Clone + Send + 'static, + B2: 'static, + E2: 'static, + { + match self { + Self::None => MethodEndpoint::None, + Self::Route(route) => MethodEndpoint::Route(f(route)), + Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)), + } + } + + fn map_state(self, state: &Arc) -> MethodEndpoint { + match self { + Self::None => MethodEndpoint::None, + Self::Route(route) => MethodEndpoint::Route(route), + Self::BoxedHandler(handler) => MethodEndpoint::Route(handler.into_route(state.clone())), + } + } + + fn downcast_state(self) -> Option> + where + S: 'static, + B: 'static, + E: 'static, + S2: 'static, + { + match self { + Self::None => Some(MethodEndpoint::None), + Self::Route(route) => Some(MethodEndpoint::Route(route)), + Self::BoxedHandler(handler) => { + try_downcast::, BoxedHandler>(handler) + .map(MethodEndpoint::BoxedHandler) + .ok() + } + } + } + + fn into_route(self, state: &Arc) -> Option> { + match self { + Self::None => None, + Self::Route(route) => Some(route), + Self::BoxedHandler(handler) => Some(handler.into_route(state.clone())), + } + } +} + +impl Clone for MethodEndpoint { + fn clone(&self) -> Self { + match self { + Self::None => Self::None, + Self::Route(inner) => Self::Route(inner.clone()), + Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()), + } + } +} + +impl fmt::Debug for MethodEndpoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::None => f.debug_tuple("None").finish(), + Self::Route(inner) => inner.fmt(f), + Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(), + } + } +} + /// A [`MethodRouter`] which has access to some state. /// /// Implements [`Service`]. @@ -1162,17 +1273,20 @@ where /// The state can be extracted with [`State`](crate::extract::State). /// /// Created with [`MethodRouter::with_state`] -pub struct WithState { - method_router: MethodRouter, - state: Arc, +pub struct WithState { + get: Option>, + head: Option>, + delete: Option>, + options: Option>, + patch: Option>, + post: Option>, + put: Option>, + trace: Option>, + fallback: Route, + allow_header: AllowHeader, } -impl WithState { - /// Get a reference to the state. - pub fn state(&self) -> &S { - &self.state - } - +impl WithState { /// Convert the handler into a [`MakeService`]. /// /// See [`MethodRouter::into_make_service`] for more details. @@ -1194,31 +1308,43 @@ impl WithState { } } -impl Clone for WithState { +impl Clone for WithState { fn clone(&self) -> Self { Self { - method_router: self.method_router.clone(), - state: Arc::clone(&self.state), + get: self.get.clone(), + head: self.head.clone(), + delete: self.delete.clone(), + options: self.options.clone(), + patch: self.patch.clone(), + post: self.post.clone(), + put: self.put.clone(), + trace: self.trace.clone(), + fallback: self.fallback.clone(), + allow_header: self.allow_header.clone(), } } } -impl fmt::Debug for WithState -where - S: fmt::Debug, -{ +impl fmt::Debug for WithState { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("WithState") - .field("method_router", &self.method_router) - .field("state", &self.state) + .field("get", &self.get) + .field("head", &self.head) + .field("delete", &self.delete) + .field("options", &self.options) + .field("patch", &self.patch) + .field("post", &self.post) + .field("put", &self.put) + .field("trace", &self.trace) + .field("fallback", &self.fallback) + .field("allow_header", &self.allow_header) .finish() } } -impl Service> for WithState +impl Service> for WithState where B: HttpBody + Send, - S: Send + Sync + 'static, { type Response = Response; type Error = E; @@ -1229,7 +1355,7 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, mut req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { macro_rules! call { ( $req:expr, @@ -1250,24 +1376,18 @@ where // written with a pattern match like this to ensure we call all routes let Self { - state, - method_router: - MethodRouter { - get, - head, - delete, - options, - patch, - post, - put, - trace, - fallback, - allow_header, - }, + get, + head, + delete, + options, + patch, + post, + put, + trace, + fallback, + allow_header, } = self; - req.extensions_mut().insert(Arc::clone(state)); - call!(req, method, HEAD, head); call!(req, method, HEAD, get); call!(req, method, GET, get); @@ -1278,18 +1398,7 @@ where call!(req, method, DELETE, delete); call!(req, method, TRACE, trace); - let future = match fallback { - Fallback::Default(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req)) - .strip_body(method == Method::HEAD), - Fallback::Service(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req)) - .strip_body(method == Method::HEAD), - Fallback::BoxedHandler(fallback) => RouteFuture::from_future( - fallback - .clone() - .into_route(Arc::clone(state)) - .oneshot_inner(req), - ), - }; + let future = RouteFuture::from_future(fallback.oneshot_inner(req)); match allow_header { AllowHeader::None => future.allow_header(Bytes::new()), @@ -1302,7 +1411,10 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{body::Body, error_handling::HandleErrorLayer, extract::State}; + use crate::{ + body::Body, error_handling::HandleErrorLayer, extract::State, + handler::HandlerWithoutStateExt, + }; use axum_core::response::IntoResponse; use http::{header::ALLOW, HeaderMap}; use std::time::Duration; @@ -1517,8 +1629,7 @@ mod tests { expected = "Overlapping method route. Cannot add two method routes that both handle `POST`" )] async fn service_overlaps() { - let _: MethodRouter<()> = post_service(IntoServiceStateInExtension::<_, _, (), _>::new(ok)) - .post_service(IntoServiceStateInExtension::<_, _, (), _>::new(ok)); + let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service()); } #[tokio::test] diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index f71dc0f4..35068adf 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -7,7 +7,6 @@ use crate::{ body::{Body, HttpBody}, handler::{BoxedHandler, Handler}, util::try_downcast, - Extension, }; use axum_core::response::IntoResponse; use http::Request; @@ -350,9 +349,7 @@ where // other has its state set Some(state) => { let fallback = fallback.map_state(&state); - cast_method_router_closure_slot = move |r: MethodRouter<_, _>| { - r.layer(Extension(Arc::clone(&state))).map_state(&state) - }; + cast_method_router_closure_slot = move |r: MethodRouter<_, _>| r.map_state(&state); let cast_method_router = &cast_method_router_closure_slot as &dyn Fn(MethodRouter<_, _>) -> MethodRouter<_, _>; @@ -637,6 +634,14 @@ impl Fallback { _ => None, } } + + fn into_route(self, state: &Arc) -> Route { + match self { + Self::Default(route) => route, + Self::Service(route) => route, + Self::BoxedHandler(handler) => handler.into_route(state.clone()), + } + } } impl Clone for Fallback { @@ -677,6 +682,7 @@ impl Fallback { } } +#[allow(clippy::large_enum_variant)] // This type is only used at init time, probably fine enum Endpoint { MethodRouter(MethodRouter), Route(Route), @@ -691,10 +697,7 @@ impl Clone for Endpoint { } } -impl fmt::Debug for Endpoint -where - S: fmt::Debug, -{ +impl fmt::Debug for Endpoint { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::MethodRouter(inner) => inner.fmt(f), diff --git a/axum/src/routing/service.rs b/axum/src/routing/service.rs index 982dc509..77d2a57d 100644 --- a/axum/src/routing/service.rs +++ b/axum/src/routing/service.rs @@ -10,7 +10,7 @@ use matchit::MatchError; use tower::Service; use super::{ - future::RouteFuture, url_params, Endpoint, Fallback, Node, Route, RouteId, Router, + future::RouteFuture, url_params, Endpoint, Node, Route, RouteId, Router, NEST_TAIL_PARAM_CAPTURE, }; use crate::{ @@ -57,11 +57,7 @@ where Self { routes, node: router.node, - fallback: match router.fallback { - Fallback::Default(route) => route, - Fallback::Service(route) => route, - Fallback::BoxedHandler(handler) => handler.into_route(state), - }, + fallback: router.fallback.into_route(&state), } }