Change Router::nest to flatten the routes (#1711)

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
David Pedersen 2023-04-11 16:17:54 +02:00 committed by GitHub
parent 14d2b2dc87
commit 2c2cf361dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 702 additions and 399 deletions

View file

@ -7,8 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased # Unreleased
- **fixed:** Fixed performance regression with `Router::nest` introduced in
0.6.0. `nest` now flattens the routes which performs better ([#1711])
- **fixed:** Extracting `MatchedPath` in nested handlers now gives the full
matched path, including the nested path ([#1711])
- **added:** Implement `Deref` and `DerefMut` for built-in extractors ([#1922]) - **added:** Implement `Deref` and `DerefMut` for built-in extractors ([#1922])
[#1711]: https://github.com/tokio-rs/axum/pull/1711
[#1922]: https://github.com/tokio-rs/axum/pull/1922 [#1922]: https://github.com/tokio-rs/axum/pull/1922
# 0.6.12 (22. March, 2023) # 0.6.12 (22. March, 2023)

View file

@ -28,25 +28,6 @@ where
into_route: |handler, state| Route::new(Handler::with_state(handler, state)), into_route: |handler, state| Route::new(Handler::with_state(handler, state)),
})) }))
} }
pub(crate) fn from_router(router: Router<S, B>) -> Self
where
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
Self(Box::new(MakeErasedRouter {
router,
into_route: |router, state| Route::new(router.with_state(state)),
}))
}
pub(crate) fn call_with_state(
self,
request: Request<B>,
state: S,
) -> RouteFuture<B, Infallible> {
self.0.call_with_state(request, state)
}
} }
impl<S, B, E> BoxedIntoRoute<S, B, E> { impl<S, B, E> BoxedIntoRoute<S, B, E> {

View file

@ -235,6 +235,26 @@ mod tests {
req req
} }
let app = Router::new()
.nest_service("/:a", Router::new().route("/:b", get(|| async move {})))
.layer(map_request(extract_matched_path));
let client = TestClient::new(app);
let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[crate::test]
async fn can_extract_nested_matched_path_in_middleware_using_nest() {
async fn extract_matched_path<B>(
matched_path: Option<MatchedPath>,
req: Request<B>,
) -> Request<B> {
assert_eq!(matched_path.unwrap().as_str(), "/:a/:b");
req
}
let app = Router::new() let app = Router::new()
.nest("/:a", Router::new().route("/:b", get(|| async move {}))) .nest("/:a", Router::new().route("/:b", get(|| async move {})))
.layer(map_request(extract_matched_path)); .layer(map_request(extract_matched_path));
@ -253,7 +273,7 @@ mod tests {
} }
let app = Router::new() let app = Router::new()
.nest("/:a", Router::new().route("/:b", get(|| async move {}))) .nest_service("/:a", Router::new().route("/:b", get(|| async move {})))
.layer(map_request(assert_no_matched_path)); .layer(map_request(assert_no_matched_path));
let client = TestClient::new(app); let client = TestClient::new(app);
@ -262,6 +282,23 @@ mod tests {
assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.status(), StatusCode::OK);
} }
#[tokio::test]
async fn can_extract_nested_matched_path_in_middleware_via_extension_using_nest() {
async fn assert_matched_path<B>(req: Request<B>) -> Request<B> {
assert!(req.extensions().get::<MatchedPath>().is_some());
req
}
let app = Router::new()
.nest("/:a", Router::new().route("/:b", get(|| async move {})))
.layer(map_request(assert_matched_path));
let client = TestClient::new(app);
let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[crate::test] #[crate::test]
async fn can_extract_nested_matched_path_in_middleware_on_nested_router() { async fn can_extract_nested_matched_path_in_middleware_on_nested_router() {
async fn extract_matched_path<B>(matched_path: MatchedPath, req: Request<B>) -> Request<B> { async fn extract_matched_path<B>(matched_path: MatchedPath, req: Request<B>) -> Request<B> {

View file

@ -1,6 +1,6 @@
//! Routing between [`Service`]s and handlers. //! Routing between [`Service`]s and handlers.
use self::{future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix}; use self::{future::RouteFuture, not_found::NotFound, path_router::PathRouter};
#[cfg(feature = "tokio")] #[cfg(feature = "tokio")]
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
use crate::{ use crate::{
@ -11,12 +11,9 @@ use crate::{
}; };
use axum_core::response::{IntoResponse, Response}; use axum_core::response::{IntoResponse, Response};
use http::Request; use http::Request;
use matchit::MatchError;
use std::{ use std::{
collections::HashMap,
convert::Infallible, convert::Infallible,
fmt, fmt,
sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
}; };
use sync_wrapper::SyncWrapper; use sync_wrapper::SyncWrapper;
@ -29,6 +26,7 @@ pub mod method_routing;
mod into_make_service; mod into_make_service;
mod method_filter; mod method_filter;
mod not_found; mod not_found;
pub(crate) mod path_router;
mod route; mod route;
mod strip_prefix; mod strip_prefix;
pub(crate) mod url_params; pub(crate) mod url_params;
@ -44,25 +42,32 @@ pub use self::method_routing::{
trace_service, MethodRouter, trace_service, MethodRouter,
}; };
macro_rules! panic_on_err {
($expr:expr) => {
match $expr {
Ok(x) => x,
Err(err) => panic!("{err}"),
}
};
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(crate) struct RouteId(u32); pub(crate) struct RouteId(u32);
/// The router type for composing handlers and services. /// The router type for composing handlers and services.
#[must_use] #[must_use]
pub struct Router<S = (), B = Body> { pub struct Router<S = (), B = Body> {
routes: HashMap<RouteId, Endpoint<S, B>>, path_router: PathRouter<S, B>,
node: Arc<Node>, fallback_router: PathRouter<S, B>,
fallback: Fallback<S, B>, default_fallback: bool,
prev_route_id: RouteId,
} }
impl<S, B> Clone for Router<S, B> { impl<S, B> Clone for Router<S, B> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
routes: self.routes.clone(), path_router: self.path_router.clone(),
node: Arc::clone(&self.node), fallback_router: self.fallback_router.clone(),
fallback: self.fallback.clone(), default_fallback: self.default_fallback,
prev_route_id: self.prev_route_id,
} }
} }
} }
@ -80,16 +85,16 @@ where
impl<S, B> fmt::Debug for Router<S, B> { impl<S, B> fmt::Debug for Router<S, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Router") f.debug_struct("Router")
.field("routes", &self.routes) .field("path_router", &self.path_router)
.field("node", &self.node) .field("fallback_router", &self.fallback_router)
.field("fallback", &self.fallback) .field("default_fallback", &self.default_fallback)
.field("prev_route_id", &self.prev_route_id)
.finish() .finish()
} }
} }
pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param"; pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param";
pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param"; pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";
pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback";
impl<S, B> Router<S, B> impl<S, B> Router<S, B>
where where
@ -101,57 +106,25 @@ where
/// Unless you add additional routes this will respond with `404 Not Found` to /// Unless you add additional routes this will respond with `404 Not Found` to
/// all requests. /// all requests.
pub fn new() -> Self { pub fn new() -> Self {
Self { let mut this = Self {
routes: Default::default(), path_router: Default::default(),
node: Default::default(), fallback_router: Default::default(),
fallback: Fallback::Default(Route::new(NotFound)), default_fallback: true,
prev_route_id: RouteId(0), };
} this = this.fallback_service(NotFound);
this.default_fallback = true;
this
} }
#[doc = include_str!("../docs/routing/route.md")] #[doc = include_str!("../docs/routing/route.md")]
#[track_caller] #[track_caller]
pub fn route(mut self, path: &str, method_router: MethodRouter<S, B>) -> Self { pub fn route(mut self, path: &str, method_router: MethodRouter<S, B>) -> Self {
#[track_caller] panic_on_err!(self.path_router.route(path, method_router));
fn validate_path(path: &str) {
if path.is_empty() {
panic!("Paths must start with a `/`. Use \"/\" for root routes");
} else if !path.starts_with('/') {
panic!("Paths must start with a `/`");
}
}
validate_path(path);
let id = self.next_route_id();
let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
.node
.path_to_route_id
.get(path)
.and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc)))
{
// if we're adding a new `MethodRouter` to a route that already has one just
// merge them. This makes `.route("/", get(_)).route("/", post(_))` work
let service = Endpoint::MethodRouter(
prev_method_router
.clone()
.merge_for_path(Some(path), method_router),
);
self.routes.insert(route_id, service);
return self;
} else {
Endpoint::MethodRouter(method_router)
};
self.set_node(path, id);
self.routes.insert(id, endpoint);
self self
} }
#[doc = include_str!("../docs/routing/route_service.md")] #[doc = include_str!("../docs/routing/route_service.md")]
pub fn route_service<T>(self, path: &str, service: T) -> Self pub fn route_service<T>(mut self, path: &str, service: T) -> Self
where where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse, T::Response: IntoResponse,
@ -164,104 +137,40 @@ where
Use `Router::nest` instead" Use `Router::nest` instead"
); );
} }
Err(svc) => svc, Err(service) => service,
}; };
self.route_endpoint(path, Endpoint::Route(Route::new(service))) panic_on_err!(self.path_router.route_service(path, service));
}
#[track_caller]
fn route_endpoint(mut self, path: &str, endpoint: Endpoint<S, B>) -> Self {
if path.is_empty() {
panic!("Paths must start with a `/`. Use \"/\" for root routes");
} else if !path.starts_with('/') {
panic!("Paths must start with a `/`");
}
let id = self.next_route_id();
self.set_node(path, id);
self.routes.insert(id, endpoint);
self self
} }
#[track_caller]
fn set_node(&mut self, path: &str, id: RouteId) {
let mut node =
Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone());
if let Err(err) = node.insert(path, id) {
panic!("Invalid route {path:?}: {err}");
}
self.node = Arc::new(node);
}
#[doc = include_str!("../docs/routing/nest.md")] #[doc = include_str!("../docs/routing/nest.md")]
#[track_caller] #[track_caller]
pub fn nest(self, path: &str, router: Router<S, B>) -> Self { pub fn nest(mut self, path: &str, router: Router<S, B>) -> Self {
self.nest_endpoint(path, RouterOrService::<_, _, NotFound>::Router(router)) let Router {
path_router,
fallback_router,
default_fallback,
} = router;
panic_on_err!(self.path_router.nest(path, path_router));
if !default_fallback {
panic_on_err!(self.fallback_router.nest(path, fallback_router));
}
self
} }
/// Like [`nest`](Self::nest), but accepts an arbitrary `Service`. /// Like [`nest`](Self::nest), but accepts an arbitrary `Service`.
#[track_caller] #[track_caller]
pub fn nest_service<T>(self, path: &str, svc: T) -> Self pub fn nest_service<T>(mut self, path: &str, service: T) -> Self
where where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse, T::Response: IntoResponse,
T::Future: Send + 'static, T::Future: Send + 'static,
{ {
self.nest_endpoint(path, RouterOrService::Service(svc)) panic_on_err!(self.path_router.nest_service(path, service));
}
#[track_caller]
fn nest_endpoint<T>(
mut self,
mut path: &str,
router_or_service: RouterOrService<S, B, T>,
) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
if path.is_empty() {
// nesting at `""` and `"/"` should mean the same thing
path = "/";
}
if path.contains('*') {
panic!("Invalid route: nested routes cannot contain wildcards (*)");
}
let prefix = path;
let path = if path.ends_with('/') {
format!("{path}*{NEST_TAIL_PARAM}")
} else {
format!("{path}/*{NEST_TAIL_PARAM}")
};
let endpoint = match router_or_service {
RouterOrService::Router(router) => {
let prefix = prefix.to_owned();
let boxed = BoxedIntoRoute::from_router(router)
.map(move |route| Route::new(StripPrefix::new(route, &prefix)));
Endpoint::NestedRouter(boxed)
}
RouterOrService::Service(svc) => {
Endpoint::Route(Route::new(StripPrefix::new(svc, prefix)))
}
};
self = self.route_endpoint(&path, endpoint.clone());
// `/*rest` is not matched by `/` so we need to also register a router at the
// prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself
// wouldn't match, which it should
self = self.route_endpoint(prefix, endpoint.clone());
if !prefix.ends_with('/') {
// same goes for `/foo/`, that should also match
self = self.route_endpoint(&format!("{prefix}/"), endpoint);
}
self self
} }
@ -272,30 +181,32 @@ where
R: Into<Router<S, B>>, R: Into<Router<S, B>>,
{ {
let Router { let Router {
routes, path_router,
node, fallback_router: other_fallback,
fallback, default_fallback,
prev_route_id: _,
} = other.into(); } = other.into();
for (id, route) in routes { panic_on_err!(self.path_router.merge(path_router));
let path = node
.route_id_to_path match (self.default_fallback, default_fallback) {
.get(&id) // both have the default fallback
.expect("no path for route id. This is a bug in axum. Please file an issue"); // use the one from other
self = match route { (true, true) => {
Endpoint::MethodRouter(method_router) => self.route(path, method_router), self.fallback_router = other_fallback;
Endpoint::Route(route) => self.route_service(path, route), }
Endpoint::NestedRouter(router) => { // self has default fallback, other has a custom fallback
self.route_endpoint(path, Endpoint::NestedRouter(router)) (true, false) => {
self.fallback_router = other_fallback;
self.default_fallback = false;
}
// self has a custom fallback, other has a default
// nothing to do
(false, true) => {}
// both have a custom fallback, not allowed
(false, false) => {
panic!("Cannot merge two `Router`s that both have a fallback")
} }
}; };
}
self.fallback = self
.fallback
.merge(fallback)
.expect("Cannot merge two `Router`s that both have a fallback");
self self
} }
@ -310,22 +221,10 @@ where
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewReqBody: HttpBody + 'static, NewReqBody: HttpBody + 'static,
{ {
let routes = self
.routes
.into_iter()
.map(|(id, endpoint)| {
let route = endpoint.layer(layer.clone());
(id, route)
})
.collect();
let fallback = self.fallback.map(|route| route.layer(layer));
Router { Router {
routes, path_router: self.path_router.layer(layer.clone()),
node: self.node, fallback_router: self.fallback_router.layer(layer),
fallback, default_fallback: self.default_fallback,
prev_route_id: self.prev_route_id,
} }
} }
@ -339,79 +238,50 @@ where
<L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static, <L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<B>>>::Future: Send + 'static, <L::Service as Service<Request<B>>>::Future: Send + 'static,
{ {
if self.routes.is_empty() {
panic!(
"Adding a route_layer before any routes is a no-op. \
Add the routes you want the layer to apply to first."
);
}
let routes = self
.routes
.into_iter()
.map(|(id, endpoint)| {
let route = endpoint.layer(layer.clone());
(id, route)
})
.collect();
Router { Router {
routes, path_router: self.path_router.route_layer(layer),
node: self.node, fallback_router: self.fallback_router,
fallback: self.fallback, default_fallback: self.default_fallback,
prev_route_id: self.prev_route_id,
} }
} }
#[track_caller]
#[doc = include_str!("../docs/routing/fallback.md")] #[doc = include_str!("../docs/routing/fallback.md")]
pub fn fallback<H, T>(mut self, handler: H) -> Self pub fn fallback<H, T>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: 'static, T: 'static,
{ {
self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler)); let endpoint = Endpoint::MethodRouter(any(handler));
self self.fallback_endpoint(endpoint)
} }
/// Add a fallback [`Service`] to the router. /// Add a fallback [`Service`] to the router.
/// ///
/// See [`Router::fallback`] for more details. /// See [`Router::fallback`] for more details.
pub fn fallback_service<T>(mut self, svc: T) -> Self pub fn fallback_service<T>(self, service: T) -> Self
where where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse, T::Response: IntoResponse,
T::Future: Send + 'static, T::Future: Send + 'static,
{ {
self.fallback = Fallback::Service(Route::new(svc)); self.fallback_endpoint(Endpoint::Route(Route::new(service)))
}
fn fallback_endpoint(mut self, endpoint: Endpoint<S, B>) -> Self {
self.fallback_router.replace_endpoint("/", endpoint.clone());
self.fallback_router
.replace_endpoint(&format!("/*{FALLBACK_PARAM}"), endpoint);
self.default_fallback = false;
self self
} }
#[doc = include_str!("../docs/routing/with_state.md")] #[doc = include_str!("../docs/routing/with_state.md")]
pub fn with_state<S2>(self, state: S) -> Router<S2, B> { pub fn with_state<S2>(self, state: S) -> Router<S2, B> {
let routes = self
.routes
.into_iter()
.map(|(id, endpoint)| {
let endpoint: Endpoint<S2, B> = match endpoint {
Endpoint::MethodRouter(method_router) => {
Endpoint::MethodRouter(method_router.with_state(state.clone()))
}
Endpoint::Route(route) => Endpoint::Route(route),
Endpoint::NestedRouter(router) => {
Endpoint::Route(router.into_route(state.clone()))
}
};
(id, endpoint)
})
.collect();
let fallback = self.fallback.with_state(state);
Router { Router {
routes, path_router: self.path_router.with_state(state.clone()),
node: self.node, fallback_router: self.fallback_router.with_state(state),
fallback, default_fallback: self.default_fallback,
prev_route_id: self.prev_route_id,
} }
} }
@ -420,85 +290,42 @@ where
mut req: Request<B>, mut req: Request<B>,
state: S, state: S,
) -> RouteFuture<B, Infallible> { ) -> RouteFuture<B, Infallible> {
#[cfg(feature = "original-uri")] // required for opaque routers to still inherit the fallback
{ // TODO(david): remove this feature in 0.7
use crate::extract::OriginalUri; if !self.default_fallback {
if req.extensions().get::<OriginalUri>().is_none() {
let original_uri = OriginalUri(req.uri().clone());
req.extensions_mut().insert(original_uri);
}
}
let path = req.uri().path().to_owned();
match self.node.at(&path) {
Ok(match_) => {
match &self.fallback {
Fallback::Default(_) => {}
Fallback::Service(fallback) => {
req.extensions_mut()
.insert(SuperFallback(SyncWrapper::new(fallback.clone())));
}
Fallback::BoxedHandler(fallback) => {
req.extensions_mut().insert(SuperFallback(SyncWrapper::new( req.extensions_mut().insert(SuperFallback(SyncWrapper::new(
fallback.clone().into_route(state.clone()), self.fallback_router.clone(),
))); )));
} }
match self.path_router.call_with_state(req, state) {
Ok(future) => {
println!("path_router hit");
future
}
Err((mut req, state)) => {
let super_fallback = req
.extensions_mut()
.remove::<SuperFallback<S, B>>()
.map(|SuperFallback(path_router)| path_router.into_inner());
if let Some(mut super_fallback) = super_fallback {
return super_fallback
.call_with_state(req, state)
.unwrap_or_else(|_| unreachable!());
} }
let id = *match_.value; match self.fallback_router.call_with_state(req, state) {
Ok(future) => future,
#[cfg(feature = "matched-path")] Err((_req, _state)) => {
crate::extract::matched_path::set_matched_path_for_request( unreachable!(
id, "the default fallback added in `Router::new` \
&self.node.route_id_to_path, matches everything"
req.extensions_mut(), )
);
url_params::insert_url_params(req.extensions_mut(), match_.params);
let endpont = self
.routes
.get_mut(&id)
.expect("no route for id. This is a bug in axum. Please file an issue");
match endpont {
Endpoint::MethodRouter(method_router) => {
method_router.call_with_state(req, state)
}
Endpoint::Route(route) => route.call(req),
Endpoint::NestedRouter(router) => router.clone().call_with_state(req, state),
} }
} }
Err(
MatchError::NotFound
| MatchError::ExtraTrailingSlash
| MatchError::MissingTrailingSlash,
) => match &mut self.fallback {
Fallback::Default(fallback) => {
if let Some(super_fallback) = req.extensions_mut().remove::<SuperFallback<B>>()
{
let mut super_fallback = super_fallback.0.into_inner();
super_fallback.call(req)
} else {
fallback.call(req)
} }
} }
Fallback::Service(fallback) => fallback.call(req),
Fallback::BoxedHandler(handler) => handler.clone().into_route(state).call(req),
},
}
}
fn next_route_id(&mut self) -> RouteId {
let next_id = self
.prev_route_id
.0
.checked_add(1)
.expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
self.prev_route_id = RouteId(next_id);
self.prev_route_id
} }
} }
@ -563,47 +390,6 @@ where
} }
} }
/// Wrapper around `matchit::Router` that supports merging two `Router`s.
#[derive(Clone, Default)]
struct Node {
inner: matchit::Router<RouteId>,
route_id_to_path: HashMap<RouteId, Arc<str>>,
path_to_route_id: HashMap<Arc<str>, RouteId>,
}
impl Node {
fn insert(
&mut self,
path: impl Into<String>,
val: RouteId,
) -> Result<(), matchit::InsertError> {
let path = path.into();
self.inner.insert(&path, val)?;
let shared_path: Arc<str> = path.into();
self.route_id_to_path.insert(val, shared_path.clone());
self.path_to_route_id.insert(shared_path, val);
Ok(())
}
fn at<'n, 'p>(
&'n self,
path: &'p str,
) -> Result<matchit::Match<'n, 'p, &'n RouteId>, MatchError> {
self.inner.at(path)
}
}
impl fmt::Debug for Node {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Node")
.field("paths", &self.route_id_to_path)
.finish()
}
}
enum Fallback<S, B, E = Infallible> { enum Fallback<S, B, E = Infallible> {
Default(Route<B, E>), Default(Route<B, E>),
Service(Route<B, E>), Service(Route<B, E>),
@ -671,7 +457,6 @@ impl<S, B, E> fmt::Debug for Fallback<S, B, E> {
enum Endpoint<S, B> { enum Endpoint<S, B> {
MethodRouter(MethodRouter<S, B>), MethodRouter(MethodRouter<S, B>),
Route(Route<B>), Route(Route<B>),
NestedRouter(BoxedIntoRoute<S, B, Infallible>),
} }
impl<S, B> Endpoint<S, B> impl<S, B> Endpoint<S, B>
@ -693,9 +478,6 @@ where
Endpoint::MethodRouter(method_router.layer(layer)) Endpoint::MethodRouter(method_router.layer(layer))
} }
Endpoint::Route(route) => Endpoint::Route(route.layer(layer)), Endpoint::Route(route) => Endpoint::Route(route.layer(layer)),
Endpoint::NestedRouter(router) => {
Endpoint::NestedRouter(router.map(|route| route.layer(layer)))
}
} }
} }
} }
@ -705,7 +487,6 @@ impl<S, B> Clone for Endpoint<S, B> {
match self { match self {
Self::MethodRouter(inner) => Self::MethodRouter(inner.clone()), Self::MethodRouter(inner) => Self::MethodRouter(inner.clone()),
Self::Route(inner) => Self::Route(inner.clone()), Self::Route(inner) => Self::Route(inner.clone()),
Self::NestedRouter(router) => Self::NestedRouter(router.clone()),
} }
} }
} }
@ -717,17 +498,11 @@ impl<S, B> fmt::Debug for Endpoint<S, B> {
f.debug_tuple("MethodRouter").field(method_router).finish() f.debug_tuple("MethodRouter").field(method_router).finish()
} }
Self::Route(route) => f.debug_tuple("Route").field(route).finish(), Self::Route(route) => f.debug_tuple("Route").field(route).finish(),
Self::NestedRouter(router) => f.debug_tuple("NestedRouter").field(router).finish(),
} }
} }
} }
enum RouterOrService<S, B, T> { struct SuperFallback<S, B>(SyncWrapper<PathRouter<S, B>>);
Router(Router<S, B>),
Service(T),
}
struct SuperFallback<B>(SyncWrapper<Route<B>>);
#[test] #[test]
#[allow(warnings)] #[allow(warnings)]

