mirror of
https://github.com/tokio-rs/axum.git
synced 2025-03-26 00:27:01 +01:00
Allow Routers to inherit state (#1368)
* Rename Fallback::Custom to Fallback::Service * Allow Routers to inherit state * Rename Router::{nest => nest_service} and add new nest method for Routers * Fix lints * Add basic tests for state inheritance * Changelog
This commit is contained in:
parent
2077d50021
commit
4847d681b1
17 changed files with 453 additions and 225 deletions
|
@ -149,7 +149,7 @@ impl<B, T, F> SpaRouter<B, T, F> {
|
|||
|
||||
impl<B, F, T> From<SpaRouter<B, T, F>> for Router<(), B>
|
||||
where
|
||||
F: Clone + Send + 'static,
|
||||
F: Clone + Send + Sync + 'static,
|
||||
HandleError<Route<B, io::Error>, F, T>: Service<Request<B>, Error = Infallible>,
|
||||
<HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send,
|
||||
<HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Future: Send,
|
||||
|
@ -161,7 +161,7 @@ where
|
|||
.handle_error(spa.handle_error.clone());
|
||||
|
||||
Router::new()
|
||||
.nest(&spa.paths.assets_path, assets_service)
|
||||
.nest_service(&spa.paths.assets_path, assets_service)
|
||||
.fallback_service(
|
||||
get_service(ServeFile::new(&spa.paths.index_file)).handle_error(spa.handle_error),
|
||||
)
|
||||
|
|
|
@ -22,6 +22,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
`map_request_with_state_arc` for transforming the request with an async
|
||||
function ([#1408])
|
||||
- **breaking:** `ContentLengthLimit` has been removed. `Use DefaultBodyLimit` instead ([#1400])
|
||||
- **changed:** `Router` no longer implements `Service`, call `.into_service()`
|
||||
on it to obtain a `RouterService` that does
|
||||
- **added:** Add `Router::inherit_state`, which creates a `Router` with an
|
||||
arbitrary state type without actually supplying the state; such a `Router`
|
||||
can't be turned into a service directly (`.into_service()` will panic), but
|
||||
can be nested or merged into a `Router` with the same state type
|
||||
- **changed:** `Router::nest` now only accepts `Router`s, the general-purpose
|
||||
`Service` nesting method has been renamed to `nest_service`
|
||||
|
||||
[#1371]: https://github.com/tokio-rs/axum/pull/1371
|
||||
[#1387]: https://github.com/tokio-rs/axum/pull/1387
|
||||
|
|
|
@ -16,10 +16,10 @@ let user_routes = Router::new().route("/:id", get(|| async {}));
|
|||
let team_routes = Router::new().route("/", post(|| async {}));
|
||||
|
||||
let api_routes = Router::new()
|
||||
.nest("/users", user_routes.into_service())
|
||||
.nest("/teams", team_routes.into_service());
|
||||
.nest("/users", user_routes)
|
||||
.nest("/teams", team_routes);
|
||||
|
||||
let app = Router::new().nest("/api", api_routes.into_service());
|
||||
let app = Router::new().nest("/api", api_routes);
|
||||
|
||||
// Our app now accepts
|
||||
// - GET /api/users/:id
|
||||
|
@ -58,7 +58,7 @@ async fn users_get(Path(params): Path<HashMap<String, String>>) {
|
|||
|
||||
let users_api = Router::new().route("/users/:id", get(users_get));
|
||||
|
||||
let app = Router::new().nest("/:version/api", users_api.into_service());
|
||||
let app = Router::new().nest("/:version/api", users_api);
|
||||
# async {
|
||||
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
# };
|
||||
|
@ -82,7 +82,7 @@ let app = Router::new()
|
|||
.route("/foo/*rest", get(|uri: Uri| async {
|
||||
// `uri` will contain `/foo`
|
||||
}))
|
||||
.nest("/bar", nested_router.into_service());
|
||||
.nest("/bar", nested_router);
|
||||
# async {
|
||||
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
# };
|
||||
|
@ -100,10 +100,10 @@ async fn fallback() -> (StatusCode, &'static str) {
|
|||
(StatusCode::NOT_FOUND, "Not Found")
|
||||
}
|
||||
|
||||
let api_routes = Router::new().nest("/users", get(|| async {}));
|
||||
let api_routes = Router::new().nest_service("/users", get(|| async {}));
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/api", api_routes.into_service())
|
||||
.nest("/api", api_routes)
|
||||
.fallback(fallback);
|
||||
# let _: Router = app;
|
||||
```
|
||||
|
@ -130,12 +130,12 @@ async fn api_fallback() -> (StatusCode, Json<Value>) {
|
|||
}
|
||||
|
||||
let api_routes = Router::new()
|
||||
.nest("/users", get(|| async {}))
|
||||
.nest_service("/users", get(|| async {}))
|
||||
// add dedicated fallback for requests starting with `/api`
|
||||
.fallback(api_fallback);
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/api", api_routes.into_service())
|
||||
.nest("/api", api_routes)
|
||||
.fallback(fallback);
|
||||
# let _: Router = app;
|
||||
```
|
||||
|
|
|
@ -148,7 +148,7 @@ mod tests {
|
|||
"/:key",
|
||||
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
|
||||
)
|
||||
.nest("/api", api.into_service())
|
||||
.nest("/api", api)
|
||||
.nest(
|
||||
"/public",
|
||||
Router::new()
|
||||
|
@ -156,10 +156,9 @@ mod tests {
|
|||
// have to set the middleware here since otherwise the
|
||||
// matched path is just `/public/*` since we're nesting
|
||||
// this router
|
||||
.layer(layer_fn(SetMatchedPathExtension))
|
||||
.into_service(),
|
||||
.layer(layer_fn(SetMatchedPathExtension)),
|
||||
)
|
||||
.nest("/foo", handler.into_service())
|
||||
.nest_service("/foo", handler.into_service())
|
||||
.layer(layer_fn(SetMatchedPathExtension));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
@ -198,12 +197,10 @@ mod tests {
|
|||
async fn nested_opaque_routers_append_to_matched_path() {
|
||||
let app = Router::new().nest(
|
||||
"/:a",
|
||||
Router::new()
|
||||
.route(
|
||||
"/:b",
|
||||
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
|
||||
)
|
||||
.into_service(),
|
||||
Router::new().route(
|
||||
"/:b",
|
||||
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
|
||||
),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
|
|
@ -38,7 +38,7 @@ use sync_wrapper::SyncWrapper;
|
|||
/// }),
|
||||
/// );
|
||||
///
|
||||
/// let app = Router::new().nest("/api", api_routes.into_service());
|
||||
/// let app = Router::new().nest("/api", api_routes);
|
||||
/// # async {
|
||||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
|
@ -75,7 +75,7 @@ use sync_wrapper::SyncWrapper;
|
|||
/// }),
|
||||
/// );
|
||||
///
|
||||
/// let app = Router::new().nest("/api", api_routes.into_service());
|
||||
/// let app = Router::new().nest("/api", api_routes);
|
||||
/// # async {
|
||||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
|
|
122
axum/src/handler/boxed.rs
Normal file
122
axum/src/handler/boxed.rs
Normal file
|
@ -0,0 +1,122 @@
|
|||
use std::{convert::Infallible, sync::Arc};
|
||||
|
||||
use super::Handler;
|
||||
use crate::routing::Route;
|
||||
|
||||
pub(crate) struct BoxedHandler<S, B, E = Infallible>(Box<dyn ErasedHandler<S, B, E>>);
|
||||
|
||||
impl<S, B> BoxedHandler<S, B>
|
||||
where
|
||||
S: Send + Sync + 'static,
|
||||
B: Send + 'static,
|
||||
{
|
||||
pub(crate) fn new<H, T>(handler: H) -> Self
|
||||
where
|
||||
H: Handler<T, S, B>,
|
||||
T: 'static,
|
||||
{
|
||||
Self(Box::new(MakeErasedHandler {
|
||||
handler,
|
||||
into_route: |handler, state| Route::new(Handler::with_state_arc(handler, state)),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, B, E> BoxedHandler<S, B, E> {
|
||||
pub(crate) fn map<F, B2, E2>(self, f: F) -> BoxedHandler<S, B2, E2>
|
||||
where
|
||||
S: 'static,
|
||||
B: 'static,
|
||||
E: 'static,
|
||||
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
|
||||
B2: 'static,
|
||||
E2: 'static,
|
||||
{
|
||||
BoxedHandler(Box::new(Map {
|
||||
handler: self.0,
|
||||
layer: Box::new(f),
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn into_route(self, state: Arc<S>) -> Route<B, E> {
|
||||
self.0.into_route(state)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, B, E> Clone for BoxedHandler<S, B, E> {
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.clone_box())
|
||||
}
|
||||
}
|
||||
|
||||
trait ErasedHandler<S, B, E = Infallible>: Send {
|
||||
fn clone_box(&self) -> Box<dyn ErasedHandler<S, B, E>>;
|
||||
fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B, E>;
|
||||
}
|
||||
|
||||
struct MakeErasedHandler<H, S, B> {
|
||||
handler: H,
|
||||
into_route: fn(H, Arc<S>) -> Route<B>,
|
||||
}
|
||||
|
||||
impl<H, S, B> ErasedHandler<S, B> for MakeErasedHandler<H, S, B>
|
||||
where
|
||||
H: Clone + Send + 'static,
|
||||
S: 'static,
|
||||
B: 'static,
|
||||
{
|
||||
fn clone_box(&self) -> Box<dyn ErasedHandler<S, B>> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
|
||||
fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B> {
|
||||
(self.into_route)(self.handler, state)
|
||||
}
|
||||
}
|
||||
|
||||
impl<H: Clone, S, B> Clone for MakeErasedHandler<H, S, B> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
handler: self.handler.clone(),
|
||||
into_route: self.into_route,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Map<S, B, E, B2, E2> {
|
||||
handler: Box<dyn ErasedHandler<S, B, E>>,
|
||||
layer: Box<dyn LayerFn<B, E, B2, E2>>,
|
||||
}
|
||||
|
||||
impl<S, B, E, B2, E2> ErasedHandler<S, B2, E2> for Map<S, B, E, B2, E2>
|
||||
where
|
||||
S: 'static,
|
||||
B: 'static,
|
||||
E: 'static,
|
||||
B2: 'static,
|
||||
E2: 'static,
|
||||
{
|
||||
fn clone_box(&self) -> Box<dyn ErasedHandler<S, B2, E2>> {
|
||||
Box::new(Self {
|
||||
handler: self.handler.clone_box(),
|
||||
layer: self.layer.clone_box(),
|
||||
})
|
||||
}
|
||||
|
||||
fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B2, E2> {
|
||||
(self.layer)(self.handler.into_route(state))
|
||||
}
|
||||
}
|
||||
|
||||
trait LayerFn<B, E, B2, E2>: FnOnce(Route<B, E>) -> Route<B2, E2> + Send {
|
||||
fn clone_box(&self) -> Box<dyn LayerFn<B, E, B2, E2>>;
|
||||
}
|
||||
|
||||
impl<F, B, E, B2, E2> LayerFn<B, E, B2, E2> for F
|
||||
where
|
||||
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
|
||||
{
|
||||
fn clone_box(&self) -> Box<dyn LayerFn<B, E, B2, E2>> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
|
@ -47,12 +47,15 @@ use tower::ServiceExt;
|
|||
use tower_layer::Layer;
|
||||
use tower_service::Service;
|
||||
|
||||
mod boxed;
|
||||
pub mod future;
|
||||
mod into_service;
|
||||
mod into_service_state_in_extension;
|
||||
mod with_state;
|
||||
|
||||
pub(crate) use self::into_service_state_in_extension::IntoServiceStateInExtension;
|
||||
pub(crate) use self::{
|
||||
boxed::BoxedHandler, into_service_state_in_extension::IntoServiceStateInExtension,
|
||||
};
|
||||
pub use self::{into_service::IntoService, with_state::WithState};
|
||||
|
||||
/// Trait for async functions that can be used to handle requests.
|
||||
|
|
|
@ -15,7 +15,6 @@ use bytes::BytesMut;
|
|||
use std::{
|
||||
convert::Infallible,
|
||||
fmt,
|
||||
marker::PhantomData,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
@ -521,9 +520,8 @@ pub struct MethodRouter<S = (), B = Body, E = Infallible> {
|
|||
post: Option<Route<B, E>>,
|
||||
put: Option<Route<B, E>>,
|
||||
trace: Option<Route<B, E>>,
|
||||
fallback: Fallback<B, E>,
|
||||
fallback: Fallback<S, B, E>,
|
||||
allow_header: AllowHeader,
|
||||
_marker: PhantomData<fn() -> S>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -720,7 +718,6 @@ where
|
|||
trace: None,
|
||||
allow_header: AllowHeader::None,
|
||||
fallback: Fallback::Default(fallback),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -741,7 +738,12 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn downcast_state<S2>(self) -> MethodRouter<S2, B, E> {
|
||||
pub(crate) fn map_state<S2>(self, state: &Arc<S>) -> MethodRouter<S2, B, E>
|
||||
where
|
||||
E: 'static,
|
||||
S: 'static,
|
||||
S2: 'static,
|
||||
{
|
||||
MethodRouter {
|
||||
get: self.get,
|
||||
head: self.head,
|
||||
|
@ -751,12 +753,31 @@ where
|
|||
post: self.post,
|
||||
put: self.put,
|
||||
trace: self.trace,
|
||||
fallback: self.fallback,
|
||||
fallback: self.fallback.map_state(state),
|
||||
allow_header: self.allow_header,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn downcast_state<S2>(self) -> Option<MethodRouter<S2, B, E>>
|
||||
where
|
||||
E: 'static,
|
||||
S: 'static,
|
||||
S2: 'static,
|
||||
{
|
||||
Some(MethodRouter {
|
||||
get: self.get,
|
||||
head: self.head,
|
||||
delete: self.delete,
|
||||
options: self.options,
|
||||
patch: self.patch,
|
||||
post: self.post,
|
||||
put: self.put,
|
||||
trace: self.trace,
|
||||
fallback: self.fallback.downcast_state()?,
|
||||
allow_header: self.allow_header,
|
||||
})
|
||||
}
|
||||
|
||||
/// Chain an additional service that will accept requests matching the given
|
||||
/// `MethodFilter`.
|
||||
///
|
||||
|
@ -808,7 +829,7 @@ where
|
|||
T::Response: IntoResponse + 'static,
|
||||
T::Future: Send + 'static,
|
||||
{
|
||||
self.fallback = Fallback::Custom(Route::new(svc));
|
||||
self.fallback = Fallback::Service(Route::new(svc));
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -818,36 +839,40 @@ where
|
|||
T::Response: IntoResponse + 'static,
|
||||
T::Future: Send + 'static,
|
||||
{
|
||||
self.fallback = Fallback::Custom(Route::new(svc));
|
||||
self.fallback = Fallback::Service(Route::new(svc));
|
||||
self
|
||||
}
|
||||
|
||||
#[doc = include_str!("../docs/method_routing/layer.md")]
|
||||
pub fn layer<L, NewReqBody, NewError>(self, layer: L) -> MethodRouter<S, NewReqBody, NewError>
|
||||
pub fn layer<L, NewReqBody: 'static, NewError: 'static>(
|
||||
self,
|
||||
layer: L,
|
||||
) -> MethodRouter<S, NewReqBody, NewError>
|
||||
where
|
||||
L: Layer<Route<B, E>>,
|
||||
L: Layer<Route<B, E>> + Clone + Send + 'static,
|
||||
L::Service: Service<Request<NewReqBody>, Error = NewError> + Clone + Send + 'static,
|
||||
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
|
||||
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
|
||||
E: 'static,
|
||||
S: 'static,
|
||||
{
|
||||
let layer_fn = |svc| {
|
||||
let layer_fn = move |svc| {
|
||||
let svc = layer.layer(svc);
|
||||
let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
|
||||
Route::new(svc)
|
||||
};
|
||||
|
||||
MethodRouter {
|
||||
get: self.get.map(layer_fn),
|
||||
head: self.head.map(layer_fn),
|
||||
delete: self.delete.map(layer_fn),
|
||||
options: self.options.map(layer_fn),
|
||||
patch: self.patch.map(layer_fn),
|
||||
post: self.post.map(layer_fn),
|
||||
put: self.put.map(layer_fn),
|
||||
trace: self.trace.map(layer_fn),
|
||||
get: self.get.map(layer_fn.clone()),
|
||||
head: self.head.map(layer_fn.clone()),
|
||||
delete: self.delete.map(layer_fn.clone()),
|
||||
options: self.options.map(layer_fn.clone()),
|
||||
patch: self.patch.map(layer_fn.clone()),
|
||||
post: self.post.map(layer_fn.clone()),
|
||||
put: self.put.map(layer_fn.clone()),
|
||||
trace: self.trace.map(layer_fn.clone()),
|
||||
fallback: self.fallback.map(layer_fn),
|
||||
allow_header: self.allow_header,
|
||||
_marker: self._marker,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -952,13 +977,14 @@ where
|
|||
/// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`.
|
||||
pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, B, Infallible>
|
||||
where
|
||||
F: Clone + Send + 'static,
|
||||
F: Clone + Send + Sync + 'static,
|
||||
HandleError<Route<B, E>, F, T>: Service<Request<B>, Error = Infallible>,
|
||||
<HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Future: Send,
|
||||
<HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send,
|
||||
T: 'static,
|
||||
E: 'static,
|
||||
B: 'static,
|
||||
S: 'static,
|
||||
{
|
||||
self.layer(HandleErrorLayer::new(f))
|
||||
}
|
||||
|
@ -1136,7 +1162,6 @@ impl<S, B, E> Clone for MethodRouter<S, B, E> {
|
|||
trace: self.trace.clone(),
|
||||
fallback: self.fallback.clone(),
|
||||
allow_header: self.allow_header.clone(),
|
||||
_marker: self._marker,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1211,7 +1236,7 @@ where
|
|||
|
||||
impl<S, B, E> Service<Request<B>> for WithState<S, B, E>
|
||||
where
|
||||
B: HttpBody,
|
||||
B: HttpBody + Send,
|
||||
S: Send + Sync + 'static,
|
||||
{
|
||||
type Response = Response;
|
||||
|
@ -1257,7 +1282,6 @@ where
|
|||
trace,
|
||||
fallback,
|
||||
allow_header,
|
||||
_marker: _,
|
||||
},
|
||||
} = self;
|
||||
|
||||
|
@ -1276,8 +1300,14 @@ where
|
|||
let future = match fallback {
|
||||
Fallback::Default(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
|
||||
.strip_body(method == Method::HEAD),
|
||||
Fallback::Custom(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
|
||||
Fallback::Service(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
|
||||
.strip_body(method == Method::HEAD),
|
||||
Fallback::BoxedHandler(fallback) => RouteFuture::from_future(
|
||||
fallback
|
||||
.clone()
|
||||
.into_route(Arc::clone(state))
|
||||
.oneshot_inner(req),
|
||||
),
|
||||
};
|
||||
|
||||
match allow_header {
|
||||
|
|
|
@ -4,14 +4,20 @@ use self::not_found::NotFound;
|
|||
use crate::{
|
||||
body::{Body, HttpBody},
|
||||
extract::connect_info::IntoMakeServiceWithConnectInfo,
|
||||
handler::Handler,
|
||||
handler::{BoxedHandler, Handler},
|
||||
util::try_downcast,
|
||||
Extension,
|
||||
};
|
||||
use axum_core::response::IntoResponse;
|
||||
use http::Request;
|
||||
use matchit::MatchError;
|
||||
use std::{collections::HashMap, convert::Infallible, fmt, sync::Arc};
|
||||
use std::{
|
||||
any::{type_name, TypeId},
|
||||
collections::HashMap,
|
||||
convert::Infallible,
|
||||
fmt,
|
||||
sync::Arc,
|
||||
};
|
||||
use tower::{util::MapResponseLayer, ServiceBuilder};
|
||||
use tower_layer::Layer;
|
||||
use tower_service::Service;
|
||||
|
@ -59,16 +65,16 @@ impl RouteId {
|
|||
|
||||
/// The router type for composing handlers and services.
|
||||
pub struct Router<S = (), B = Body> {
|
||||
state: Arc<S>,
|
||||
state: Option<Arc<S>>,
|
||||
routes: HashMap<RouteId, Endpoint<S, B>>,
|
||||
node: Arc<Node>,
|
||||
fallback: Fallback<B>,
|
||||
fallback: Fallback<S, B>,
|
||||
}
|
||||
|
||||
impl<S, B> Clone for Router<S, B> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
state: Arc::clone(&self.state),
|
||||
state: self.state.clone(),
|
||||
routes: self.routes.clone(),
|
||||
node: Arc::clone(&self.node),
|
||||
fallback: self.fallback.clone(),
|
||||
|
@ -162,7 +168,18 @@ where
|
|||
/// [`State`]: crate::extract::State
|
||||
pub fn with_state_arc(state: Arc<S>) -> Self {
|
||||
Self {
|
||||
state,
|
||||
state: Some(state),
|
||||
routes: Default::default(),
|
||||
node: Default::default(),
|
||||
fallback: Fallback::Default(Route::new(NotFound)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new `Router` that inherits its state from another `Router` that it is merged into
|
||||
/// or nested under.
|
||||
pub fn inherit_state() -> Self {
|
||||
Self {
|
||||
state: None,
|
||||
routes: Default::default(),
|
||||
node: Default::default(),
|
||||
fallback: Fallback::Default(Route::new(NotFound)),
|
||||
|
@ -253,7 +270,29 @@ where
|
|||
|
||||
#[doc = include_str!("../docs/routing/nest.md")]
|
||||
#[track_caller]
|
||||
pub fn nest<T>(mut self, mut path: &str, svc: T) -> Self
|
||||
pub fn nest<S2>(self, path: &str, mut router: Router<S2, B>) -> Self
|
||||
where
|
||||
S2: Send + Sync + 'static,
|
||||
{
|
||||
if router.state.is_none() {
|
||||
let s = self.state.clone();
|
||||
router.state = match try_downcast::<Option<Arc<S2>>, Option<Arc<S>>>(s) {
|
||||
Ok(state) => state,
|
||||
Err(_) => panic!(
|
||||
"can't nest a `Router` that wants to inherit state of type `{}` \
|
||||
into a `Router` with a state type of `{}`",
|
||||
type_name::<S2>(),
|
||||
type_name::<S>(),
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
self.nest_service(path, router.into_service())
|
||||
}
|
||||
|
||||
/// Like [`nest`](Self::nest), but accepts an arbitrary `Service`.
|
||||
#[track_caller]
|
||||
pub fn nest_service<T>(mut self, mut path: &str, svc: T) -> Self
|
||||
where
|
||||
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
|
||||
T::Response: IntoResponse,
|
||||
|
@ -305,20 +344,55 @@ where
|
|||
fallback,
|
||||
} = other.into();
|
||||
|
||||
let cast_method_router_closure_slot;
|
||||
let (fallback, cast_method_router) = match state {
|
||||
// other has its state set
|
||||
Some(state) => {
|
||||
let fallback = fallback.map_state(&state);
|
||||
cast_method_router_closure_slot = move |r: MethodRouter<_, _>| {
|
||||
r.layer(Extension(Arc::clone(&state))).map_state(&state)
|
||||
};
|
||||
let cast_method_router = &cast_method_router_closure_slot
|
||||
as &dyn Fn(MethodRouter<_, _>) -> MethodRouter<_, _>;
|
||||
|
||||
(fallback, cast_method_router)
|
||||
}
|
||||
// other wants to inherit its state
|
||||
None => {
|
||||
if TypeId::of::<S>() != TypeId::of::<S2>() {
|
||||
panic!(
|
||||
"can't merge a `Router` that wants to inherit state of type `{}` \
|
||||
into a `Router` with a state type of `{}`",
|
||||
type_name::<S2>(),
|
||||
type_name::<S>(),
|
||||
);
|
||||
}
|
||||
|
||||
// With the branch above not taken, we know we can cast S2 to S
|
||||
let fallback = fallback.downcast_state::<S>().unwrap();
|
||||
|
||||
fn cast_method_router<S, S2, B>(r: MethodRouter<S2, B>) -> MethodRouter<S, B>
|
||||
where
|
||||
B: Send + 'static,
|
||||
S: 'static,
|
||||
S2: 'static,
|
||||
{
|
||||
r.downcast_state().unwrap()
|
||||
}
|
||||
|
||||
(fallback, &cast_method_router as _)
|
||||
}
|
||||
};
|
||||
|
||||
for (id, route) in routes {
|
||||
let path = node
|
||||
.route_id_to_path
|
||||
.get(&id)
|
||||
.expect("no path for route id. This is a bug in axum. Please file an issue");
|
||||
self = match route {
|
||||
Endpoint::MethodRouter(method_router) => self.route(
|
||||
path,
|
||||
method_router
|
||||
// this will set the state for each route
|
||||
// such we don't override the inner state later in `MethodRouterWithState`
|
||||
.layer(Extension(Arc::clone(&state)))
|
||||
.downcast_state(),
|
||||
),
|
||||
Endpoint::MethodRouter(method_router) => {
|
||||
self.route(path, cast_method_router(method_router))
|
||||
}
|
||||
Endpoint::Route(route) => self.route_service(path, route),
|
||||
};
|
||||
}
|
||||
|
@ -332,9 +406,9 @@ where
|
|||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/layer.md")]
|
||||
pub fn layer<L, NewReqBody>(self, layer: L) -> Router<S, NewReqBody>
|
||||
pub fn layer<L, NewReqBody: 'static>(self, layer: L) -> Router<S, NewReqBody>
|
||||
where
|
||||
L: Layer<Route<B>>,
|
||||
L: Layer<Route<B>> + Clone + Send + 'static,
|
||||
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
|
||||
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
|
||||
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
|
||||
|
@ -352,7 +426,7 @@ where
|
|||
.map(|(id, route)| {
|
||||
let route = match route {
|
||||
Endpoint::MethodRouter(method_router) => {
|
||||
Endpoint::MethodRouter(method_router.layer(&layer))
|
||||
Endpoint::MethodRouter(method_router.layer(layer.clone()))
|
||||
}
|
||||
Endpoint::Route(route) => Endpoint::Route(Route::new(layer.layer(route))),
|
||||
};
|
||||
|
@ -360,7 +434,7 @@ where
|
|||
})
|
||||
.collect();
|
||||
|
||||
let fallback = self.fallback.map(|svc| Route::new(layer.layer(svc)));
|
||||
let fallback = self.fallback.map(move |svc| Route::new(layer.layer(svc)));
|
||||
|
||||
Router {
|
||||
state: self.state,
|
||||
|
@ -374,7 +448,7 @@ where
|
|||
#[track_caller]
|
||||
pub fn route_layer<L>(self, layer: L) -> Self
|
||||
where
|
||||
L: Layer<Route<B>>,
|
||||
L: Layer<Route<B>> + Clone + Send + 'static,
|
||||
L::Service: Service<Request<B>> + Clone + Send + 'static,
|
||||
<L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
|
||||
<L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
|
||||
|
@ -399,7 +473,7 @@ where
|
|||
.map(|(id, route)| {
|
||||
let route = match route {
|
||||
Endpoint::MethodRouter(method_router) => {
|
||||
Endpoint::MethodRouter(method_router.layer(&layer))
|
||||
Endpoint::MethodRouter(method_router.layer(layer.clone()))
|
||||
}
|
||||
Endpoint::Route(route) => Endpoint::Route(Route::new(layer.layer(route))),
|
||||
};
|
||||
|
@ -416,13 +490,13 @@ where
|
|||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/fallback.md")]
|
||||
pub fn fallback<H, T>(self, handler: H) -> Self
|
||||
pub fn fallback<H, T>(mut self, handler: H) -> Self
|
||||
where
|
||||
H: Handler<T, S, B>,
|
||||
T: 'static,
|
||||
{
|
||||
let state = Arc::clone(&self.state);
|
||||
self.fallback_service(handler.with_state_arc(state))
|
||||
self.fallback = Fallback::BoxedHandler(BoxedHandler::new(handler));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a fallback [`Service`] to the router.
|
||||
|
@ -434,7 +508,7 @@ where
|
|||
T::Response: IntoResponse,
|
||||
T::Future: Send + 'static,
|
||||
{
|
||||
self.fallback = Fallback::Custom(Route::new(svc));
|
||||
self.fallback = Fallback::Service(Route::new(svc));
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -478,11 +552,6 @@ where
|
|||
) -> IntoMakeServiceWithConnectInfo<RouterService<B>, C> {
|
||||
IntoMakeServiceWithConnectInfo::new(self.into_service())
|
||||
}
|
||||
|
||||
/// Get a reference to the state.
|
||||
pub fn state(&self) -> &S {
|
||||
&self.state
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper around `matchit::Router` that supports merging two `Router`s.
|
||||
|
@ -526,12 +595,39 @@ impl fmt::Debug for Node {
|
|||
}
|
||||
}
|
||||
|
||||
enum Fallback<B, E = Infallible> {
|
||||
enum Fallback<S, B, E = Infallible> {
|
||||
Default(Route<B, E>),
|
||||
Custom(Route<B, E>),
|
||||
Service(Route<B, E>),
|
||||
BoxedHandler(BoxedHandler<S, B, E>),
|
||||
}
|
||||
|
||||
impl<B, E> Fallback<B, E> {
|
||||
impl<S, B, E> Fallback<S, B, E> {
|
||||
fn map_state<S2>(self, state: &Arc<S>) -> Fallback<S2, B, E> {
|
||||
match self {
|
||||
Self::Default(route) => Fallback::Default(route),
|
||||
Self::Service(route) => Fallback::Service(route),
|
||||
Self::BoxedHandler(handler) => Fallback::Service(handler.into_route(state.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
fn downcast_state<S2>(self) -> Option<Fallback<S2, B, E>>
|
||||
where
|
||||
S: 'static,
|
||||
B: 'static,
|
||||
E: 'static,
|
||||
S2: 'static,
|
||||
{
|
||||
match self {
|
||||
Self::Default(route) => Some(Fallback::Default(route)),
|
||||
Self::Service(route) => Some(Fallback::Service(route)),
|
||||
Self::BoxedHandler(handler) => {
|
||||
try_downcast::<BoxedHandler<S2, B, E>, BoxedHandler<S, B, E>>(handler)
|
||||
.map(Fallback::BoxedHandler)
|
||||
.ok()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn merge(self, other: Self) -> Option<Self> {
|
||||
match (self, other) {
|
||||
(Self::Default(_), pick @ Self::Default(_)) => Some(pick),
|
||||
|
@ -541,32 +637,40 @@ impl<B, E> Fallback<B, E> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B, E> Clone for Fallback<B, E> {
|
||||
impl<S, B, E> Clone for Fallback<S, B, E> {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
Self::Default(inner) => Self::Default(inner.clone()),
|
||||
Self::Custom(inner) => Self::Custom(inner.clone()),
|
||||
Self::Service(inner) => Self::Service(inner.clone()),
|
||||
Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, E> fmt::Debug for Fallback<B, E> {
|
||||
impl<S, B, E> fmt::Debug for Fallback<S, B, E> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
|
||||
Self::Custom(inner) => f.debug_tuple("Custom").field(inner).finish(),
|
||||
Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(),
|
||||
Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, E> Fallback<B, E> {
|
||||
fn map<F, B2, E2>(self, f: F) -> Fallback<B2, E2>
|
||||
impl<S, B, E> Fallback<S, B, E> {
|
||||
fn map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2>
|
||||
where
|
||||
F: FnOnce(Route<B, E>) -> Route<B2, E2>,
|
||||
S: 'static,
|
||||
B: 'static,
|
||||
E: 'static,
|
||||
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
|
||||
B2: 'static,
|
||||
E2: 'static,
|
||||
{
|
||||
match self {
|
||||
Self::Default(inner) => Fallback::Default(f(inner)),
|
||||
Self::Custom(inner) => Fallback::Custom(f(inner)),
|
||||
Self::Service(inner) => Fallback::Service(f(inner)),
|
||||
Self::BoxedHandler(inner) => Fallback::BoxedHandler(inner.map(f)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ use tower_service::Service;
|
|||
pub struct Route<B = Body, E = Infallible>(BoxCloneService<Request<B>, Response, E>);
|
||||
|
||||
impl<B, E> Route<B, E> {
|
||||
pub(super) fn new<T>(svc: T) -> Self
|
||||
pub(crate) fn new<T>(svc: T) -> Self
|
||||
where
|
||||
T: Service<Request<B>, Error = E> + Clone + Send + 'static,
|
||||
T::Response: IntoResponse + 'static,
|
||||
|
|
|
@ -30,17 +30,22 @@ impl<B> RouterService<B>
|
|||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
{
|
||||
#[track_caller]
|
||||
pub(super) fn new<S>(router: Router<S, B>) -> Self
|
||||
where
|
||||
S: Send + Sync + 'static,
|
||||
{
|
||||
let state = router
|
||||
.state
|
||||
.expect("Can't turn a `Router` that wants to inherit state into a service");
|
||||
|
||||
let routes = router
|
||||
.routes
|
||||
.into_iter()
|
||||
.map(|(route_id, endpoint)| {
|
||||
let route = match endpoint {
|
||||
Endpoint::MethodRouter(method_router) => {
|
||||
Route::new(method_router.with_state_arc(Arc::clone(&router.state)))
|
||||
Route::new(method_router.with_state_arc(Arc::clone(&state)))
|
||||
}
|
||||
Endpoint::Route(route) => route,
|
||||
};
|
||||
|
@ -54,7 +59,8 @@ where
|
|||
node: router.node,
|
||||
fallback: match router.fallback {
|
||||
Fallback::Default(route) => route,
|
||||
Fallback::Custom(route) => route,
|
||||
Fallback::Service(route) => route,
|
||||
Fallback::BoxedHandler(handler) => handler.into_route(state),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,10 +18,7 @@ async fn basic() {
|
|||
#[tokio::test]
|
||||
async fn nest() {
|
||||
let app = Router::new()
|
||||
.nest(
|
||||
"/foo",
|
||||
Router::new().route("/bar", get(|| async {})).into_service(),
|
||||
)
|
||||
.nest("/foo", Router::new().route("/bar", get(|| async {})))
|
||||
.fallback(|| async { "fallback" });
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
|
|
@ -82,7 +82,7 @@ async fn nested_or() {
|
|||
assert_eq!(client.get("/bar").send().await.text().await, "bar");
|
||||
assert_eq!(client.get("/baz").send().await.text().await, "baz");
|
||||
|
||||
let client = TestClient::new(Router::new().nest("/foo", bar_or_baz.into_service()));
|
||||
let client = TestClient::new(Router::new().nest("/foo", bar_or_baz));
|
||||
assert_eq!(client.get("/foo/bar").send().await.text().await, "bar");
|
||||
assert_eq!(client.get("/foo/baz").send().await.text().await, "baz");
|
||||
}
|
||||
|
@ -145,10 +145,7 @@ async fn layer_and_handle_error() {
|
|||
#[tokio::test]
|
||||
async fn nesting() {
|
||||
let one = Router::new().route("/foo", get(|| async {}));
|
||||
let two = Router::new().nest(
|
||||
"/bar",
|
||||
Router::new().route("/baz", get(|| async {})).into_service(),
|
||||
);
|
||||
let two = Router::new().nest("/bar", Router::new().route("/baz", get(|| async {})));
|
||||
let app = one.merge(two);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
@ -232,12 +229,7 @@ async fn all_the_uris(
|
|||
|
||||
#[tokio::test]
|
||||
async fn nesting_and_seeing_the_right_uri() {
|
||||
let one = Router::new().nest(
|
||||
"/foo/",
|
||||
Router::new()
|
||||
.route("/bar", get(all_the_uris))
|
||||
.into_service(),
|
||||
);
|
||||
let one = Router::new().nest("/foo/", Router::new().route("/bar", get(all_the_uris)));
|
||||
let two = Router::new().route("/foo", get(all_the_uris));
|
||||
|
||||
let client = TestClient::new(one.merge(two));
|
||||
|
@ -269,14 +261,7 @@ async fn nesting_and_seeing_the_right_uri() {
|
|||
async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() {
|
||||
let one = Router::new().nest(
|
||||
"/foo/",
|
||||
Router::new()
|
||||
.nest(
|
||||
"/bar",
|
||||
Router::new()
|
||||
.route("/baz", get(all_the_uris))
|
||||
.into_service(),
|
||||
)
|
||||
.into_service(),
|
||||
Router::new().nest("/bar", Router::new().route("/baz", get(all_the_uris))),
|
||||
);
|
||||
let two = Router::new().route("/foo", get(all_the_uris));
|
||||
|
||||
|
@ -309,21 +294,9 @@ async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() {
|
|||
async fn nesting_and_seeing_the_right_uri_ors_with_nesting() {
|
||||
let one = Router::new().nest(
|
||||
"/one",
|
||||
Router::new()
|
||||
.nest(
|
||||
"/bar",
|
||||
Router::new()
|
||||
.route("/baz", get(all_the_uris))
|
||||
.into_service(),
|
||||
)
|
||||
.into_service(),
|
||||
);
|
||||
let two = Router::new().nest(
|
||||
"/two",
|
||||
Router::new()
|
||||
.route("/qux", get(all_the_uris))
|
||||
.into_service(),
|
||||
Router::new().nest("/bar", Router::new().route("/baz", get(all_the_uris))),
|
||||
);
|
||||
let two = Router::new().nest("/two", Router::new().route("/qux", get(all_the_uris)));
|
||||
let three = Router::new().route("/three", get(all_the_uris));
|
||||
|
||||
let client = TestClient::new(one.merge(two).merge(three));
|
||||
|
@ -366,14 +339,7 @@ async fn nesting_and_seeing_the_right_uri_ors_with_nesting() {
|
|||
async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() {
|
||||
let one = Router::new().nest(
|
||||
"/one",
|
||||
Router::new()
|
||||
.nest(
|
||||
"/foo",
|
||||
Router::new()
|
||||
.route("/bar", get(all_the_uris))
|
||||
.into_service(),
|
||||
)
|
||||
.into_service(),
|
||||
Router::new().nest("/foo", Router::new().route("/bar", get(all_the_uris))),
|
||||
);
|
||||
let two = Router::new().route("/two/foo", get(all_the_uris));
|
||||
|
||||
|
@ -500,3 +466,18 @@ async fn merging_routes_different_paths_different_states() {
|
|||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await, "bar state");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn inherit_state_via_merge() {
|
||||
let foo = Router::inherit_state().route(
|
||||
"/foo",
|
||||
get(|State(state): State<&'static str>| async move { state }),
|
||||
);
|
||||
|
||||
let app = Router::with_state("state").merge(foo);
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await, "state");
|
||||
}
|
||||
|
|
|
@ -19,9 +19,7 @@ use std::{
|
|||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
use tower::{
|
||||
service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder, ServiceExt,
|
||||
};
|
||||
use tower::{service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder};
|
||||
use tower_http::{auth::RequireAuthorizationLayer, limit::RequestBodyLimitLayer};
|
||||
use tower_service::Service;
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ async fn nesting_apps() {
|
|||
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async { "hi" }))
|
||||
.nest("/:version/api", api_routes.into_service());
|
||||
.nest("/:version/api", api_routes);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
|
@ -61,7 +61,7 @@ async fn nesting_apps() {
|
|||
#[tokio::test]
|
||||
async fn wrong_method_nest() {
|
||||
let nested_app = Router::new().route("/", get(|| async {}));
|
||||
let app = Router::new().nest("/", nested_app.into_service());
|
||||
let app = Router::new().nest("/", nested_app);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
|
@ -78,7 +78,7 @@ async fn wrong_method_nest() {
|
|||
#[tokio::test]
|
||||
async fn nesting_router_at_root() {
|
||||
let nested = Router::new().route("/foo", get(|uri: Uri| async move { uri.to_string() }));
|
||||
let app = Router::new().nest("/", nested.into_service());
|
||||
let app = Router::new().nest("/", nested);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
|
@ -96,7 +96,7 @@ async fn nesting_router_at_root() {
|
|||
#[tokio::test]
|
||||
async fn nesting_router_at_empty_path() {
|
||||
let nested = Router::new().route("/foo", get(|uri: Uri| async move { uri.to_string() }));
|
||||
let app = Router::new().nest("", nested.into_service());
|
||||
let app = Router::new().nest("", nested);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
|
@ -113,7 +113,7 @@ async fn nesting_router_at_empty_path() {
|
|||
|
||||
#[tokio::test]
|
||||
async fn nesting_handler_at_root() {
|
||||
let app = Router::new().nest("/", get(|uri: Uri| async move { uri.to_string() }));
|
||||
let app = Router::new().nest_service("/", get(|uri: Uri| async move { uri.to_string() }));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
|
@ -134,18 +134,15 @@ async fn nesting_handler_at_root() {
|
|||
async fn nested_url_extractor() {
|
||||
let app = Router::new().nest(
|
||||
"/foo",
|
||||
Router::new()
|
||||
.nest(
|
||||
"/bar",
|
||||
Router::new()
|
||||
.route("/baz", get(|uri: Uri| async move { uri.to_string() }))
|
||||
.route(
|
||||
"/qux",
|
||||
get(|req: Request<Body>| async move { req.uri().to_string() }),
|
||||
)
|
||||
.into_service(),
|
||||
)
|
||||
.into_service(),
|
||||
Router::new().nest(
|
||||
"/bar",
|
||||
Router::new()
|
||||
.route("/baz", get(|uri: Uri| async move { uri.to_string() }))
|
||||
.route(
|
||||
"/qux",
|
||||
get(|req: Request<Body>| async move { req.uri().to_string() }),
|
||||
),
|
||||
),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
@ -163,17 +160,13 @@ async fn nested_url_extractor() {
|
|||
async fn nested_url_original_extractor() {
|
||||
let app = Router::new().nest(
|
||||
"/foo",
|
||||
Router::new()
|
||||
.nest(
|
||||
"/bar",
|
||||
Router::new()
|
||||
.route(
|
||||
"/baz",
|
||||
get(|uri: extract::OriginalUri| async move { uri.0.to_string() }),
|
||||
)
|
||||
.into_service(),
|
||||
)
|
||||
.into_service(),
|
||||
Router::new().nest(
|
||||
"/bar",
|
||||
Router::new().route(
|
||||
"/baz",
|
||||
get(|uri: extract::OriginalUri| async move { uri.0.to_string() }),
|
||||
),
|
||||
),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
@ -187,20 +180,16 @@ async fn nested_url_original_extractor() {
|
|||
async fn nested_service_sees_stripped_uri() {
|
||||
let app = Router::new().nest(
|
||||
"/foo",
|
||||
Router::new()
|
||||
.nest(
|
||||
"/bar",
|
||||
Router::new()
|
||||
.route_service(
|
||||
"/baz",
|
||||
service_fn(|req: Request<Body>| async move {
|
||||
let body = boxed(Body::from(req.uri().to_string()));
|
||||
Ok::<_, Infallible>(Response::new(body))
|
||||
}),
|
||||
)
|
||||
.into_service(),
|
||||
)
|
||||
.into_service(),
|
||||
Router::new().nest(
|
||||
"/bar",
|
||||
Router::new().route_service(
|
||||
"/baz",
|
||||
service_fn(|req: Request<Body>| async move {
|
||||
let body = boxed(Body::from(req.uri().to_string()));
|
||||
Ok::<_, Infallible>(Response::new(body))
|
||||
}),
|
||||
),
|
||||
),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
@ -212,7 +201,7 @@ async fn nested_service_sees_stripped_uri() {
|
|||
|
||||
#[tokio::test]
|
||||
async fn nest_static_file_server() {
|
||||
let app = Router::new().nest(
|
||||
let app = Router::new().nest_service(
|
||||
"/static",
|
||||
get_service(ServeDir::new(".")).handle_error(|error| async move {
|
||||
(
|
||||
|
@ -235,8 +224,7 @@ async fn nested_multiple_routes() {
|
|||
"/api",
|
||||
Router::new()
|
||||
.route("/users", get(|| async { "users" }))
|
||||
.route("/teams", get(|| async { "teams" }))
|
||||
.into_service(),
|
||||
.route("/teams", get(|| async { "teams" })),
|
||||
)
|
||||
.route("/", get(|| async { "root" }));
|
||||
|
||||
|
@ -251,12 +239,7 @@ async fn nested_multiple_routes() {
|
|||
#[should_panic = "Invalid route \"/\": insertion failed due to conflict with previously registered route: /*__private__axum_nest_tail_param"]
|
||||
fn nested_at_root_with_other_routes() {
|
||||
let _: Router = Router::new()
|
||||
.nest(
|
||||
"/",
|
||||
Router::new()
|
||||
.route("/users", get(|| async {}))
|
||||
.into_service(),
|
||||
)
|
||||
.nest("/", Router::new().route("/users", get(|| async {})))
|
||||
.route("/", get(|| async {}));
|
||||
}
|
||||
|
||||
|
@ -265,15 +248,11 @@ async fn multiple_top_level_nests() {
|
|||
let app = Router::new()
|
||||
.nest(
|
||||
"/one",
|
||||
Router::new()
|
||||
.route("/route", get(|| async { "one" }))
|
||||
.into_service(),
|
||||
Router::new().route("/route", get(|| async { "one" })),
|
||||
)
|
||||
.nest(
|
||||
"/two",
|
||||
Router::new()
|
||||
.route("/route", get(|| async { "two" }))
|
||||
.into_service(),
|
||||
Router::new().route("/route", get(|| async { "two" })),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
@ -285,7 +264,7 @@ async fn multiple_top_level_nests() {
|
|||
#[tokio::test]
|
||||
#[should_panic(expected = "Invalid route: nested routes cannot contain wildcards (*)")]
|
||||
async fn nest_cannot_contain_wildcards() {
|
||||
Router::<_, Body>::new().nest("/one/*rest", Router::new().into_service());
|
||||
Router::<_, Body>::new().nest("/one/*rest", Router::new());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -323,10 +302,7 @@ async fn outer_middleware_still_see_whole_url() {
|
|||
.route("/", get(handler))
|
||||
.route("/foo", get(handler))
|
||||
.route("/foo/bar", get(handler))
|
||||
.nest(
|
||||
"/one",
|
||||
Router::new().route("/two", get(handler)).into_service(),
|
||||
)
|
||||
.nest("/one", Router::new().route("/two", get(handler)))
|
||||
.fallback(handler)
|
||||
.layer(tower::layer::layer_fn(SetUriExtension));
|
||||
|
||||
|
@ -344,13 +320,10 @@ async fn outer_middleware_still_see_whole_url() {
|
|||
|
||||
#[tokio::test]
|
||||
async fn nest_at_capture() {
|
||||
let api_routes = Router::new()
|
||||
.route(
|
||||
"/:b",
|
||||
get(|Path((a, b)): Path<(String, String)>| async move { format!("a={} b={}", a, b) }),
|
||||
)
|
||||
.into_service()
|
||||
.boxed_clone();
|
||||
let api_routes = Router::new().route(
|
||||
"/:b",
|
||||
get(|Path((a, b)): Path<(String, String)>| async move { format!("a={} b={}", a, b) }),
|
||||
);
|
||||
|
||||
let app = Router::new().nest("/:a", api_routes);
|
||||
|
||||
|
@ -363,7 +336,7 @@ async fn nest_at_capture() {
|
|||
|
||||
#[tokio::test]
|
||||
async fn nest_with_and_without_trailing() {
|
||||
let app = Router::new().nest("/foo", get(|| async {}));
|
||||
let app = Router::new().nest_service("/foo", get(|| async {}));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
|
@ -380,10 +353,7 @@ async fn nest_with_and_without_trailing() {
|
|||
#[tokio::test]
|
||||
async fn doesnt_call_outer_fallback() {
|
||||
let app = Router::new()
|
||||
.nest(
|
||||
"/foo",
|
||||
Router::new().route("/", get(|| async {})).into_service(),
|
||||
)
|
||||
.nest("/foo", Router::new().route("/", get(|| async {})))
|
||||
.fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") });
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
@ -401,9 +371,7 @@ async fn doesnt_call_outer_fallback() {
|
|||
async fn nesting_with_root_inner_router() {
|
||||
let app = Router::new().nest(
|
||||
"/foo",
|
||||
Router::new()
|
||||
.route("/", get(|| async { "inner route" }))
|
||||
.into_service(),
|
||||
Router::new().route("/", get(|| async { "inner route" })),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
@ -426,8 +394,7 @@ async fn fallback_on_inner() {
|
|||
"/foo",
|
||||
Router::new()
|
||||
.route("/", get(|| async {}))
|
||||
.fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") })
|
||||
.into_service(),
|
||||
.fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }),
|
||||
)
|
||||
.fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") });
|
||||
|
||||
|
@ -451,7 +418,7 @@ macro_rules! nested_route_test {
|
|||
#[tokio::test]
|
||||
async fn $name() {
|
||||
let inner = Router::new().route($route_path, get(|| async {}));
|
||||
let app = Router::new().nest($nested_path, inner.into_service());
|
||||
let app = Router::new().nest($nested_path, inner);
|
||||
let client = TestClient::new(app);
|
||||
let res = client.get($expected_path).send().await;
|
||||
let status = res.status();
|
||||
|
@ -486,7 +453,7 @@ async fn nesting_with_different_state() {
|
|||
"/foo",
|
||||
get(|State(state): State<&'static str>| async move { state }),
|
||||
)
|
||||
.nest("/nested", inner.into_service())
|
||||
.nest("/nested", inner)
|
||||
.route(
|
||||
"/bar",
|
||||
get(|State(state): State<&'static str>| async move { state }),
|
||||
|
@ -503,3 +470,18 @@ async fn nesting_with_different_state() {
|
|||
let res = client.get("/bar").send().await;
|
||||
assert_eq!(res.text().await, "outer");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn inherit_state_via_nest() {
|
||||
let foo = Router::inherit_state().route(
|
||||
"/foo",
|
||||
get(|State(state): State<&'static str>| async move { state }),
|
||||
);
|
||||
|
||||
let app = Router::with_state("state").nest("/test", foo);
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/test/foo").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await, "state");
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ async fn main() {
|
|||
)
|
||||
.route("/keys", get(list_keys))
|
||||
// Nest our admin routes under `/admin`
|
||||
.nest("/admin", admin_routes(shared_state).into_service())
|
||||
.nest("/admin", admin_routes(shared_state))
|
||||
// Add middleware to all routes
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
|
|
|
@ -131,7 +131,7 @@ where
|
|||
// to prevent directory traversal attacks we ensure the path consists of exactly one normal
|
||||
// component
|
||||
fn path_is_valid(path: &str) -> bool {
|
||||
let path = std::path::Path::new(&*path);
|
||||
let path = std::path::Path::new(path);
|
||||
let mut components = path.components().peekable();
|
||||
|
||||
if let Some(first) = components.peek() {
|
||||
|
|
Loading…
Add table
Reference in a new issue