diff --git a/axum-extra/src/routing/spa.rs b/axum-extra/src/routing/spa.rs index 844c1a5d..2c402fb6 100644 --- a/axum-extra/src/routing/spa.rs +++ b/axum-extra/src/routing/spa.rs @@ -149,7 +149,7 @@ impl<B, T, F> SpaRouter<B, T, F> { impl<B, F, T> From<SpaRouter<B, T, F>> for Router<(), B> where - F: Clone + Send + 'static, + F: Clone + Send + Sync + 'static, HandleError<Route<B, io::Error>, F, T>: Service<Request<B>, Error = Infallible>, <HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send, <HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Future: Send, @@ -161,7 +161,7 @@ where .handle_error(spa.handle_error.clone()); Router::new() - .nest(&spa.paths.assets_path, assets_service) + .nest_service(&spa.paths.assets_path, assets_service) .fallback_service( get_service(ServeFile::new(&spa.paths.index_file)).handle_error(spa.handle_error), ) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 51a6622e..9de9808c 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -22,6 +22,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `map_request_with_state_arc` for transforming the request with an async function ([#1408]) - **breaking:** `ContentLengthLimit` has been removed. `Use DefaultBodyLimit` instead ([#1400]) +- **changed:** `Router` no longer implements `Service`, call `.into_service()` + on it to obtain a `RouterService` that does +- **added:** Add `Router::inherit_state`, which creates a `Router` with an + arbitrary state type without actually supplying the state; such a `Router` + can't be turned into a service directly (`.into_service()` will panic), but + can be nested or merged into a `Router` with the same state type +- **changed:** `Router::nest` now only accepts `Router`s, the general-purpose + `Service` nesting method has been renamed to `nest_service` [#1371]: https://github.com/tokio-rs/axum/pull/1371 [#1387]: https://github.com/tokio-rs/axum/pull/1387 diff --git a/axum/src/docs/routing/nest.md b/axum/src/docs/routing/nest.md index f60a83d8..3b03d4d9 100644 --- a/axum/src/docs/routing/nest.md +++ b/axum/src/docs/routing/nest.md @@ -16,10 +16,10 @@ let user_routes = Router::new().route("/:id", get(|| async {})); let team_routes = Router::new().route("/", post(|| async {})); let api_routes = Router::new() - .nest("/users", user_routes.into_service()) - .nest("/teams", team_routes.into_service()); + .nest("/users", user_routes) + .nest("/teams", team_routes); -let app = Router::new().nest("/api", api_routes.into_service()); +let app = Router::new().nest("/api", api_routes); // Our app now accepts // - GET /api/users/:id @@ -58,7 +58,7 @@ async fn users_get(Path(params): Path<HashMap<String, String>>) { let users_api = Router::new().route("/users/:id", get(users_get)); -let app = Router::new().nest("/:version/api", users_api.into_service()); +let app = Router::new().nest("/:version/api", users_api); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; @@ -82,7 +82,7 @@ let app = Router::new() .route("/foo/*rest", get(|uri: Uri| async { // `uri` will contain `/foo` })) - .nest("/bar", nested_router.into_service()); + .nest("/bar", nested_router); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; @@ -100,10 +100,10 @@ async fn fallback() -> (StatusCode, &'static str) { (StatusCode::NOT_FOUND, "Not Found") } -let api_routes = Router::new().nest("/users", get(|| async {})); +let api_routes = Router::new().nest_service("/users", get(|| async {})); let app = Router::new() - .nest("/api", api_routes.into_service()) + .nest("/api", api_routes) .fallback(fallback); # let _: Router = app; ``` @@ -130,12 +130,12 @@ async fn api_fallback() -> (StatusCode, Json<Value>) { } let api_routes = Router::new() - .nest("/users", get(|| async {})) + .nest_service("/users", get(|| async {})) // add dedicated fallback for requests starting with `/api` .fallback(api_fallback); let app = Router::new() - .nest("/api", api_routes.into_service()) + .nest("/api", api_routes) .fallback(fallback); # let _: Router = app; ``` diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index 4472012a..1e35544e 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -148,7 +148,7 @@ mod tests { "/:key", get(|path: MatchedPath| async move { path.as_str().to_owned() }), ) - .nest("/api", api.into_service()) + .nest("/api", api) .nest( "/public", Router::new() @@ -156,10 +156,9 @@ mod tests { // have to set the middleware here since otherwise the // matched path is just `/public/*` since we're nesting // this router - .layer(layer_fn(SetMatchedPathExtension)) - .into_service(), + .layer(layer_fn(SetMatchedPathExtension)), ) - .nest("/foo", handler.into_service()) + .nest_service("/foo", handler.into_service()) .layer(layer_fn(SetMatchedPathExtension)); let client = TestClient::new(app); @@ -198,12 +197,10 @@ mod tests { async fn nested_opaque_routers_append_to_matched_path() { let app = Router::new().nest( "/:a", - Router::new() - .route( - "/:b", - get(|path: MatchedPath| async move { path.as_str().to_owned() }), - ) - .into_service(), + Router::new().route( + "/:b", + get(|path: MatchedPath| async move { path.as_str().to_owned() }), + ), ); let client = TestClient::new(app); diff --git a/axum/src/extract/request_parts.rs b/axum/src/extract/request_parts.rs index bf35fe75..933e7a71 100644 --- a/axum/src/extract/request_parts.rs +++ b/axum/src/extract/request_parts.rs @@ -38,7 +38,7 @@ use sync_wrapper::SyncWrapper; /// }), /// ); /// -/// let app = Router::new().nest("/api", api_routes.into_service()); +/// let app = Router::new().nest("/api", api_routes); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; @@ -75,7 +75,7 @@ use sync_wrapper::SyncWrapper; /// }), /// ); /// -/// let app = Router::new().nest("/api", api_routes.into_service()); +/// let app = Router::new().nest("/api", api_routes); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; diff --git a/axum/src/handler/boxed.rs b/axum/src/handler/boxed.rs new file mode 100644 index 00000000..34b5c31d --- /dev/null +++ b/axum/src/handler/boxed.rs @@ -0,0 +1,122 @@ +use std::{convert::Infallible, sync::Arc}; + +use super::Handler; +use crate::routing::Route; + +pub(crate) struct BoxedHandler<S, B, E = Infallible>(Box<dyn ErasedHandler<S, B, E>>); + +impl<S, B> BoxedHandler<S, B> +where + S: Send + Sync + 'static, + B: Send + 'static, +{ + pub(crate) fn new<H, T>(handler: H) -> Self + where + H: Handler<T, S, B>, + T: 'static, + { + Self(Box::new(MakeErasedHandler { + handler, + into_route: |handler, state| Route::new(Handler::with_state_arc(handler, state)), + })) + } +} + +impl<S, B, E> BoxedHandler<S, B, E> { + pub(crate) fn map<F, B2, E2>(self, f: F) -> BoxedHandler<S, B2, E2> + where + S: 'static, + B: 'static, + E: 'static, + F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static, + B2: 'static, + E2: 'static, + { + BoxedHandler(Box::new(Map { + handler: self.0, + layer: Box::new(f), + })) + } + + pub(crate) fn into_route(self, state: Arc<S>) -> Route<B, E> { + self.0.into_route(state) + } +} + +impl<S, B, E> Clone for BoxedHandler<S, B, E> { + fn clone(&self) -> Self { + Self(self.0.clone_box()) + } +} + +trait ErasedHandler<S, B, E = Infallible>: Send { + fn clone_box(&self) -> Box<dyn ErasedHandler<S, B, E>>; + fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B, E>; +} + +struct MakeErasedHandler<H, S, B> { + handler: H, + into_route: fn(H, Arc<S>) -> Route<B>, +} + +impl<H, S, B> ErasedHandler<S, B> for MakeErasedHandler<H, S, B> +where + H: Clone + Send + 'static, + S: 'static, + B: 'static, +{ + fn clone_box(&self) -> Box<dyn ErasedHandler<S, B>> { + Box::new(self.clone()) + } + + fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B> { + (self.into_route)(self.handler, state) + } +} + +impl<H: Clone, S, B> Clone for MakeErasedHandler<H, S, B> { + fn clone(&self) -> Self { + Self { + handler: self.handler.clone(), + into_route: self.into_route, + } + } +} + +struct Map<S, B, E, B2, E2> { + handler: Box<dyn ErasedHandler<S, B, E>>, + layer: Box<dyn LayerFn<B, E, B2, E2>>, +} + +impl<S, B, E, B2, E2> ErasedHandler<S, B2, E2> for Map<S, B, E, B2, E2> +where + S: 'static, + B: 'static, + E: 'static, + B2: 'static, + E2: 'static, +{ + fn clone_box(&self) -> Box<dyn ErasedHandler<S, B2, E2>> { + Box::new(Self { + handler: self.handler.clone_box(), + layer: self.layer.clone_box(), + }) + } + + fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B2, E2> { + (self.layer)(self.handler.into_route(state)) + } +} + +trait LayerFn<B, E, B2, E2>: FnOnce(Route<B, E>) -> Route<B2, E2> + Send { + fn clone_box(&self) -> Box<dyn LayerFn<B, E, B2, E2>>; +} + +impl<F, B, E, B2, E2> LayerFn<B, E, B2, E2> for F +where + F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static, +{ + fn clone_box(&self) -> Box<dyn LayerFn<B, E, B2, E2>> { + Box::new(self.clone()) + } +} diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 3fa51ca0..84f2a444 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -47,12 +47,15 @@ use tower::ServiceExt; use tower_layer::Layer; use tower_service::Service; +mod boxed; pub mod future; mod into_service; mod into_service_state_in_extension; mod with_state; -pub(crate) use self::into_service_state_in_extension::IntoServiceStateInExtension; +pub(crate) use self::{ + boxed::BoxedHandler, into_service_state_in_extension::IntoServiceStateInExtension, +}; pub use self::{into_service::IntoService, with_state::WithState}; /// Trait for async functions that can be used to handle requests. diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 8bf8c25a..5c0db520 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -15,7 +15,6 @@ use bytes::BytesMut; use std::{ convert::Infallible, fmt, - marker::PhantomData, sync::Arc, task::{Context, Poll}, }; @@ -521,9 +520,8 @@ pub struct MethodRouter<S = (), B = Body, E = Infallible> { post: Option<Route<B, E>>, put: Option<Route<B, E>>, trace: Option<Route<B, E>>, - fallback: Fallback<B, E>, + fallback: Fallback<S, B, E>, allow_header: AllowHeader, - _marker: PhantomData<fn() -> S>, } #[derive(Clone)] @@ -720,7 +718,6 @@ where trace: None, allow_header: AllowHeader::None, fallback: Fallback::Default(fallback), - _marker: PhantomData, } } @@ -741,7 +738,12 @@ where } } - pub(crate) fn downcast_state<S2>(self) -> MethodRouter<S2, B, E> { + pub(crate) fn map_state<S2>(self, state: &Arc<S>) -> MethodRouter<S2, B, E> + where + E: 'static, + S: 'static, + S2: 'static, + { MethodRouter { get: self.get, head: self.head, @@ -751,12 +753,31 @@ where post: self.post, put: self.put, trace: self.trace, - fallback: self.fallback, + fallback: self.fallback.map_state(state), allow_header: self.allow_header, - _marker: PhantomData, } } + pub(crate) fn downcast_state<S2>(self) -> Option<MethodRouter<S2, B, E>> + where + E: 'static, + S: 'static, + S2: 'static, + { + Some(MethodRouter { + get: self.get, + head: self.head, + delete: self.delete, + options: self.options, + patch: self.patch, + post: self.post, + put: self.put, + trace: self.trace, + fallback: self.fallback.downcast_state()?, + allow_header: self.allow_header, + }) + } + /// Chain an additional service that will accept requests matching the given /// `MethodFilter`. /// @@ -808,7 +829,7 @@ where T::Response: IntoResponse + 'static, T::Future: Send + 'static, { - self.fallback = Fallback::Custom(Route::new(svc)); + self.fallback = Fallback::Service(Route::new(svc)); self } @@ -818,36 +839,40 @@ where T::Response: IntoResponse + 'static, T::Future: Send + 'static, { - self.fallback = Fallback::Custom(Route::new(svc)); + self.fallback = Fallback::Service(Route::new(svc)); self } #[doc = include_str!("../docs/method_routing/layer.md")] - pub fn layer<L, NewReqBody, NewError>(self, layer: L) -> MethodRouter<S, NewReqBody, NewError> + pub fn layer<L, NewReqBody: 'static, NewError: 'static>( + self, + layer: L, + ) -> MethodRouter<S, NewReqBody, NewError> where - L: Layer<Route<B, E>>, + L: Layer<Route<B, E>> + Clone + Send + 'static, L::Service: Service<Request<NewReqBody>, Error = NewError> + Clone + Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, + E: 'static, + S: 'static, { - let layer_fn = |svc| { + let layer_fn = move |svc| { let svc = layer.layer(svc); let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc); Route::new(svc) }; MethodRouter { - get: self.get.map(layer_fn), - head: self.head.map(layer_fn), - delete: self.delete.map(layer_fn), - options: self.options.map(layer_fn), - patch: self.patch.map(layer_fn), - post: self.post.map(layer_fn), - put: self.put.map(layer_fn), - trace: self.trace.map(layer_fn), + get: self.get.map(layer_fn.clone()), + head: self.head.map(layer_fn.clone()), + delete: self.delete.map(layer_fn.clone()), + options: self.options.map(layer_fn.clone()), + patch: self.patch.map(layer_fn.clone()), + post: self.post.map(layer_fn.clone()), + put: self.put.map(layer_fn.clone()), + trace: self.trace.map(layer_fn.clone()), fallback: self.fallback.map(layer_fn), allow_header: self.allow_header, - _marker: self._marker, } } @@ -952,13 +977,14 @@ where /// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`. pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, B, Infallible> where - F: Clone + Send + 'static, + F: Clone + Send + Sync + 'static, HandleError<Route<B, E>, F, T>: Service<Request<B>, Error = Infallible>, <HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Future: Send, <HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send, T: 'static, E: 'static, B: 'static, + S: 'static, { self.layer(HandleErrorLayer::new(f)) } @@ -1136,7 +1162,6 @@ impl<S, B, E> Clone for MethodRouter<S, B, E> { trace: self.trace.clone(), fallback: self.fallback.clone(), allow_header: self.allow_header.clone(), - _marker: self._marker, } } } @@ -1211,7 +1236,7 @@ where impl<S, B, E> Service<Request<B>> for WithState<S, B, E> where - B: HttpBody, + B: HttpBody + Send, S: Send + Sync + 'static, { type Response = Response; @@ -1257,7 +1282,6 @@ where trace, fallback, allow_header, - _marker: _, }, } = self; @@ -1276,8 +1300,14 @@ where let future = match fallback { Fallback::Default(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req)) .strip_body(method == Method::HEAD), - Fallback::Custom(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req)) + Fallback::Service(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req)) .strip_body(method == Method::HEAD), + Fallback::BoxedHandler(fallback) => RouteFuture::from_future( + fallback + .clone() + .into_route(Arc::clone(state)) + .oneshot_inner(req), + ), }; match allow_header { diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 687c51c6..cb20c9e3 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -4,14 +4,20 @@ use self::not_found::NotFound; use crate::{ body::{Body, HttpBody}, extract::connect_info::IntoMakeServiceWithConnectInfo, - handler::Handler, + handler::{BoxedHandler, Handler}, util::try_downcast, Extension, }; use axum_core::response::IntoResponse; use http::Request; use matchit::MatchError; -use std::{collections::HashMap, convert::Infallible, fmt, sync::Arc}; +use std::{ + any::{type_name, TypeId}, + collections::HashMap, + convert::Infallible, + fmt, + sync::Arc, +}; use tower::{util::MapResponseLayer, ServiceBuilder}; use tower_layer::Layer; use tower_service::Service; @@ -59,16 +65,16 @@ impl RouteId { /// The router type for composing handlers and services. pub struct Router<S = (), B = Body> { - state: Arc<S>, + state: Option<Arc<S>>, routes: HashMap<RouteId, Endpoint<S, B>>, node: Arc<Node>, - fallback: Fallback<B>, + fallback: Fallback<S, B>, } impl<S, B> Clone for Router<S, B> { fn clone(&self) -> Self { Self { - state: Arc::clone(&self.state), + state: self.state.clone(), routes: self.routes.clone(), node: Arc::clone(&self.node), fallback: self.fallback.clone(), @@ -162,7 +168,18 @@ where /// [`State`]: crate::extract::State pub fn with_state_arc(state: Arc<S>) -> Self { Self { - state, + state: Some(state), + routes: Default::default(), + node: Default::default(), + fallback: Fallback::Default(Route::new(NotFound)), + } + } + + /// Create a new `Router` that inherits its state from another `Router` that it is merged into + /// or nested under. + pub fn inherit_state() -> Self { + Self { + state: None, routes: Default::default(), node: Default::default(), fallback: Fallback::Default(Route::new(NotFound)), @@ -253,7 +270,29 @@ where #[doc = include_str!("../docs/routing/nest.md")] #[track_caller] - pub fn nest<T>(mut self, mut path: &str, svc: T) -> Self + pub fn nest<S2>(self, path: &str, mut router: Router<S2, B>) -> Self + where + S2: Send + Sync + 'static, + { + if router.state.is_none() { + let s = self.state.clone(); + router.state = match try_downcast::<Option<Arc<S2>>, Option<Arc<S>>>(s) { + Ok(state) => state, + Err(_) => panic!( + "can't nest a `Router` that wants to inherit state of type `{}` \ + into a `Router` with a state type of `{}`", + type_name::<S2>(), + type_name::<S>(), + ), + }; + } + + self.nest_service(path, router.into_service()) + } + + /// Like [`nest`](Self::nest), but accepts an arbitrary `Service`. + #[track_caller] + pub fn nest_service<T>(mut self, mut path: &str, svc: T) -> Self where T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, @@ -305,20 +344,55 @@ where fallback, } = other.into(); + let cast_method_router_closure_slot; + let (fallback, cast_method_router) = match state { + // other has its state set + Some(state) => { + let fallback = fallback.map_state(&state); + cast_method_router_closure_slot = move |r: MethodRouter<_, _>| { + r.layer(Extension(Arc::clone(&state))).map_state(&state) + }; + let cast_method_router = &cast_method_router_closure_slot + as &dyn Fn(MethodRouter<_, _>) -> MethodRouter<_, _>; + + (fallback, cast_method_router) + } + // other wants to inherit its state + None => { + if TypeId::of::<S>() != TypeId::of::<S2>() { + panic!( + "can't merge a `Router` that wants to inherit state of type `{}` \ + into a `Router` with a state type of `{}`", + type_name::<S2>(), + type_name::<S>(), + ); + } + + // With the branch above not taken, we know we can cast S2 to S + let fallback = fallback.downcast_state::<S>().unwrap(); + + fn cast_method_router<S, S2, B>(r: MethodRouter<S2, B>) -> MethodRouter<S, B> + where + B: Send + 'static, + S: 'static, + S2: 'static, + { + r.downcast_state().unwrap() + } + + (fallback, &cast_method_router as _) + } + }; + for (id, route) in routes { let path = node .route_id_to_path .get(&id) .expect("no path for route id. This is a bug in axum. Please file an issue"); self = match route { - Endpoint::MethodRouter(method_router) => self.route( - path, - method_router - // this will set the state for each route - // such we don't override the inner state later in `MethodRouterWithState` - .layer(Extension(Arc::clone(&state))) - .downcast_state(), - ), + Endpoint::MethodRouter(method_router) => { + self.route(path, cast_method_router(method_router)) + } Endpoint::Route(route) => self.route_service(path, route), }; } @@ -332,9 +406,9 @@ where } #[doc = include_str!("../docs/routing/layer.md")] - pub fn layer<L, NewReqBody>(self, layer: L) -> Router<S, NewReqBody> + pub fn layer<L, NewReqBody: 'static>(self, layer: L) -> Router<S, NewReqBody> where - L: Layer<Route<B>>, + L: Layer<Route<B>> + Clone + Send + 'static, L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static, @@ -352,7 +426,7 @@ where .map(|(id, route)| { let route = match route { Endpoint::MethodRouter(method_router) => { - Endpoint::MethodRouter(method_router.layer(&layer)) + Endpoint::MethodRouter(method_router.layer(layer.clone())) } Endpoint::Route(route) => Endpoint::Route(Route::new(layer.layer(route))), }; @@ -360,7 +434,7 @@ where }) .collect(); - let fallback = self.fallback.map(|svc| Route::new(layer.layer(svc))); + let fallback = self.fallback.map(move |svc| Route::new(layer.layer(svc))); Router { state: self.state, @@ -374,7 +448,7 @@ where #[track_caller] pub fn route_layer<L>(self, layer: L) -> Self where - L: Layer<Route<B>>, + L: Layer<Route<B>> + Clone + Send + 'static, L::Service: Service<Request<B>> + Clone + Send + 'static, <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static, @@ -399,7 +473,7 @@ where .map(|(id, route)| { let route = match route { Endpoint::MethodRouter(method_router) => { - Endpoint::MethodRouter(method_router.layer(&layer)) + Endpoint::MethodRouter(method_router.layer(layer.clone())) } Endpoint::Route(route) => Endpoint::Route(Route::new(layer.layer(route))), }; @@ -416,13 +490,13 @@ where } #[doc = include_str!("../docs/routing/fallback.md")] - pub fn fallback<H, T>(self, handler: H) -> Self + pub fn fallback<H, T>(mut self, handler: H) -> Self where H: Handler<T, S, B>, T: 'static, { - let state = Arc::clone(&self.state); - self.fallback_service(handler.with_state_arc(state)) + self.fallback = Fallback::BoxedHandler(BoxedHandler::new(handler)); + self } /// Add a fallback [`Service`] to the router. @@ -434,7 +508,7 @@ where T::Response: IntoResponse, T::Future: Send + 'static, { - self.fallback = Fallback::Custom(Route::new(svc)); + self.fallback = Fallback::Service(Route::new(svc)); self } @@ -478,11 +552,6 @@ where ) -> IntoMakeServiceWithConnectInfo<RouterService<B>, C> { IntoMakeServiceWithConnectInfo::new(self.into_service()) } - - /// Get a reference to the state. - pub fn state(&self) -> &S { - &self.state - } } /// Wrapper around `matchit::Router` that supports merging two `Router`s. @@ -526,12 +595,39 @@ impl fmt::Debug for Node { } } -enum Fallback<B, E = Infallible> { +enum Fallback<S, B, E = Infallible> { Default(Route<B, E>), - Custom(Route<B, E>), + Service(Route<B, E>), + BoxedHandler(BoxedHandler<S, B, E>), } -impl<B, E> Fallback<B, E> { +impl<S, B, E> Fallback<S, B, E> { + fn map_state<S2>(self, state: &Arc<S>) -> Fallback<S2, B, E> { + match self { + Self::Default(route) => Fallback::Default(route), + Self::Service(route) => Fallback::Service(route), + Self::BoxedHandler(handler) => Fallback::Service(handler.into_route(state.clone())), + } + } + + fn downcast_state<S2>(self) -> Option<Fallback<S2, B, E>> + where + S: 'static, + B: 'static, + E: 'static, + S2: 'static, + { + match self { + Self::Default(route) => Some(Fallback::Default(route)), + Self::Service(route) => Some(Fallback::Service(route)), + Self::BoxedHandler(handler) => { + try_downcast::<BoxedHandler<S2, B, E>, BoxedHandler<S, B, E>>(handler) + .map(Fallback::BoxedHandler) + .ok() + } + } + } + fn merge(self, other: Self) -> Option<Self> { match (self, other) { (Self::Default(_), pick @ Self::Default(_)) => Some(pick), @@ -541,32 +637,40 @@ impl<B, E> Fallback<B, E> { } } -impl<B, E> Clone for Fallback<B, E> { +impl<S, B, E> Clone for Fallback<S, B, E> { fn clone(&self) -> Self { match self { Self::Default(inner) => Self::Default(inner.clone()), - Self::Custom(inner) => Self::Custom(inner.clone()), + Self::Service(inner) => Self::Service(inner.clone()), + Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()), } } } -impl<B, E> fmt::Debug for Fallback<B, E> { +impl<S, B, E> fmt::Debug for Fallback<S, B, E> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(), - Self::Custom(inner) => f.debug_tuple("Custom").field(inner).finish(), + Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(), + Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(), } } } -impl<B, E> Fallback<B, E> { - fn map<F, B2, E2>(self, f: F) -> Fallback<B2, E2> +impl<S, B, E> Fallback<S, B, E> { + fn map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2> where - F: FnOnce(Route<B, E>) -> Route<B2, E2>, + S: 'static, + B: 'static, + E: 'static, + F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static, + B2: 'static, + E2: 'static, { match self { Self::Default(inner) => Fallback::Default(f(inner)), - Self::Custom(inner) => Fallback::Custom(f(inner)), + Self::Service(inner) => Fallback::Service(f(inner)), + Self::BoxedHandler(inner) => Fallback::BoxedHandler(inner.map(f)), } } } diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 8b3959df..4b735a8d 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -29,7 +29,7 @@ use tower_service::Service; pub struct Route<B = Body, E = Infallible>(BoxCloneService<Request<B>, Response, E>); impl<B, E> Route<B, E> { - pub(super) fn new<T>(svc: T) -> Self + pub(crate) fn new<T>(svc: T) -> Self where T: Service<Request<B>, Error = E> + Clone + Send + 'static, T::Response: IntoResponse + 'static, diff --git a/axum/src/routing/service.rs b/axum/src/routing/service.rs index 6b0b4ac2..982dc509 100644 --- a/axum/src/routing/service.rs +++ b/axum/src/routing/service.rs @@ -30,17 +30,22 @@ impl<B> RouterService<B> where B: HttpBody + Send + 'static, { + #[track_caller] pub(super) fn new<S>(router: Router<S, B>) -> Self where S: Send + Sync + 'static, { + let state = router + .state + .expect("Can't turn a `Router` that wants to inherit state into a service"); + let routes = router .routes .into_iter() .map(|(route_id, endpoint)| { let route = match endpoint { Endpoint::MethodRouter(method_router) => { - Route::new(method_router.with_state_arc(Arc::clone(&router.state))) + Route::new(method_router.with_state_arc(Arc::clone(&state))) } Endpoint::Route(route) => route, }; @@ -54,7 +59,8 @@ where node: router.node, fallback: match router.fallback { Fallback::Default(route) => route, - Fallback::Custom(route) => route, + Fallback::Service(route) => route, + Fallback::BoxedHandler(handler) => handler.into_route(state), }, } } diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index d9b27da0..4da166ba 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -18,10 +18,7 @@ async fn basic() { #[tokio::test] async fn nest() { let app = Router::new() - .nest( - "/foo", - Router::new().route("/bar", get(|| async {})).into_service(), - ) + .nest("/foo", Router::new().route("/bar", get(|| async {}))) .fallback(|| async { "fallback" }); let client = TestClient::new(app); diff --git a/axum/src/routing/tests/merge.rs b/axum/src/routing/tests/merge.rs index 7ff34a77..804549d3 100644 --- a/axum/src/routing/tests/merge.rs +++ b/axum/src/routing/tests/merge.rs @@ -82,7 +82,7 @@ async fn nested_or() { assert_eq!(client.get("/bar").send().await.text().await, "bar"); assert_eq!(client.get("/baz").send().await.text().await, "baz"); - let client = TestClient::new(Router::new().nest("/foo", bar_or_baz.into_service())); + let client = TestClient::new(Router::new().nest("/foo", bar_or_baz)); assert_eq!(client.get("/foo/bar").send().await.text().await, "bar"); assert_eq!(client.get("/foo/baz").send().await.text().await, "baz"); } @@ -145,10 +145,7 @@ async fn layer_and_handle_error() { #[tokio::test] async fn nesting() { let one = Router::new().route("/foo", get(|| async {})); - let two = Router::new().nest( - "/bar", - Router::new().route("/baz", get(|| async {})).into_service(), - ); + let two = Router::new().nest("/bar", Router::new().route("/baz", get(|| async {}))); let app = one.merge(two); let client = TestClient::new(app); @@ -232,12 +229,7 @@ async fn all_the_uris( #[tokio::test] async fn nesting_and_seeing_the_right_uri() { - let one = Router::new().nest( - "/foo/", - Router::new() - .route("/bar", get(all_the_uris)) - .into_service(), - ); + let one = Router::new().nest("/foo/", Router::new().route("/bar", get(all_the_uris))); let two = Router::new().route("/foo", get(all_the_uris)); let client = TestClient::new(one.merge(two)); @@ -269,14 +261,7 @@ async fn nesting_and_seeing_the_right_uri() { async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() { let one = Router::new().nest( "/foo/", - Router::new() - .nest( - "/bar", - Router::new() - .route("/baz", get(all_the_uris)) - .into_service(), - ) - .into_service(), + Router::new().nest("/bar", Router::new().route("/baz", get(all_the_uris))), ); let two = Router::new().route("/foo", get(all_the_uris)); @@ -309,21 +294,9 @@ async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() { async fn nesting_and_seeing_the_right_uri_ors_with_nesting() { let one = Router::new().nest( "/one", - Router::new() - .nest( - "/bar", - Router::new() - .route("/baz", get(all_the_uris)) - .into_service(), - ) - .into_service(), - ); - let two = Router::new().nest( - "/two", - Router::new() - .route("/qux", get(all_the_uris)) - .into_service(), + Router::new().nest("/bar", Router::new().route("/baz", get(all_the_uris))), ); + let two = Router::new().nest("/two", Router::new().route("/qux", get(all_the_uris))); let three = Router::new().route("/three", get(all_the_uris)); let client = TestClient::new(one.merge(two).merge(three)); @@ -366,14 +339,7 @@ async fn nesting_and_seeing_the_right_uri_ors_with_nesting() { async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() { let one = Router::new().nest( "/one", - Router::new() - .nest( - "/foo", - Router::new() - .route("/bar", get(all_the_uris)) - .into_service(), - ) - .into_service(), + Router::new().nest("/foo", Router::new().route("/bar", get(all_the_uris))), ); let two = Router::new().route("/two/foo", get(all_the_uris)); @@ -500,3 +466,18 @@ async fn merging_routes_different_paths_different_states() { assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "bar state"); } + +#[tokio::test] +async fn inherit_state_via_merge() { + let foo = Router::inherit_state().route( + "/foo", + get(|State(state): State<&'static str>| async move { state }), + ); + + let app = Router::with_state("state").merge(foo); + let client = TestClient::new(app); + + let res = client.get("/foo").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "state"); +} diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index c61efbf4..8829a64a 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -19,9 +19,7 @@ use std::{ task::{Context, Poll}, time::Duration, }; -use tower::{ - service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder, ServiceExt, -}; +use tower::{service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder}; use tower_http::{auth::RequireAuthorizationLayer, limit::RequestBodyLimitLayer}; use tower_service::Service; diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index a0243f1c..42c4ffe9 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -37,7 +37,7 @@ async fn nesting_apps() { let app = Router::new() .route("/", get(|| async { "hi" })) - .nest("/:version/api", api_routes.into_service()); + .nest("/:version/api", api_routes); let client = TestClient::new(app); @@ -61,7 +61,7 @@ async fn nesting_apps() { #[tokio::test] async fn wrong_method_nest() { let nested_app = Router::new().route("/", get(|| async {})); - let app = Router::new().nest("/", nested_app.into_service()); + let app = Router::new().nest("/", nested_app); let client = TestClient::new(app); @@ -78,7 +78,7 @@ async fn wrong_method_nest() { #[tokio::test] async fn nesting_router_at_root() { let nested = Router::new().route("/foo", get(|uri: Uri| async move { uri.to_string() })); - let app = Router::new().nest("/", nested.into_service()); + let app = Router::new().nest("/", nested); let client = TestClient::new(app); @@ -96,7 +96,7 @@ async fn nesting_router_at_root() { #[tokio::test] async fn nesting_router_at_empty_path() { let nested = Router::new().route("/foo", get(|uri: Uri| async move { uri.to_string() })); - let app = Router::new().nest("", nested.into_service()); + let app = Router::new().nest("", nested); let client = TestClient::new(app); @@ -113,7 +113,7 @@ async fn nesting_router_at_empty_path() { #[tokio::test] async fn nesting_handler_at_root() { - let app = Router::new().nest("/", get(|uri: Uri| async move { uri.to_string() })); + let app = Router::new().nest_service("/", get(|uri: Uri| async move { uri.to_string() })); let client = TestClient::new(app); @@ -134,18 +134,15 @@ async fn nesting_handler_at_root() { async fn nested_url_extractor() { let app = Router::new().nest( "/foo", - Router::new() - .nest( - "/bar", - Router::new() - .route("/baz", get(|uri: Uri| async move { uri.to_string() })) - .route( - "/qux", - get(|req: Request<Body>| async move { req.uri().to_string() }), - ) - .into_service(), - ) - .into_service(), + Router::new().nest( + "/bar", + Router::new() + .route("/baz", get(|uri: Uri| async move { uri.to_string() })) + .route( + "/qux", + get(|req: Request<Body>| async move { req.uri().to_string() }), + ), + ), ); let client = TestClient::new(app); @@ -163,17 +160,13 @@ async fn nested_url_extractor() { async fn nested_url_original_extractor() { let app = Router::new().nest( "/foo", - Router::new() - .nest( - "/bar", - Router::new() - .route( - "/baz", - get(|uri: extract::OriginalUri| async move { uri.0.to_string() }), - ) - .into_service(), - ) - .into_service(), + Router::new().nest( + "/bar", + Router::new().route( + "/baz", + get(|uri: extract::OriginalUri| async move { uri.0.to_string() }), + ), + ), ); let client = TestClient::new(app); @@ -187,20 +180,16 @@ async fn nested_url_original_extractor() { async fn nested_service_sees_stripped_uri() { let app = Router::new().nest( "/foo", - Router::new() - .nest( - "/bar", - Router::new() - .route_service( - "/baz", - service_fn(|req: Request<Body>| async move { - let body = boxed(Body::from(req.uri().to_string())); - Ok::<_, Infallible>(Response::new(body)) - }), - ) - .into_service(), - ) - .into_service(), + Router::new().nest( + "/bar", + Router::new().route_service( + "/baz", + service_fn(|req: Request<Body>| async move { + let body = boxed(Body::from(req.uri().to_string())); + Ok::<_, Infallible>(Response::new(body)) + }), + ), + ), ); let client = TestClient::new(app); @@ -212,7 +201,7 @@ async fn nested_service_sees_stripped_uri() { #[tokio::test] async fn nest_static_file_server() { - let app = Router::new().nest( + let app = Router::new().nest_service( "/static", get_service(ServeDir::new(".")).handle_error(|error| async move { ( @@ -235,8 +224,7 @@ async fn nested_multiple_routes() { "/api", Router::new() .route("/users", get(|| async { "users" })) - .route("/teams", get(|| async { "teams" })) - .into_service(), + .route("/teams", get(|| async { "teams" })), ) .route("/", get(|| async { "root" })); @@ -251,12 +239,7 @@ async fn nested_multiple_routes() { #[should_panic = "Invalid route \"/\": insertion failed due to conflict with previously registered route: /*__private__axum_nest_tail_param"] fn nested_at_root_with_other_routes() { let _: Router = Router::new() - .nest( - "/", - Router::new() - .route("/users", get(|| async {})) - .into_service(), - ) + .nest("/", Router::new().route("/users", get(|| async {}))) .route("/", get(|| async {})); } @@ -265,15 +248,11 @@ async fn multiple_top_level_nests() { let app = Router::new() .nest( "/one", - Router::new() - .route("/route", get(|| async { "one" })) - .into_service(), + Router::new().route("/route", get(|| async { "one" })), ) .nest( "/two", - Router::new() - .route("/route", get(|| async { "two" })) - .into_service(), + Router::new().route("/route", get(|| async { "two" })), ); let client = TestClient::new(app); @@ -285,7 +264,7 @@ async fn multiple_top_level_nests() { #[tokio::test] #[should_panic(expected = "Invalid route: nested routes cannot contain wildcards (*)")] async fn nest_cannot_contain_wildcards() { - Router::<_, Body>::new().nest("/one/*rest", Router::new().into_service()); + Router::<_, Body>::new().nest("/one/*rest", Router::new()); } #[tokio::test] @@ -323,10 +302,7 @@ async fn outer_middleware_still_see_whole_url() { .route("/", get(handler)) .route("/foo", get(handler)) .route("/foo/bar", get(handler)) - .nest( - "/one", - Router::new().route("/two", get(handler)).into_service(), - ) + .nest("/one", Router::new().route("/two", get(handler))) .fallback(handler) .layer(tower::layer::layer_fn(SetUriExtension)); @@ -344,13 +320,10 @@ async fn outer_middleware_still_see_whole_url() { #[tokio::test] async fn nest_at_capture() { - let api_routes = Router::new() - .route( - "/:b", - get(|Path((a, b)): Path<(String, String)>| async move { format!("a={} b={}", a, b) }), - ) - .into_service() - .boxed_clone(); + let api_routes = Router::new().route( + "/:b", + get(|Path((a, b)): Path<(String, String)>| async move { format!("a={} b={}", a, b) }), + ); let app = Router::new().nest("/:a", api_routes); @@ -363,7 +336,7 @@ async fn nest_at_capture() { #[tokio::test] async fn nest_with_and_without_trailing() { - let app = Router::new().nest("/foo", get(|| async {})); + let app = Router::new().nest_service("/foo", get(|| async {})); let client = TestClient::new(app); @@ -380,10 +353,7 @@ async fn nest_with_and_without_trailing() { #[tokio::test] async fn doesnt_call_outer_fallback() { let app = Router::new() - .nest( - "/foo", - Router::new().route("/", get(|| async {})).into_service(), - ) + .nest("/foo", Router::new().route("/", get(|| async {}))) .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") }); let client = TestClient::new(app); @@ -401,9 +371,7 @@ async fn doesnt_call_outer_fallback() { async fn nesting_with_root_inner_router() { let app = Router::new().nest( "/foo", - Router::new() - .route("/", get(|| async { "inner route" })) - .into_service(), + Router::new().route("/", get(|| async { "inner route" })), ); let client = TestClient::new(app); @@ -426,8 +394,7 @@ async fn fallback_on_inner() { "/foo", Router::new() .route("/", get(|| async {})) - .fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }) - .into_service(), + .fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }), ) .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") }); @@ -451,7 +418,7 @@ macro_rules! nested_route_test { #[tokio::test] async fn $name() { let inner = Router::new().route($route_path, get(|| async {})); - let app = Router::new().nest($nested_path, inner.into_service()); + let app = Router::new().nest($nested_path, inner); let client = TestClient::new(app); let res = client.get($expected_path).send().await; let status = res.status(); @@ -486,7 +453,7 @@ async fn nesting_with_different_state() { "/foo", get(|State(state): State<&'static str>| async move { state }), ) - .nest("/nested", inner.into_service()) + .nest("/nested", inner) .route( "/bar", get(|State(state): State<&'static str>| async move { state }), @@ -503,3 +470,18 @@ async fn nesting_with_different_state() { let res = client.get("/bar").send().await; assert_eq!(res.text().await, "outer"); } + +#[tokio::test] +async fn inherit_state_via_nest() { + let foo = Router::inherit_state().route( + "/foo", + get(|State(state): State<&'static str>| async move { state }), + ); + + let app = Router::with_state("state").nest("/test", foo); + let client = TestClient::new(app); + + let res = client.get("/test/foo").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "state"); +} diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index 1d33b364..b9580102 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -58,7 +58,7 @@ async fn main() { ) .route("/keys", get(list_keys)) // Nest our admin routes under `/admin` - .nest("/admin", admin_routes(shared_state).into_service()) + .nest("/admin", admin_routes(shared_state)) // Add middleware to all routes .layer( ServiceBuilder::new() diff --git a/examples/stream-to-file/src/main.rs b/examples/stream-to-file/src/main.rs index bd1405b7..018d3d2f 100644 --- a/examples/stream-to-file/src/main.rs +++ b/examples/stream-to-file/src/main.rs @@ -131,7 +131,7 @@ where // to prevent directory traversal attacks we ensure the path consists of exactly one normal // component fn path_is_valid(path: &str) -> bool { - let path = std::path::Path::new(&*path); + let path = std::path::Path::new(path); let mut components = path.components().peekable(); if let Some(first) = components.peek() {