View file

@ -29,6 +29,7 @@ where
} }
fn call(&mut self, _req: Request<B>) -> Self::Future { fn call(&mut self, _req: Request<B>) -> Self::Future {
println!("NotFound hit");
ready(Ok(StatusCode::NOT_FOUND.into_response())) ready(Ok(StatusCode::NOT_FOUND.into_response()))
} }
} }

View file

@ -0,0 +1,445 @@
use crate::body::{Body, HttpBody};
use axum_core::response::IntoResponse;
use http::Request;
use matchit::MatchError;
use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
use tower_layer::Layer;
use tower_service::Service;
use super::{
future::RouteFuture, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route,
RouteId, NEST_TAIL_PARAM,
};
pub(super) struct PathRouter<S = (), B = Body> {
routes: HashMap<RouteId, Endpoint<S, B>>,
node: Arc<Node>,
prev_route_id: RouteId,
}
impl<S, B> PathRouter<S, B>
where
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
pub(super) fn route(
&mut self,
path: &str,
method_router: MethodRouter<S, B>,
) -> Result<(), Cow<'static, str>> {
fn validate_path(path: &str) -> Result<(), &'static str> {
if path.is_empty() {
return Err("Paths must start with a `/`. Use \"/\" for root routes");
} else if !path.starts_with('/') {
return Err("Paths must start with a `/`");
}
Ok(())
}
validate_path(path)?;
let id = self.next_route_id();
let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
.node
.path_to_route_id
.get(path)
.and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc)))
{
// if we're adding a new `MethodRouter` to a route that already has one just
// merge them. This makes `.route("/", get(_)).route("/", post(_))` work
let service = Endpoint::MethodRouter(
prev_method_router
.clone()
.merge_for_path(Some(path), method_router),
);
self.routes.insert(route_id, service);
return Ok(());
} else {
Endpoint::MethodRouter(method_router)
};
self.set_node(path, id)?;
self.routes.insert(id, endpoint);
Ok(())
}
pub(super) fn route_service<T>(
&mut self,
path: &str,
service: T,
) -> Result<(), Cow<'static, str>>
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
self.route_endpoint(path, Endpoint::Route(Route::new(service)))
}
pub(super) fn route_endpoint(
&mut self,
path: &str,
endpoint: Endpoint<S, B>,
) -> Result<(), Cow<'static, str>> {
if path.is_empty() {
return Err("Paths must start with a `/`. Use \"/\" for root routes".into());
} else if !path.starts_with('/') {
return Err("Paths must start with a `/`".into());
}
let id = self.next_route_id();
self.set_node(path, id)?;
self.routes.insert(id, endpoint);
Ok(())
}
fn set_node(&mut self, path: &str, id: RouteId) -> Result<(), String> {
let mut node =
Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone());
if let Err(err) = node.insert(path, id) {
return Err(format!("Invalid route {path:?}: {err}"));
}
self.node = Arc::new(node);
Ok(())
}
pub(super) fn merge(&mut self, other: PathRouter<S, B>) -> Result<(), Cow<'static, str>> {
let PathRouter {
routes,
node,
prev_route_id: _,
} = other;
for (id, route) in routes {
let path = node
.route_id_to_path
.get(&id)
.expect("no path for route id. This is a bug in axum. Please file an issue");
match route {
Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
Endpoint::Route(route) => self.route_service(path, route)?,
};
}
Ok(())
}
pub(super) fn nest(
&mut self,
path: &str,
router: PathRouter<S, B>,
) -> Result<(), Cow<'static, str>> {
let prefix = validate_nest_path(path);
let PathRouter {
routes,
node,
prev_route_id: _,
} = router;
for (id, endpoint) in routes {
let inner_path = node
.route_id_to_path
.get(&id)
.expect("no path for route id. This is a bug in axum. Please file an issue");
let path = path_for_nested_route(prefix, inner_path);
match endpoint.layer(StripPrefix::layer(prefix)) {
Endpoint::MethodRouter(method_router) => {
self.route(&path, method_router)?;
}
Endpoint::Route(route) => {
self.route_endpoint(&path, Endpoint::Route(route))?;
}
}
}
Ok(())
}
pub(super) fn nest_service<T>(&mut self, path: &str, svc: T) -> Result<(), Cow<'static, str>>
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
let path = validate_nest_path(path);
let prefix = path;
let path = if path.ends_with('/') {
format!("{path}*{NEST_TAIL_PARAM}")
} else {
format!("{path}/*{NEST_TAIL_PARAM}")
};
let endpoint = Endpoint::Route(Route::new(StripPrefix::new(svc, prefix)));
self.route_endpoint(&path, endpoint.clone())?;
// `/*rest` is not matched by `/` so we need to also register a router at the
// prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself
// wouldn't match, which it should
self.route_endpoint(prefix, endpoint.clone())?;
if !prefix.ends_with('/') {
// same goes for `/foo/`, that should also match
self.route_endpoint(&format!("{prefix}/"), endpoint)?;
}
Ok(())
}
pub(super) fn layer<L, NewReqBody>(self, layer: L) -> PathRouter<S, NewReqBody>
where
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewReqBody: HttpBody + 'static,
{
let routes = self
.routes
.into_iter()
.map(|(id, endpoint)| {
let route = endpoint.layer(layer.clone());
(id, route)
})
.collect();
PathRouter {
routes,
node: self.node,
prev_route_id: self.prev_route_id,
}
}
#[track_caller]
pub(super) fn route_layer<L>(self, layer: L) -> Self
where
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<B>> + Clone + Send + 'static,
<L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<B>>>::Future: Send + 'static,
{
if self.routes.is_empty() {
panic!(
"Adding a route_layer before any routes is a no-op. \
Add the routes you want the layer to apply to first."
);
}
let routes = self
.routes
.into_iter()
.map(|(id, endpoint)| {
let route = endpoint.layer(layer.clone());
(id, route)
})
.collect();
PathRouter {
routes,
node: self.node,
prev_route_id: self.prev_route_id,
}
}
pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, B> {
let routes = self
.routes
.into_iter()
.map(|(id, endpoint)| {
let endpoint: Endpoint<S2, B> = match endpoint {
Endpoint::MethodRouter(method_router) => {
Endpoint::MethodRouter(method_router.with_state(state.clone()))
}
Endpoint::Route(route) => Endpoint::Route(route),
};
(id, endpoint)
})
.collect();
PathRouter {
routes,
node: self.node,
prev_route_id: self.prev_route_id,
}
}
pub(super) fn call_with_state(
&mut self,
mut req: Request<B>,
state: S,
) -> Result<RouteFuture<B, Infallible>, (Request<B>, S)> {
#[cfg(feature = "original-uri")]
{
use crate::extract::OriginalUri;
if req.extensions().get::<OriginalUri>().is_none() {
let original_uri = OriginalUri(req.uri().clone());
req.extensions_mut().insert(original_uri);
}
}
let path = req.uri().path().to_owned();
match self.node.at(&path) {
Ok(match_) => {
let id = *match_.value;
#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
req.extensions_mut(),
);
url_params::insert_url_params(req.extensions_mut(), match_.params);
let endpont = self
.routes
.get_mut(&id)
.expect("no route for id. This is a bug in axum. Please file an issue");
match endpont {
Endpoint::MethodRouter(method_router) => {
Ok(method_router.call_with_state(req, state))
}
Endpoint::Route(route) => Ok(route.clone().call(req)),
}
}
// explicitly handle all variants in case matchit adds
// new ones we need to handle differently
Err(
MatchError::NotFound
| MatchError::ExtraTrailingSlash
| MatchError::MissingTrailingSlash,
) => Err((req, state)),
}
}
pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S, B>) {
match self.node.at(path) {
Ok(match_) => {
let id = *match_.value;
self.routes.insert(id, endpoint);
}
Err(_) => self
.route_endpoint(path, endpoint)
.expect("path wasn't matched so endpoint shouldn't exist"),
}
}
fn next_route_id(&mut self) -> RouteId {
let next_id = self
.prev_route_id
.0
.checked_add(1)
.expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
self.prev_route_id = RouteId(next_id);
self.prev_route_id
}
}
impl<B, S> Default for PathRouter<S, B> {
fn default() -> Self {
Self {
routes: Default::default(),
node: Default::default(),
prev_route_id: RouteId(0),
}
}
}
impl<S, B> fmt::Debug for PathRouter<S, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PathRouter")
.field("routes", &self.routes)
.field("node", &self.node)
.finish()
}
}
impl<S, B> Clone for PathRouter<S, B> {
fn clone(&self) -> Self {
Self {
routes: self.routes.clone(),
node: self.node.clone(),
prev_route_id: self.prev_route_id,
}
}
}
/// Wrapper around `matchit::Router` that supports merging two `Router`s.
#[derive(Clone, Default)]
struct Node {
inner: matchit::Router<RouteId>,
route_id_to_path: HashMap<RouteId, Arc<str>>,
path_to_route_id: HashMap<Arc<str>, RouteId>,
}
impl Node {
fn insert(
&mut self,
path: impl Into<String>,
val: RouteId,
) -> Result<(), matchit::InsertError> {
let path = path.into();
self.inner.insert(&path, val)?;
let shared_path: Arc<str> = path.into();
self.route_id_to_path.insert(val, shared_path.clone());
self.path_to_route_id.insert(shared_path, val);
Ok(())
}
fn at<'n, 'p>(
&'n self,
path: &'p str,
) -> Result<matchit::Match<'n, 'p, &'n RouteId>, MatchError> {
self.inner.at(path)
}
}
impl fmt::Debug for Node {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Node")
.field("paths", &self.route_id_to_path)
.finish()
}
}
#[track_caller]
fn validate_nest_path(path: &str) -> &str {
if path.is_empty() {
// nesting at `""` and `"/"` should mean the same thing
return "/";
}
if path.contains('*') {
panic!("Invalid route: nested routes cannot contain wildcards (*)");
}
path
}
pub(crate) fn path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str> {
debug_assert!(prefix.starts_with('/'));
debug_assert!(path.starts_with('/'));
if prefix.ends_with('/') {
format!("{prefix}{}", path.trim_start_matches('/')).into()
} else if path == "/" {
prefix.into()
} else {
format!("{prefix}{path}").into()
}
}

