1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-03-26 00:27:01 +01:00

Allow Routers to inherit state ()

* Rename Fallback::Custom to Fallback::Service

* Allow Routers to inherit state

* Rename Router::{nest => nest_service} and add new nest method for Routers

* Fix lints

* Add basic tests for state inheritance

* Changelog
This commit is contained in:
Jonas Platte 2022-09-25 13:56:23 +02:00 committed by GitHub
parent 2077d50021
commit 4847d681b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 453 additions and 225 deletions
axum-extra/src/routing
axum
examples
key-value-store/src
stream-to-file/src

View file

@ -149,7 +149,7 @@ impl<B, T, F> SpaRouter<B, T, F> {
impl<B, F, T> From<SpaRouter<B, T, F>> for Router<(), B>
where
F: Clone + Send + 'static,
F: Clone + Send + Sync + 'static,
HandleError<Route<B, io::Error>, F, T>: Service<Request<B>, Error = Infallible>,
<HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send,
<HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Future: Send,
@ -161,7 +161,7 @@ where
.handle_error(spa.handle_error.clone());
Router::new()
.nest(&spa.paths.assets_path, assets_service)
.nest_service(&spa.paths.assets_path, assets_service)
.fallback_service(
get_service(ServeFile::new(&spa.paths.index_file)).handle_error(spa.handle_error),
)

View file

@ -22,6 +22,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`map_request_with_state_arc` for transforming the request with an async
function ([#1408])
- **breaking:** `ContentLengthLimit` has been removed. `Use DefaultBodyLimit` instead ([#1400])
- **changed:** `Router` no longer implements `Service`, call `.into_service()`
on it to obtain a `RouterService` that does
- **added:** Add `Router::inherit_state`, which creates a `Router` with an
arbitrary state type without actually supplying the state; such a `Router`
can't be turned into a service directly (`.into_service()` will panic), but
can be nested or merged into a `Router` with the same state type
- **changed:** `Router::nest` now only accepts `Router`s, the general-purpose
`Service` nesting method has been renamed to `nest_service`
[#1371]: https://github.com/tokio-rs/axum/pull/1371
[#1387]: https://github.com/tokio-rs/axum/pull/1387

View file

@ -16,10 +16,10 @@ let user_routes = Router::new().route("/:id", get(|| async {}));
let team_routes = Router::new().route("/", post(|| async {}));
let api_routes = Router::new()
.nest("/users", user_routes.into_service())
.nest("/teams", team_routes.into_service());
.nest("/users", user_routes)
.nest("/teams", team_routes);
let app = Router::new().nest("/api", api_routes.into_service());
let app = Router::new().nest("/api", api_routes);
// Our app now accepts
// - GET /api/users/:id
@ -58,7 +58,7 @@ async fn users_get(Path(params): Path<HashMap<String, String>>) {
let users_api = Router::new().route("/users/:id", get(users_get));
let app = Router::new().nest("/:version/api", users_api.into_service());
let app = Router::new().nest("/:version/api", users_api);
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
@ -82,7 +82,7 @@ let app = Router::new()
.route("/foo/*rest", get(|uri: Uri| async {
// `uri` will contain `/foo`
}))
.nest("/bar", nested_router.into_service());
.nest("/bar", nested_router);
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
@ -100,10 +100,10 @@ async fn fallback() -> (StatusCode, &'static str) {
(StatusCode::NOT_FOUND, "Not Found")
}
let api_routes = Router::new().nest("/users", get(|| async {}));
let api_routes = Router::new().nest_service("/users", get(|| async {}));
let app = Router::new()
.nest("/api", api_routes.into_service())
.nest("/api", api_routes)
.fallback(fallback);
# let _: Router = app;
```
@ -130,12 +130,12 @@ async fn api_fallback() -> (StatusCode, Json<Value>) {
}
let api_routes = Router::new()
.nest("/users", get(|| async {}))
.nest_service("/users", get(|| async {}))
// add dedicated fallback for requests starting with `/api`
.fallback(api_fallback);
let app = Router::new()
.nest("/api", api_routes.into_service())
.nest("/api", api_routes)
.fallback(fallback);
# let _: Router = app;
```

View file

@ -148,7 +148,7 @@ mod tests {
"/:key",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
)
.nest("/api", api.into_service())
.nest("/api", api)
.nest(
"/public",
Router::new()
@ -156,10 +156,9 @@ mod tests {
// have to set the middleware here since otherwise the
// matched path is just `/public/*` since we're nesting
// this router
.layer(layer_fn(SetMatchedPathExtension))
.into_service(),
.layer(layer_fn(SetMatchedPathExtension)),
)
.nest("/foo", handler.into_service())
.nest_service("/foo", handler.into_service())
.layer(layer_fn(SetMatchedPathExtension));
let client = TestClient::new(app);
@ -198,12 +197,10 @@ mod tests {
async fn nested_opaque_routers_append_to_matched_path() {
let app = Router::new().nest(
"/:a",
Router::new()
.route(
"/:b",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
)
.into_service(),
Router::new().route(
"/:b",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
),
);
let client = TestClient::new(app);

View file

@ -38,7 +38,7 @@ use sync_wrapper::SyncWrapper;
/// }),
/// );
///
/// let app = Router::new().nest("/api", api_routes.into_service());
/// let app = Router::new().nest("/api", api_routes);
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
@ -75,7 +75,7 @@ use sync_wrapper::SyncWrapper;
/// }),
/// );
///
/// let app = Router::new().nest("/api", api_routes.into_service());
/// let app = Router::new().nest("/api", api_routes);
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };

122
axum/src/handler/boxed.rs Normal file
View file

@ -0,0 +1,122 @@
use std::{convert::Infallible, sync::Arc};
use super::Handler;
use crate::routing::Route;
pub(crate) struct BoxedHandler<S, B, E = Infallible>(Box<dyn ErasedHandler<S, B, E>>);
impl<S, B> BoxedHandler<S, B>
where
S: Send + Sync + 'static,
B: Send + 'static,
{
pub(crate) fn new<H, T>(handler: H) -> Self
where
H: Handler<T, S, B>,
T: 'static,
{
Self(Box::new(MakeErasedHandler {
handler,
into_route: |handler, state| Route::new(Handler::with_state_arc(handler, state)),
}))
}
}
impl<S, B, E> BoxedHandler<S, B, E> {
pub(crate) fn map<F, B2, E2>(self, f: F) -> BoxedHandler<S, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: 'static,
E2: 'static,
{
BoxedHandler(Box::new(Map {
handler: self.0,
layer: Box::new(f),
}))
}
pub(crate) fn into_route(self, state: Arc<S>) -> Route<B, E> {
self.0.into_route(state)
}
}
impl<S, B, E> Clone for BoxedHandler<S, B, E> {
fn clone(&self) -> Self {
Self(self.0.clone_box())
}
}
trait ErasedHandler<S, B, E = Infallible>: Send {
fn clone_box(&self) -> Box<dyn ErasedHandler<S, B, E>>;
fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B, E>;
}
struct MakeErasedHandler<H, S, B> {
handler: H,
into_route: fn(H, Arc<S>) -> Route<B>,
}
impl<H, S, B> ErasedHandler<S, B> for MakeErasedHandler<H, S, B>
where
H: Clone + Send + 'static,
S: 'static,
B: 'static,
{
fn clone_box(&self) -> Box<dyn ErasedHandler<S, B>> {
Box::new(self.clone())
}
fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B> {
(self.into_route)(self.handler, state)
}
}
impl<H: Clone, S, B> Clone for MakeErasedHandler<H, S, B> {
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
into_route: self.into_route,
}
}
}
struct Map<S, B, E, B2, E2> {
handler: Box<dyn ErasedHandler<S, B, E>>,
layer: Box<dyn LayerFn<B, E, B2, E2>>,
}
impl<S, B, E, B2, E2> ErasedHandler<S, B2, E2> for Map<S, B, E, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
B2: 'static,
E2: 'static,
{
fn clone_box(&self) -> Box<dyn ErasedHandler<S, B2, E2>> {
Box::new(Self {
handler: self.handler.clone_box(),
layer: self.layer.clone_box(),
})
}
fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B2, E2> {
(self.layer)(self.handler.into_route(state))
}
}
trait LayerFn<B, E, B2, E2>: FnOnce(Route<B, E>) -> Route<B2, E2> + Send {
fn clone_box(&self) -> Box<dyn LayerFn<B, E, B2, E2>>;
}
impl<F, B, E, B2, E2> LayerFn<B, E, B2, E2> for F
where
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
{
fn clone_box(&self) -> Box<dyn LayerFn<B, E, B2, E2>> {
Box::new(self.clone())
}
}

View file

@ -47,12 +47,15 @@ use tower::ServiceExt;
use tower_layer::Layer;
use tower_service::Service;
mod boxed;
pub mod future;
mod into_service;
mod into_service_state_in_extension;
mod with_state;
pub(crate) use self::into_service_state_in_extension::IntoServiceStateInExtension;
pub(crate) use self::{
boxed::BoxedHandler, into_service_state_in_extension::IntoServiceStateInExtension,
};
pub use self::{into_service::IntoService, with_state::WithState};
/// Trait for async functions that can be used to handle requests.

View file

@ -15,7 +15,6 @@ use bytes::BytesMut;
use std::{
convert::Infallible,
fmt,
marker::PhantomData,
sync::Arc,
task::{Context, Poll},
};
@ -521,9 +520,8 @@ pub struct MethodRouter<S = (), B = Body, E = Infallible> {
post: Option<Route<B, E>>,
put: Option<Route<B, E>>,
trace: Option<Route<B, E>>,
fallback: Fallback<B, E>,
fallback: Fallback<S, B, E>,
allow_header: AllowHeader,
_marker: PhantomData<fn() -> S>,
}
#[derive(Clone)]
@ -720,7 +718,6 @@ where
trace: None,
allow_header: AllowHeader::None,
fallback: Fallback::Default(fallback),
_marker: PhantomData,
}
}
@ -741,7 +738,12 @@ where
}
}
pub(crate) fn downcast_state<S2>(self) -> MethodRouter<S2, B, E> {
pub(crate) fn map_state<S2>(self, state: &Arc<S>) -> MethodRouter<S2, B, E>
where
E: 'static,
S: 'static,
S2: 'static,
{
MethodRouter {
get: self.get,
head: self.head,
@ -751,12 +753,31 @@ where
post: self.post,
put: self.put,
trace: self.trace,
fallback: self.fallback,
fallback: self.fallback.map_state(state),
allow_header: self.allow_header,
_marker: PhantomData,
}
}
pub(crate) fn downcast_state<S2>(self) -> Option<MethodRouter<S2, B, E>>
where
E: 'static,
S: 'static,
S2: 'static,
{
Some(MethodRouter {
get: self.get,
head: self.head,
delete: self.delete,
options: self.options,
patch: self.patch,
post: self.post,
put: self.put,
trace: self.trace,
fallback: self.fallback.downcast_state()?,
allow_header: self.allow_header,
})
}
/// Chain an additional service that will accept requests matching the given
/// `MethodFilter`.
///
@ -808,7 +829,7 @@ where
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
{
self.fallback = Fallback::Custom(Route::new(svc));
self.fallback = Fallback::Service(Route::new(svc));
self
}
@ -818,36 +839,40 @@ where
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
{
self.fallback = Fallback::Custom(Route::new(svc));
self.fallback = Fallback::Service(Route::new(svc));
self
}
#[doc = include_str!("../docs/method_routing/layer.md")]
pub fn layer<L, NewReqBody, NewError>(self, layer: L) -> MethodRouter<S, NewReqBody, NewError>
pub fn layer<L, NewReqBody: 'static, NewError: 'static>(
self,
layer: L,
) -> MethodRouter<S, NewReqBody, NewError>
where
L: Layer<Route<B, E>>,
L: Layer<Route<B, E>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>, Error = NewError> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
E: 'static,
S: 'static,
{
let layer_fn = |svc| {
let layer_fn = move |svc| {
let svc = layer.layer(svc);
let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
Route::new(svc)
};
MethodRouter {
get: self.get.map(layer_fn),
head: self.head.map(layer_fn),
delete: self.delete.map(layer_fn),
options: self.options.map(layer_fn),
patch: self.patch.map(layer_fn),
post: self.post.map(layer_fn),
put: self.put.map(layer_fn),
trace: self.trace.map(layer_fn),
get: self.get.map(layer_fn.clone()),
head: self.head.map(layer_fn.clone()),
delete: self.delete.map(layer_fn.clone()),
options: self.options.map(layer_fn.clone()),
patch: self.patch.map(layer_fn.clone()),
post: self.post.map(layer_fn.clone()),
put: self.put.map(layer_fn.clone()),
trace: self.trace.map(layer_fn.clone()),
fallback: self.fallback.map(layer_fn),
allow_header: self.allow_header,
_marker: self._marker,
}
}
@ -952,13 +977,14 @@ where
/// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`.
pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, B, Infallible>
where
F: Clone + Send + 'static,
F: Clone + Send + Sync + 'static,
HandleError<Route<B, E>, F, T>: Service<Request<B>, Error = Infallible>,
<HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Future: Send,
<HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send,
T: 'static,
E: 'static,
B: 'static,
S: 'static,
{
self.layer(HandleErrorLayer::new(f))
}
@ -1136,7 +1162,6 @@ impl<S, B, E> Clone for MethodRouter<S, B, E> {
trace: self.trace.clone(),
fallback: self.fallback.clone(),
allow_header: self.allow_header.clone(),
_marker: self._marker,
}
}
}
@ -1211,7 +1236,7 @@ where
impl<S, B, E> Service<Request<B>> for WithState<S, B, E>
where
B: HttpBody,
B: HttpBody + Send,
S: Send + Sync + 'static,
{
type Response = Response;
@ -1257,7 +1282,6 @@ where
trace,
fallback,
allow_header,
_marker: _,
},
} = self;
@ -1276,8 +1300,14 @@ where
let future = match fallback {
Fallback::Default(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
.strip_body(method == Method::HEAD),
Fallback::Custom(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
Fallback::Service(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
.strip_body(method == Method::HEAD),
Fallback::BoxedHandler(fallback) => RouteFuture::from_future(
fallback
.clone()
.into_route(Arc::clone(state))
.oneshot_inner(req),
),
};
match allow_header {

View file

@ -4,14 +4,20 @@ use self::not_found::NotFound;
use crate::{
body::{Body, HttpBody},
extract::connect_info::IntoMakeServiceWithConnectInfo,
handler::Handler,
handler::{BoxedHandler, Handler},
util::try_downcast,
Extension,
};
use axum_core::response::IntoResponse;
use http::Request;
use matchit::MatchError;
use std::{collections::HashMap, convert::Infallible, fmt, sync::Arc};
use std::{
any::{type_name, TypeId},
collections::HashMap,
convert::Infallible,
fmt,
sync::Arc,
};
use tower::{util::MapResponseLayer, ServiceBuilder};
use tower_layer::Layer;
use tower_service::Service;
@ -59,16 +65,16 @@ impl RouteId {
/// The router type for composing handlers and services.
pub struct Router<S = (), B = Body> {
state: Arc<S>,
state: Option<Arc<S>>,
routes: HashMap<RouteId, Endpoint<S, B>>,
node: Arc<Node>,
fallback: Fallback<B>,
fallback: Fallback<S, B>,
}
impl<S, B> Clone for Router<S, B> {
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
state: self.state.clone(),
routes: self.routes.clone(),
node: Arc::clone(&self.node),
fallback: self.fallback.clone(),
@ -162,7 +168,18 @@ where
/// [`State`]: crate::extract::State
pub fn with_state_arc(state: Arc<S>) -> Self {
Self {
state,
state: Some(state),
routes: Default::default(),
node: Default::default(),
fallback: Fallback::Default(Route::new(NotFound)),
}
}
/// Create a new `Router` that inherits its state from another `Router` that it is merged into
/// or nested under.
pub fn inherit_state() -> Self {
Self {
state: None,
routes: Default::default(),
node: Default::default(),
fallback: Fallback::Default(Route::new(NotFound)),
@ -253,7 +270,29 @@ where
#[doc = include_str!("../docs/routing/nest.md")]
#[track_caller]
pub fn nest<T>(mut self, mut path: &str, svc: T) -> Self
pub fn nest<S2>(self, path: &str, mut router: Router<S2, B>) -> Self
where
S2: Send + Sync + 'static,
{
if router.state.is_none() {
let s = self.state.clone();
router.state = match try_downcast::<Option<Arc<S2>>, Option<Arc<S>>>(s) {
Ok(state) => state,
Err(_) => panic!(
"can't nest a `Router` that wants to inherit state of type `{}` \
into a `Router` with a state type of `{}`",
type_name::<S2>(),
type_name::<S>(),
),
};
}
self.nest_service(path, router.into_service())
}
/// Like [`nest`](Self::nest), but accepts an arbitrary `Service`.
#[track_caller]
pub fn nest_service<T>(mut self, mut path: &str, svc: T) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
@ -305,20 +344,55 @@ where
fallback,
} = other.into();
let cast_method_router_closure_slot;
let (fallback, cast_method_router) = match state {
// other has its state set
Some(state) => {
let fallback = fallback.map_state(&state);
cast_method_router_closure_slot = move |r: MethodRouter<_, _>| {
r.layer(Extension(Arc::clone(&state))).map_state(&state)
};
let cast_method_router = &cast_method_router_closure_slot
as &dyn Fn(MethodRouter<_, _>) -> MethodRouter<_, _>;
(fallback, cast_method_router)
}
// other wants to inherit its state
None => {
if TypeId::of::<S>() != TypeId::of::<S2>() {
panic!(
"can't merge a `Router` that wants to inherit state of type `{}` \
into a `Router` with a state type of `{}`",
type_name::<S2>(),
type_name::<S>(),
);
}
// With the branch above not taken, we know we can cast S2 to S
let fallback = fallback.downcast_state::<S>().unwrap();
fn cast_method_router<S, S2, B>(r: MethodRouter<S2, B>) -> MethodRouter<S, B>
where
B: Send + 'static,
S: 'static,
S2: 'static,
{
r.downcast_state().unwrap()
}
(fallback, &cast_method_router as _)
}
};
for (id, route) in routes {
let path = node
.route_id_to_path
.get(&id)
.expect("no path for route id. This is a bug in axum. Please file an issue");
self = match route {
Endpoint::MethodRouter(method_router) => self.route(
path,
method_router
// this will set the state for each route
// such we don't override the inner state later in `MethodRouterWithState`
.layer(Extension(Arc::clone(&state)))
.downcast_state(),
),
Endpoint::MethodRouter(method_router) => {
self.route(path, cast_method_router(method_router))
}
Endpoint::Route(route) => self.route_service(path, route),
};
}
@ -332,9 +406,9 @@ where
}
#[doc = include_str!("../docs/routing/layer.md")]
pub fn layer<L, NewReqBody>(self, layer: L) -> Router<S, NewReqBody>
pub fn layer<L, NewReqBody: 'static>(self, layer: L) -> Router<S, NewReqBody>
where
L: Layer<Route<B>>,
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
@ -352,7 +426,7 @@ where
.map(|(id, route)| {
let route = match route {
Endpoint::MethodRouter(method_router) => {
Endpoint::MethodRouter(method_router.layer(&layer))
Endpoint::MethodRouter(method_router.layer(layer.clone()))
}
Endpoint::Route(route) => Endpoint::Route(Route::new(layer.layer(route))),
};
@ -360,7 +434,7 @@ where
})
.collect();
let fallback = self.fallback.map(|svc| Route::new(layer.layer(svc)));
let fallback = self.fallback.map(move |svc| Route::new(layer.layer(svc)));
Router {
state: self.state,
@ -374,7 +448,7 @@ where
#[track_caller]
pub fn route_layer<L>(self, layer: L) -> Self
where
L: Layer<Route<B>>,
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<B>> + Clone + Send + 'static,
<L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
@ -399,7 +473,7 @@ where
.map(|(id, route)| {
let route = match route {
Endpoint::MethodRouter(method_router) => {
Endpoint::MethodRouter(method_router.layer(&layer))
Endpoint::MethodRouter(method_router.layer(layer.clone()))
}
Endpoint::Route(route) => Endpoint::Route(Route::new(layer.layer(route))),
};
@ -416,13 +490,13 @@ where
}
#[doc = include_str!("../docs/routing/fallback.md")]
pub fn fallback<H, T>(self, handler: H) -> Self
pub fn fallback<H, T>(mut self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: 'static,
{
let state = Arc::clone(&self.state);
self.fallback_service(handler.with_state_arc(state))
self.fallback = Fallback::BoxedHandler(BoxedHandler::new(handler));
self
}
/// Add a fallback [`Service`] to the router.
@ -434,7 +508,7 @@ where
T::Response: IntoResponse,
T::Future: Send + 'static,
{
self.fallback = Fallback::Custom(Route::new(svc));
self.fallback = Fallback::Service(Route::new(svc));
self
}
@ -478,11 +552,6 @@ where
) -> IntoMakeServiceWithConnectInfo<RouterService<B>, C> {
IntoMakeServiceWithConnectInfo::new(self.into_service())
}
/// Get a reference to the state.
pub fn state(&self) -> &S {
&self.state
}
}
/// Wrapper around `matchit::Router` that supports merging two `Router`s.
@ -526,12 +595,39 @@ impl fmt::Debug for Node {
}
}
enum Fallback<B, E = Infallible> {
enum Fallback<S, B, E = Infallible> {
Default(Route<B, E>),
Custom(Route<B, E>),
Service(Route<B, E>),
BoxedHandler(BoxedHandler<S, B, E>),
}
impl<B, E> Fallback<B, E> {
impl<S, B, E> Fallback<S, B, E> {
fn map_state<S2>(self, state: &Arc<S>) -> Fallback<S2, B, E> {
match self {
Self::Default(route) => Fallback::Default(route),
Self::Service(route) => Fallback::Service(route),
Self::BoxedHandler(handler) => Fallback::Service(handler.into_route(state.clone())),
}
}
fn downcast_state<S2>(self) -> Option<Fallback<S2, B, E>>
where
S: 'static,
B: 'static,
E: 'static,
S2: 'static,
{
match self {
Self::Default(route) => Some(Fallback::Default(route)),
Self::Service(route) => Some(Fallback::Service(route)),
Self::BoxedHandler(handler) => {
try_downcast::<BoxedHandler<S2, B, E>, BoxedHandler<S, B, E>>(handler)
.map(Fallback::BoxedHandler)
.ok()
}
}
}
fn merge(self, other: Self) -> Option<Self> {
match (self, other) {
(Self::Default(_), pick @ Self::Default(_)) => Some(pick),
@ -541,32 +637,40 @@ impl<B, E> Fallback<B, E> {
}
}
impl<B, E> Clone for Fallback<B, E> {
impl<S, B, E> Clone for Fallback<S, B, E> {
fn clone(&self) -> Self {
match self {
Self::Default(inner) => Self::Default(inner.clone()),
Self::Custom(inner) => Self::Custom(inner.clone()),
Self::Service(inner) => Self::Service(inner.clone()),
Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
}
}
}
impl<B, E> fmt::Debug for Fallback<B, E> {
impl<S, B, E> fmt::Debug for Fallback<S, B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
Self::Custom(inner) => f.debug_tuple("Custom").field(inner).finish(),
Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(),
Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
}
}
}
impl<B, E> Fallback<B, E> {
fn map<F, B2, E2>(self, f: F) -> Fallback<B2, E2>
impl<S, B, E> Fallback<S, B, E> {
fn map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2>
where
F: FnOnce(Route<B, E>) -> Route<B2, E2>,
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: 'static,
E2: 'static,
{
match self {
Self::Default(inner) => Fallback::Default(f(inner)),
Self::Custom(inner) => Fallback::Custom(f(inner)),
Self::Service(inner) => Fallback::Service(f(inner)),
Self::BoxedHandler(inner) => Fallback::BoxedHandler(inner.map(f)),
}
}
}

View file

@ -29,7 +29,7 @@ use tower_service::Service;
pub struct Route<B = Body, E = Infallible>(BoxCloneService<Request<B>, Response, E>);
impl<B, E> Route<B, E> {
pub(super) fn new<T>(svc: T) -> Self
pub(crate) fn new<T>(svc: T) -> Self
where
T: Service<Request<B>, Error = E> + Clone + Send + 'static,
T::Response: IntoResponse + 'static,

View file

@ -30,17 +30,22 @@ impl<B> RouterService<B>
where
B: HttpBody + Send + 'static,
{
#[track_caller]
pub(super) fn new<S>(router: Router<S, B>) -> Self
where
S: Send + Sync + 'static,
{
let state = router
.state
.expect("Can't turn a `Router` that wants to inherit state into a service");
let routes = router
.routes
.into_iter()
.map(|(route_id, endpoint)| {
let route = match endpoint {
Endpoint::MethodRouter(method_router) => {
Route::new(method_router.with_state_arc(Arc::clone(&router.state)))
Route::new(method_router.with_state_arc(Arc::clone(&state)))
}
Endpoint::Route(route) => route,
};
@ -54,7 +59,8 @@ where
node: router.node,
fallback: match router.fallback {
Fallback::Default(route) => route,
Fallback::Custom(route) => route,
Fallback::Service(route) => route,
Fallback::BoxedHandler(handler) => handler.into_route(state),
},
}
}

View file

@ -18,10 +18,7 @@ async fn basic() {
#[tokio::test]
async fn nest() {
let app = Router::new()
.nest(
"/foo",
Router::new().route("/bar", get(|| async {})).into_service(),
)
.nest("/foo", Router::new().route("/bar", get(|| async {})))
.fallback(|| async { "fallback" });
let client = TestClient::new(app);

View file

@ -82,7 +82,7 @@ async fn nested_or() {
assert_eq!(client.get("/bar").send().await.text().await, "bar");
assert_eq!(client.get("/baz").send().await.text().await, "baz");
let client = TestClient::new(Router::new().nest("/foo", bar_or_baz.into_service()));
let client = TestClient::new(Router::new().nest("/foo", bar_or_baz));
assert_eq!(client.get("/foo/bar").send().await.text().await, "bar");
assert_eq!(client.get("/foo/baz").send().await.text().await, "baz");
}
@ -145,10 +145,7 @@ async fn layer_and_handle_error() {
#[tokio::test]
async fn nesting() {
let one = Router::new().route("/foo", get(|| async {}));
let two = Router::new().nest(
"/bar",
Router::new().route("/baz", get(|| async {})).into_service(),
);
let two = Router::new().nest("/bar", Router::new().route("/baz", get(|| async {})));
let app = one.merge(two);
let client = TestClient::new(app);
@ -232,12 +229,7 @@ async fn all_the_uris(
#[tokio::test]
async fn nesting_and_seeing_the_right_uri() {
let one = Router::new().nest(
"/foo/",
Router::new()
.route("/bar", get(all_the_uris))
.into_service(),
);
let one = Router::new().nest("/foo/", Router::new().route("/bar", get(all_the_uris)));
let two = Router::new().route("/foo", get(all_the_uris));
let client = TestClient::new(one.merge(two));
@ -269,14 +261,7 @@ async fn nesting_and_seeing_the_right_uri() {
async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() {
let one = Router::new().nest(
"/foo/",
Router::new()
.nest(
"/bar",
Router::new()
.route("/baz", get(all_the_uris))
.into_service(),
)
.into_service(),
Router::new().nest("/bar", Router::new().route("/baz", get(all_the_uris))),
);
let two = Router::new().route("/foo", get(all_the_uris));
@ -309,21 +294,9 @@ async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() {
async fn nesting_and_seeing_the_right_uri_ors_with_nesting() {
let one = Router::new().nest(
"/one",
Router::new()
.nest(
"/bar",
Router::new()
.route("/baz", get(all_the_uris))
.into_service(),
)
.into_service(),
);
let two = Router::new().nest(
"/two",
Router::new()
.route("/qux", get(all_the_uris))
.into_service(),
Router::new().nest("/bar", Router::new().route("/baz", get(all_the_uris))),
);
let two = Router::new().nest("/two", Router::new().route("/qux", get(all_the_uris)));
let three = Router::new().route("/three", get(all_the_uris));
let client = TestClient::new(one.merge(two).merge(three));
@ -366,14 +339,7 @@ async fn nesting_and_seeing_the_right_uri_ors_with_nesting() {
async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() {
let one = Router::new().nest(
"/one",
Router::new()
.nest(
"/foo",
Router::new()
.route("/bar", get(all_the_uris))
.into_service(),
)
.into_service(),
Router::new().nest("/foo", Router::new().route("/bar", get(all_the_uris))),
);
let two = Router::new().route("/two/foo", get(all_the_uris));
@ -500,3 +466,18 @@ async fn merging_routes_different_paths_different_states() {
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "bar state");
}
#[tokio::test]
async fn inherit_state_via_merge() {
let foo = Router::inherit_state().route(
"/foo",
get(|State(state): State<&'static str>| async move { state }),
);
let app = Router::with_state("state").merge(foo);
let client = TestClient::new(app);
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "state");
}

View file

@ -19,9 +19,7 @@ use std::{
task::{Context, Poll},
time::Duration,
};
use tower::{
service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder, ServiceExt,
};
use tower::{service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder};
use tower_http::{auth::RequireAuthorizationLayer, limit::RequestBodyLimitLayer};
use tower_service::Service;

View file

@ -37,7 +37,7 @@ async fn nesting_apps() {
let app = Router::new()
.route("/", get(|| async { "hi" }))
.nest("/:version/api", api_routes.into_service());
.nest("/:version/api", api_routes);
let client = TestClient::new(app);
@ -61,7 +61,7 @@ async fn nesting_apps() {
#[tokio::test]
async fn wrong_method_nest() {
let nested_app = Router::new().route("/", get(|| async {}));
let app = Router::new().nest("/", nested_app.into_service());
let app = Router::new().nest("/", nested_app);
let client = TestClient::new(app);
@ -78,7 +78,7 @@ async fn wrong_method_nest() {
#[tokio::test]
async fn nesting_router_at_root() {
let nested = Router::new().route("/foo", get(|uri: Uri| async move { uri.to_string() }));
let app = Router::new().nest("/", nested.into_service());
let app = Router::new().nest("/", nested);
let client = TestClient::new(app);
@ -96,7 +96,7 @@ async fn nesting_router_at_root() {
#[tokio::test]
async fn nesting_router_at_empty_path() {
let nested = Router::new().route("/foo", get(|uri: Uri| async move { uri.to_string() }));
let app = Router::new().nest("", nested.into_service());
let app = Router::new().nest("", nested);
let client = TestClient::new(app);
@ -113,7 +113,7 @@ async fn nesting_router_at_empty_path() {
#[tokio::test]
async fn nesting_handler_at_root() {
let app = Router::new().nest("/", get(|uri: Uri| async move { uri.to_string() }));
let app = Router::new().nest_service("/", get(|uri: Uri| async move { uri.to_string() }));
let client = TestClient::new(app);
@ -134,18 +134,15 @@ async fn nesting_handler_at_root() {
async fn nested_url_extractor() {
let app = Router::new().nest(
"/foo",
Router::new()
.nest(
"/bar",
Router::new()
.route("/baz", get(|uri: Uri| async move { uri.to_string() }))
.route(
"/qux",
get(|req: Request<Body>| async move { req.uri().to_string() }),
)
.into_service(),
)
.into_service(),
Router::new().nest(
"/bar",
Router::new()
.route("/baz", get(|uri: Uri| async move { uri.to_string() }))
.route(
"/qux",
get(|req: Request<Body>| async move { req.uri().to_string() }),
),
),
);
let client = TestClient::new(app);
@ -163,17 +160,13 @@ async fn nested_url_extractor() {
async fn nested_url_original_extractor() {
let app = Router::new().nest(
"/foo",
Router::new()
.nest(
"/bar",
Router::new()
.route(
"/baz",
get(|uri: extract::OriginalUri| async move { uri.0.to_string() }),
)
.into_service(),
)
.into_service(),
Router::new().nest(
"/bar",
Router::new().route(
"/baz",
get(|uri: extract::OriginalUri| async move { uri.0.to_string() }),
),
),
);
let client = TestClient::new(app);
@ -187,20 +180,16 @@ async fn nested_url_original_extractor() {
async fn nested_service_sees_stripped_uri() {
let app = Router::new().nest(
"/foo",
Router::new()
.nest(
"/bar",
Router::new()
.route_service(
"/baz",
service_fn(|req: Request<Body>| async move {
let body = boxed(Body::from(req.uri().to_string()));
Ok::<_, Infallible>(Response::new(body))
}),
)
.into_service(),
)
.into_service(),
Router::new().nest(
"/bar",
Router::new().route_service(
"/baz",
service_fn(|req: Request<Body>| async move {
let body = boxed(Body::from(req.uri().to_string()));
Ok::<_, Infallible>(Response::new(body))
}),
),
),
);
let client = TestClient::new(app);
@ -212,7 +201,7 @@ async fn nested_service_sees_stripped_uri() {
#[tokio::test]
async fn nest_static_file_server() {
let app = Router::new().nest(
let app = Router::new().nest_service(
"/static",
get_service(ServeDir::new(".")).handle_error(|error| async move {
(
@ -235,8 +224,7 @@ async fn nested_multiple_routes() {
"/api",
Router::new()
.route("/users", get(|| async { "users" }))
.route("/teams", get(|| async { "teams" }))
.into_service(),
.route("/teams", get(|| async { "teams" })),
)
.route("/", get(|| async { "root" }));
@ -251,12 +239,7 @@ async fn nested_multiple_routes() {
#[should_panic = "Invalid route \"/\": insertion failed due to conflict with previously registered route: /*__private__axum_nest_tail_param"]
fn nested_at_root_with_other_routes() {
let _: Router = Router::new()
.nest(
"/",
Router::new()
.route("/users", get(|| async {}))
.into_service(),
)
.nest("/", Router::new().route("/users", get(|| async {})))
.route("/", get(|| async {}));
}
@ -265,15 +248,11 @@ async fn multiple_top_level_nests() {
let app = Router::new()
.nest(
"/one",
Router::new()
.route("/route", get(|| async { "one" }))
.into_service(),
Router::new().route("/route", get(|| async { "one" })),
)
.nest(
"/two",
Router::new()
.route("/route", get(|| async { "two" }))
.into_service(),
Router::new().route("/route", get(|| async { "two" })),
);
let client = TestClient::new(app);
@ -285,7 +264,7 @@ async fn multiple_top_level_nests() {
#[tokio::test]
#[should_panic(expected = "Invalid route: nested routes cannot contain wildcards (*)")]
async fn nest_cannot_contain_wildcards() {
Router::<_, Body>::new().nest("/one/*rest", Router::new().into_service());
Router::<_, Body>::new().nest("/one/*rest", Router::new());
}
#[tokio::test]
@ -323,10 +302,7 @@ async fn outer_middleware_still_see_whole_url() {
.route("/", get(handler))
.route("/foo", get(handler))
.route("/foo/bar", get(handler))
.nest(
"/one",
Router::new().route("/two", get(handler)).into_service(),
)
.nest("/one", Router::new().route("/two", get(handler)))
.fallback(handler)
.layer(tower::layer::layer_fn(SetUriExtension));
@ -344,13 +320,10 @@ async fn outer_middleware_still_see_whole_url() {
#[tokio::test]
async fn nest_at_capture() {
let api_routes = Router::new()
.route(
"/:b",
get(|Path((a, b)): Path<(String, String)>| async move { format!("a={} b={}", a, b) }),
)
.into_service()
.boxed_clone();
let api_routes = Router::new().route(
"/:b",
get(|Path((a, b)): Path<(String, String)>| async move { format!("a={} b={}", a, b) }),
);
let app = Router::new().nest("/:a", api_routes);
@ -363,7 +336,7 @@ async fn nest_at_capture() {
#[tokio::test]
async fn nest_with_and_without_trailing() {
let app = Router::new().nest("/foo", get(|| async {}));
let app = Router::new().nest_service("/foo", get(|| async {}));
let client = TestClient::new(app);
@ -380,10 +353,7 @@ async fn nest_with_and_without_trailing() {
#[tokio::test]
async fn doesnt_call_outer_fallback() {
let app = Router::new()
.nest(
"/foo",
Router::new().route("/", get(|| async {})).into_service(),
)
.nest("/foo", Router::new().route("/", get(|| async {})))
.fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") });
let client = TestClient::new(app);
@ -401,9 +371,7 @@ async fn doesnt_call_outer_fallback() {
async fn nesting_with_root_inner_router() {
let app = Router::new().nest(
"/foo",
Router::new()
.route("/", get(|| async { "inner route" }))
.into_service(),
Router::new().route("/", get(|| async { "inner route" })),
);
let client = TestClient::new(app);
@ -426,8 +394,7 @@ async fn fallback_on_inner() {
"/foo",
Router::new()
.route("/", get(|| async {}))
.fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") })
.into_service(),
.fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }),
)
.fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") });
@ -451,7 +418,7 @@ macro_rules! nested_route_test {
#[tokio::test]
async fn $name() {
let inner = Router::new().route($route_path, get(|| async {}));
let app = Router::new().nest($nested_path, inner.into_service());
let app = Router::new().nest($nested_path, inner);
let client = TestClient::new(app);
let res = client.get($expected_path).send().await;
let status = res.status();
@ -486,7 +453,7 @@ async fn nesting_with_different_state() {
"/foo",
get(|State(state): State<&'static str>| async move { state }),
)
.nest("/nested", inner.into_service())
.nest("/nested", inner)
.route(
"/bar",
get(|State(state): State<&'static str>| async move { state }),
@ -503,3 +470,18 @@ async fn nesting_with_different_state() {
let res = client.get("/bar").send().await;
assert_eq!(res.text().await, "outer");
}
#[tokio::test]
async fn inherit_state_via_nest() {
let foo = Router::inherit_state().route(
"/foo",
get(|State(state): State<&'static str>| async move { state }),
);
let app = Router::with_state("state").nest("/test", foo);
let client = TestClient::new(app);
let res = client.get("/test/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "state");
}

View file

@ -58,7 +58,7 @@ async fn main() {
)
.route("/keys", get(list_keys))
// Nest our admin routes under `/admin`
.nest("/admin", admin_routes(shared_state).into_service())
.nest("/admin", admin_routes(shared_state))
// Add middleware to all routes
.layer(
ServiceBuilder::new()

View file

@ -131,7 +131,7 @@ where
// to prevent directory traversal attacks we ensure the path consists of exactly one normal
// component
fn path_is_valid(path: &str) -> bool {
let path = std::path::Path::new(&*path);
let path = std::path::Path::new(path);
let mut components = path.components().peekable();
if let Some(first) = components.peek() {