View file

@ -3,6 +3,8 @@ use std::{
sync::Arc, sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
}; };
use tower::Layer;
use tower_layer::layer_fn;
use tower_service::Service; use tower_service::Service;
#[derive(Clone)] #[derive(Clone)]
@ -18,6 +20,14 @@ impl<S> StripPrefix<S> {
prefix: prefix.into(), prefix: prefix.into(),
} }
} }
pub(super) fn layer(prefix: &str) -> impl Layer<S, Service = Self> + Clone {
let prefix = Arc::from(prefix);
layer_fn(move |inner| Self {
inner,
prefix: Arc::clone(&prefix),
})
}
} }
impl<S, B> Service<Request<B>> for StripPrefix<S> impl<S, B> Service<Request<B>> for StripPrefix<S>

View file

@ -93,6 +93,10 @@ async fn doesnt_inherit_fallback_if_overriden() {
let res = client.get("/foo/bar").send().await; let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "inner"); assert_eq!(res.text().await, "inner");
let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "outer");
} }
#[crate::test] #[crate::test]
@ -203,3 +207,21 @@ async fn fallback_inherited_into_nested_opaque_service() {
assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "outer"); assert_eq!(res.text().await, "outer");
} }
#[crate::test]
async fn nest_fallback_on_inner() {
let app = Router::new()
.nest(
"/foo",
Router::new()
.route("/", get(|| async {}))
.fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }),
)
.fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") });
let client = TestClient::new(app);
let res = client.get("/foo/not-found").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "inner fallback");
}

View file

@ -4,7 +4,10 @@ use crate::{
extract::{self, DefaultBodyLimit, FromRef, Path, State}, extract::{self, DefaultBodyLimit, FromRef, Path, State},
handler::{Handler, HandlerWithoutStateExt}, handler::{Handler, HandlerWithoutStateExt},
response::IntoResponse, response::IntoResponse,
routing::{delete, get, get_service, on, on_service, patch, patch_service, post, MethodFilter}, routing::{
delete, get, get_service, on, on_service, patch, patch_service,
path_router::path_for_nested_route, post, MethodFilter,
},
test_helpers::*, test_helpers::*,
BoxError, Json, Router, BoxError, Json, Router,
}; };
@ -601,7 +604,10 @@ async fn head_with_middleware_applied() {
use tower_http::compression::{predicate::SizeAbove, CompressionLayer}; use tower_http::compression::{predicate::SizeAbove, CompressionLayer};
let app = Router::new() let app = Router::new()
.route("/", get(|| async { "Hello, World!" })) .nest(
"/",
Router::new().route("/", get(|| async { "Hello, World!" })),
)
.layer(CompressionLayer::new().compress_when(SizeAbove::new(0))); .layer(CompressionLayer::new().compress_when(SizeAbove::new(0)));
let client = TestClient::new(app); let client = TestClient::new(app);
@ -841,6 +847,21 @@ fn method_router_fallback_with_state() {
.with_state(state); .with_state(state);
} }
#[test]
fn test_path_for_nested_route() {
assert_eq!(path_for_nested_route("/", "/"), "/");
assert_eq!(path_for_nested_route("/a", "/"), "/a");
assert_eq!(path_for_nested_route("/", "/b"), "/b");
assert_eq!(path_for_nested_route("/a/", "/"), "/a/");
assert_eq!(path_for_nested_route("/", "/b/"), "/b/");
assert_eq!(path_for_nested_route("/a", "/b"), "/a/b");
assert_eq!(path_for_nested_route("/a/", "/b"), "/a/b");
assert_eq!(path_for_nested_route("/a", "/b/"), "/a/b/");
assert_eq!(path_for_nested_route("/a/", "/b/"), "/a/b/");
}
#[crate::test] #[crate::test]
async fn state_isnt_cloned_too_much() { async fn state_isnt_cloned_too_much() {
static SETUP_DONE: AtomicBool = AtomicBool::new(false); static SETUP_DONE: AtomicBool = AtomicBool::new(false);

View file

@ -230,6 +230,13 @@ async fn nested_multiple_routes() {
#[test] #[test]
#[should_panic = "Invalid route \"/\": insertion failed due to conflict with previously registered route: /*__private__axum_nest_tail_param"] #[should_panic = "Invalid route \"/\": insertion failed due to conflict with previously registered route: /*__private__axum_nest_tail_param"]
fn nested_service_at_root_with_other_routes() {
let _: Router = Router::new()
.nest_service("/", Router::new().route("/users", get(|| async {})))
.route("/", get(|| async {}));
}
#[test]
fn nested_at_root_with_other_routes() { fn nested_at_root_with_other_routes() {
let _: Router = Router::new() let _: Router = Router::new()
.nest("/", Router::new().route("/users", get(|| async {}))) .nest("/", Router::new().route("/users", get(|| async {})))
@ -343,42 +350,40 @@ async fn nest_with_and_without_trailing() {
assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.status(), StatusCode::OK);
} }
#[crate::test] #[tokio::test]
async fn nesting_with_root_inner_router() { async fn nesting_with_root_inner_router() {
let app = Router::new().nest(
"/foo",
Router::new().route("/", get(|| async { "inner route" })),
);
let client = TestClient::new(app);
// `/foo/` does match the `/foo` prefix and the remaining path is technically
// empty, which is the same as `/` which matches `.route("/", _)`
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
// `/foo/` does match the `/foo` prefix and the remaining path is `/`
// which matches `.route("/", _)`
let res = client.get("/foo/").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[crate::test]
async fn fallback_on_inner() {
let app = Router::new() let app = Router::new()
.nest( .nest_service("/service", Router::new().route("/", get(|| async {})))
"/foo", .nest("/router", Router::new().route("/", get(|| async {})))
Router::new() .nest("/router-slash/", Router::new().route("/", get(|| async {})));
.route("/", get(|| async {}))
.fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }),
)
.fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") });
let client = TestClient::new(app); let client = TestClient::new(app);
let res = client.get("/foo/not-found").send().await; // `/service/` does match the `/service` prefix and the remaining path is technically
// empty, which is the same as `/` which matches `.route("/", _)`
let res = client.get("/service").send().await;
assert_eq!(res.status(), StatusCode::OK);
// `/service/` does match the `/service` prefix and the remaining path is `/`
// which matches `.route("/", _)`
//
// this is perhaps a little surprising but don't think there is much we can do
let res = client.get("/service/").send().await;
assert_eq!(res.status(), StatusCode::OK);
// at least it does work like you'd expect when using `nest`
let res = client.get("/router").send().await;
assert_eq!(res.status(), StatusCode::OK);
let res = client.get("/router/").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "inner fallback");
let res = client.get("/router-slash").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
let res = client.get("/router-slash/").send().await;
assert_eq!(res.status(), StatusCode::OK);
} }
macro_rules! nested_route_test { macro_rules! nested_route_test {

View file

@ -19,6 +19,7 @@ pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params) {
let params = params let params = params
.iter() .iter()
.filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM)) .filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM))
.filter(|(key, _)| !key.starts_with(super::FALLBACK_PARAM))
.map(|(k, v)| { .map(|(k, v)| {
if let Some(decoded) = PercentDecodedStr::new(v) { if let Some(decoded) = PercentDecodedStr::new(v) {
Ok((Arc::from(k), decoded)) Ok((Arc::from(k), decoded))