mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-28 15:30:16 +01:00
Change Router::nest
to flatten the routes (#1711)
Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
parent
14d2b2dc87
commit
2c2cf361dd
11 changed files with 702 additions and 399 deletions
|
@ -7,8 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
# Unreleased
|
||||
|
||||
- **fixed:** Fixed performance regression with `Router::nest` introduced in
|
||||
0.6.0. `nest` now flattens the routes which performs better ([#1711])
|
||||
- **fixed:** Extracting `MatchedPath` in nested handlers now gives the full
|
||||
matched path, including the nested path ([#1711])
|
||||
- **added:** Implement `Deref` and `DerefMut` for built-in extractors ([#1922])
|
||||
|
||||
[#1711]: https://github.com/tokio-rs/axum/pull/1711
|
||||
[#1922]: https://github.com/tokio-rs/axum/pull/1922
|
||||
|
||||
# 0.6.12 (22. March, 2023)
|
||||
|
|
|
@ -28,25 +28,6 @@ where
|
|||
into_route: |handler, state| Route::new(Handler::with_state(handler, state)),
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn from_router(router: Router<S, B>) -> Self
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
Self(Box::new(MakeErasedRouter {
|
||||
router,
|
||||
into_route: |router, state| Route::new(router.with_state(state)),
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn call_with_state(
|
||||
self,
|
||||
request: Request<B>,
|
||||
state: S,
|
||||
) -> RouteFuture<B, Infallible> {
|
||||
self.0.call_with_state(request, state)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, B, E> BoxedIntoRoute<S, B, E> {
|
||||
|
|
|
@ -235,6 +235,26 @@ mod tests {
|
|||
req
|
||||
}
|
||||
|
||||
let app = Router::new()
|
||||
.nest_service("/:a", Router::new().route("/:b", get(|| async move {})))
|
||||
.layer(map_request(extract_matched_path));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo/bar").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn can_extract_nested_matched_path_in_middleware_using_nest() {
|
||||
async fn extract_matched_path<B>(
|
||||
matched_path: Option<MatchedPath>,
|
||||
req: Request<B>,
|
||||
) -> Request<B> {
|
||||
assert_eq!(matched_path.unwrap().as_str(), "/:a/:b");
|
||||
req
|
||||
}
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/:a", Router::new().route("/:b", get(|| async move {})))
|
||||
.layer(map_request(extract_matched_path));
|
||||
|
@ -253,7 +273,7 @@ mod tests {
|
|||
}
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/:a", Router::new().route("/:b", get(|| async move {})))
|
||||
.nest_service("/:a", Router::new().route("/:b", get(|| async move {})))
|
||||
.layer(map_request(assert_no_matched_path));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
@ -262,6 +282,23 @@ mod tests {
|
|||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn can_extract_nested_matched_path_in_middleware_via_extension_using_nest() {
|
||||
async fn assert_matched_path<B>(req: Request<B>) -> Request<B> {
|
||||
assert!(req.extensions().get::<MatchedPath>().is_some());
|
||||
req
|
||||
}
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/:a", Router::new().route("/:b", get(|| async move {})))
|
||||
.layer(map_request(assert_matched_path));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo/bar").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn can_extract_nested_matched_path_in_middleware_on_nested_router() {
|
||||
async fn extract_matched_path<B>(matched_path: MatchedPath, req: Request<B>) -> Request<B> {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
//! Routing between [`Service`]s and handlers.
|
||||
|
||||
use self::{future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix};
|
||||
use self::{future::RouteFuture, not_found::NotFound, path_router::PathRouter};
|
||||
#[cfg(feature = "tokio")]
|
||||
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
|
||||
use crate::{
|
||||
|
@ -11,12 +11,9 @@ use crate::{
|
|||
};
|
||||
use axum_core::response::{IntoResponse, Response};
|
||||
use http::Request;
|
||||
use matchit::MatchError;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
convert::Infallible,
|
||||
fmt,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use sync_wrapper::SyncWrapper;
|
||||
|
@ -29,6 +26,7 @@ pub mod method_routing;
|
|||
mod into_make_service;
|
||||
mod method_filter;
|
||||
mod not_found;
|
||||
pub(crate) mod path_router;
|
||||
mod route;
|
||||
mod strip_prefix;
|
||||
pub(crate) mod url_params;
|
||||
|
@ -44,25 +42,32 @@ pub use self::method_routing::{
|
|||
trace_service, MethodRouter,
|
||||
};
|
||||
|
||||
macro_rules! panic_on_err {
|
||||
($expr:expr) => {
|
||||
match $expr {
|
||||
Ok(x) => x,
|
||||
Err(err) => panic!("{err}"),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub(crate) struct RouteId(u32);
|
||||
|
||||
/// The router type for composing handlers and services.
|
||||
#[must_use]
|
||||
pub struct Router<S = (), B = Body> {
|
||||
routes: HashMap<RouteId, Endpoint<S, B>>,
|
||||
node: Arc<Node>,
|
||||
fallback: Fallback<S, B>,
|
||||
prev_route_id: RouteId,
|
||||
path_router: PathRouter<S, B>,
|
||||
fallback_router: PathRouter<S, B>,
|
||||
default_fallback: bool,
|
||||
}
|
||||
|
||||
impl<S, B> Clone for Router<S, B> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
routes: self.routes.clone(),
|
||||
node: Arc::clone(&self.node),
|
||||
fallback: self.fallback.clone(),
|
||||
prev_route_id: self.prev_route_id,
|
||||
path_router: self.path_router.clone(),
|
||||
fallback_router: self.fallback_router.clone(),
|
||||
default_fallback: self.default_fallback,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -80,16 +85,16 @@ where
|
|||
impl<S, B> fmt::Debug for Router<S, B> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Router")
|
||||
.field("routes", &self.routes)
|
||||
.field("node", &self.node)
|
||||
.field("fallback", &self.fallback)
|
||||
.field("prev_route_id", &self.prev_route_id)
|
||||
.field("path_router", &self.path_router)
|
||||
.field("fallback_router", &self.fallback_router)
|
||||
.field("default_fallback", &self.default_fallback)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param";
|
||||
pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";
|
||||
pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback";
|
||||
|
||||
impl<S, B> Router<S, B>
|
||||
where
|
||||
|
@ -101,57 +106,25 @@ where
|
|||
/// Unless you add additional routes this will respond with `404 Not Found` to
|
||||
/// all requests.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
routes: Default::default(),
|
||||
node: Default::default(),
|
||||
fallback: Fallback::Default(Route::new(NotFound)),
|
||||
prev_route_id: RouteId(0),
|
||||
}
|
||||
let mut this = Self {
|
||||
path_router: Default::default(),
|
||||
fallback_router: Default::default(),
|
||||
default_fallback: true,
|
||||
};
|
||||
this = this.fallback_service(NotFound);
|
||||
this.default_fallback = true;
|
||||
this
|
||||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/route.md")]
|
||||
#[track_caller]
|
||||
pub fn route(mut self, path: &str, method_router: MethodRouter<S, B>) -> Self {
|
||||
#[track_caller]
|
||||
fn validate_path(path: &str) {
|
||||
if path.is_empty() {
|
||||
panic!("Paths must start with a `/`. Use \"/\" for root routes");
|
||||
} else if !path.starts_with('/') {
|
||||
panic!("Paths must start with a `/`");
|
||||
}
|
||||
}
|
||||
|
||||
validate_path(path);
|
||||
|
||||
let id = self.next_route_id();
|
||||
|
||||
let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
|
||||
.node
|
||||
.path_to_route_id
|
||||
.get(path)
|
||||
.and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc)))
|
||||
{
|
||||
// if we're adding a new `MethodRouter` to a route that already has one just
|
||||
// merge them. This makes `.route("/", get(_)).route("/", post(_))` work
|
||||
let service = Endpoint::MethodRouter(
|
||||
prev_method_router
|
||||
.clone()
|
||||
.merge_for_path(Some(path), method_router),
|
||||
);
|
||||
self.routes.insert(route_id, service);
|
||||
return self;
|
||||
} else {
|
||||
Endpoint::MethodRouter(method_router)
|
||||
};
|
||||
|
||||
self.set_node(path, id);
|
||||
self.routes.insert(id, endpoint);
|
||||
|
||||
panic_on_err!(self.path_router.route(path, method_router));
|
||||
self
|
||||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/route_service.md")]
|
||||
pub fn route_service<T>(self, path: &str, service: T) -> Self
|
||||
pub fn route_service<T>(mut self, path: &str, service: T) -> Self
|
||||
where
|
||||
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
|
||||
T::Response: IntoResponse,
|
||||
|
@ -164,104 +137,40 @@ where
|
|||
Use `Router::nest` instead"
|
||||
);
|
||||
}
|
||||
Err(svc) => svc,
|
||||
Err(service) => service,
|
||||
};
|
||||
|
||||
self.route_endpoint(path, Endpoint::Route(Route::new(service)))
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn route_endpoint(mut self, path: &str, endpoint: Endpoint<S, B>) -> Self {
|
||||
if path.is_empty() {
|
||||
panic!("Paths must start with a `/`. Use \"/\" for root routes");
|
||||
} else if !path.starts_with('/') {
|
||||
panic!("Paths must start with a `/`");
|
||||
}
|
||||
|
||||
let id = self.next_route_id();
|
||||
self.set_node(path, id);
|
||||
self.routes.insert(id, endpoint);
|
||||
panic_on_err!(self.path_router.route_service(path, service));
|
||||
self
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn set_node(&mut self, path: &str, id: RouteId) {
|
||||
let mut node =
|
||||
Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone());
|
||||
if let Err(err) = node.insert(path, id) {
|
||||
panic!("Invalid route {path:?}: {err}");
|
||||
}
|
||||
self.node = Arc::new(node);
|
||||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/nest.md")]
|
||||
#[track_caller]
|
||||
pub fn nest(self, path: &str, router: Router<S, B>) -> Self {
|
||||
self.nest_endpoint(path, RouterOrService::<_, _, NotFound>::Router(router))
|
||||
pub fn nest(mut self, path: &str, router: Router<S, B>) -> Self {
|
||||
let Router {
|
||||
path_router,
|
||||
fallback_router,
|
||||
default_fallback,
|
||||
} = router;
|
||||
|
||||
panic_on_err!(self.path_router.nest(path, path_router));
|
||||
|
||||
if !default_fallback {
|
||||
panic_on_err!(self.fallback_router.nest(path, fallback_router));
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Like [`nest`](Self::nest), but accepts an arbitrary `Service`.
|
||||
#[track_caller]
|
||||
pub fn nest_service<T>(self, path: &str, svc: T) -> Self
|
||||
pub fn nest_service<T>(mut self, path: &str, service: T) -> Self
|
||||
where
|
||||
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
|
||||
T::Response: IntoResponse,
|
||||
T::Future: Send + 'static,
|
||||
{
|
||||
self.nest_endpoint(path, RouterOrService::Service(svc))
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn nest_endpoint<T>(
|
||||
mut self,
|
||||
mut path: &str,
|
||||
router_or_service: RouterOrService<S, B, T>,
|
||||
) -> Self
|
||||
where
|
||||
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
|
||||
T::Response: IntoResponse,
|
||||
T::Future: Send + 'static,
|
||||
{
|
||||
if path.is_empty() {
|
||||
// nesting at `""` and `"/"` should mean the same thing
|
||||
path = "/";
|
||||
}
|
||||
|
||||
if path.contains('*') {
|
||||
panic!("Invalid route: nested routes cannot contain wildcards (*)");
|
||||
}
|
||||
|
||||
let prefix = path;
|
||||
|
||||
let path = if path.ends_with('/') {
|
||||
format!("{path}*{NEST_TAIL_PARAM}")
|
||||
} else {
|
||||
format!("{path}/*{NEST_TAIL_PARAM}")
|
||||
};
|
||||
|
||||
let endpoint = match router_or_service {
|
||||
RouterOrService::Router(router) => {
|
||||
let prefix = prefix.to_owned();
|
||||
let boxed = BoxedIntoRoute::from_router(router)
|
||||
.map(move |route| Route::new(StripPrefix::new(route, &prefix)));
|
||||
Endpoint::NestedRouter(boxed)
|
||||
}
|
||||
RouterOrService::Service(svc) => {
|
||||
Endpoint::Route(Route::new(StripPrefix::new(svc, prefix)))
|
||||
}
|
||||
};
|
||||
|
||||
self = self.route_endpoint(&path, endpoint.clone());
|
||||
|
||||
// `/*rest` is not matched by `/` so we need to also register a router at the
|
||||
// prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself
|
||||
// wouldn't match, which it should
|
||||
self = self.route_endpoint(prefix, endpoint.clone());
|
||||
if !prefix.ends_with('/') {
|
||||
// same goes for `/foo/`, that should also match
|
||||
self = self.route_endpoint(&format!("{prefix}/"), endpoint);
|
||||
}
|
||||
|
||||
panic_on_err!(self.path_router.nest_service(path, service));
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -272,30 +181,32 @@ where
|
|||
R: Into<Router<S, B>>,
|
||||
{
|
||||
let Router {
|
||||
routes,
|
||||
node,
|
||||
fallback,
|
||||
prev_route_id: _,
|
||||
path_router,
|
||||
fallback_router: other_fallback,
|
||||
default_fallback,
|
||||
} = other.into();
|
||||
|
||||
for (id, route) in routes {
|
||||
let path = node
|
||||
.route_id_to_path
|
||||
.get(&id)
|
||||
.expect("no path for route id. This is a bug in axum. Please file an issue");
|
||||
self = match route {
|
||||
Endpoint::MethodRouter(method_router) => self.route(path, method_router),
|
||||
Endpoint::Route(route) => self.route_service(path, route),
|
||||
Endpoint::NestedRouter(router) => {
|
||||
self.route_endpoint(path, Endpoint::NestedRouter(router))
|
||||
}
|
||||
};
|
||||
}
|
||||
panic_on_err!(self.path_router.merge(path_router));
|
||||
|
||||
self.fallback = self
|
||||
.fallback
|
||||
.merge(fallback)
|
||||
.expect("Cannot merge two `Router`s that both have a fallback");
|
||||
match (self.default_fallback, default_fallback) {
|
||||
// both have the default fallback
|
||||
// use the one from other
|
||||
(true, true) => {
|
||||
self.fallback_router = other_fallback;
|
||||
}
|
||||
// self has default fallback, other has a custom fallback
|
||||
(true, false) => {
|
||||
self.fallback_router = other_fallback;
|
||||
self.default_fallback = false;
|
||||
}
|
||||
// self has a custom fallback, other has a default
|
||||
// nothing to do
|
||||
(false, true) => {}
|
||||
// both have a custom fallback, not allowed
|
||||
(false, false) => {
|
||||
panic!("Cannot merge two `Router`s that both have a fallback")
|
||||
}
|
||||
};
|
||||
|
||||
self
|
||||
}
|
||||
|
@ -310,22 +221,10 @@ where
|
|||
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
|
||||
NewReqBody: HttpBody + 'static,
|
||||
{
|
||||
let routes = self
|
||||
.routes
|
||||
.into_iter()
|
||||
.map(|(id, endpoint)| {
|
||||
let route = endpoint.layer(layer.clone());
|
||||
(id, route)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let fallback = self.fallback.map(|route| route.layer(layer));
|
||||
|
||||
Router {
|
||||
routes,
|
||||
node: self.node,
|
||||
fallback,
|
||||
prev_route_id: self.prev_route_id,
|
||||
path_router: self.path_router.layer(layer.clone()),
|
||||
fallback_router: self.fallback_router.layer(layer),
|
||||
default_fallback: self.default_fallback,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -339,79 +238,50 @@ where
|
|||
<L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
|
||||
<L::Service as Service<Request<B>>>::Future: Send + 'static,
|
||||
{
|
||||
if self.routes.is_empty() {
|
||||
panic!(
|
||||
"Adding a route_layer before any routes is a no-op. \
|
||||
Add the routes you want the layer to apply to first."
|
||||
);
|
||||
}
|
||||
|
||||
let routes = self
|
||||
.routes
|
||||
.into_iter()
|
||||
.map(|(id, endpoint)| {
|
||||
let route = endpoint.layer(layer.clone());
|
||||
(id, route)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Router {
|
||||
routes,
|
||||
node: self.node,
|
||||
fallback: self.fallback,
|
||||
prev_route_id: self.prev_route_id,
|
||||
path_router: self.path_router.route_layer(layer),
|
||||
fallback_router: self.fallback_router,
|
||||
default_fallback: self.default_fallback,
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
#[doc = include_str!("../docs/routing/fallback.md")]
|
||||
pub fn fallback<H, T>(mut self, handler: H) -> Self
|
||||
pub fn fallback<H, T>(self, handler: H) -> Self
|
||||
where
|
||||
H: Handler<T, S, B>,
|
||||
T: 'static,
|
||||
{
|
||||
self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
|
||||
self
|
||||
let endpoint = Endpoint::MethodRouter(any(handler));
|
||||
self.fallback_endpoint(endpoint)
|
||||
}
|
||||
|
||||
/// Add a fallback [`Service`] to the router.
|
||||
///
|
||||
/// See [`Router::fallback`] for more details.
|
||||
pub fn fallback_service<T>(mut self, svc: T) -> Self
|
||||
pub fn fallback_service<T>(self, service: T) -> Self
|
||||
where
|
||||
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
|
||||
T::Response: IntoResponse,
|
||||
T::Future: Send + 'static,
|
||||
{
|
||||
self.fallback = Fallback::Service(Route::new(svc));
|
||||
self.fallback_endpoint(Endpoint::Route(Route::new(service)))
|
||||
}
|
||||
|
||||
fn fallback_endpoint(mut self, endpoint: Endpoint<S, B>) -> Self {
|
||||
self.fallback_router.replace_endpoint("/", endpoint.clone());
|
||||
self.fallback_router
|
||||
.replace_endpoint(&format!("/*{FALLBACK_PARAM}"), endpoint);
|
||||
self.default_fallback = false;
|
||||
self
|
||||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/with_state.md")]
|
||||
pub fn with_state<S2>(self, state: S) -> Router<S2, B> {
|
||||
let routes = self
|
||||
.routes
|
||||
.into_iter()
|
||||
.map(|(id, endpoint)| {
|
||||
let endpoint: Endpoint<S2, B> = match endpoint {
|
||||
Endpoint::MethodRouter(method_router) => {
|
||||
Endpoint::MethodRouter(method_router.with_state(state.clone()))
|
||||
}
|
||||
Endpoint::Route(route) => Endpoint::Route(route),
|
||||
Endpoint::NestedRouter(router) => {
|
||||
Endpoint::Route(router.into_route(state.clone()))
|
||||
}
|
||||
};
|
||||
(id, endpoint)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let fallback = self.fallback.with_state(state);
|
||||
|
||||
Router {
|
||||
routes,
|
||||
node: self.node,
|
||||
fallback,
|
||||
prev_route_id: self.prev_route_id,
|
||||
path_router: self.path_router.with_state(state.clone()),
|
||||
fallback_router: self.fallback_router.with_state(state),
|
||||
default_fallback: self.default_fallback,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -420,86 +290,43 @@ where
|
|||
mut req: Request<B>,
|
||||
state: S,
|
||||
) -> RouteFuture<B, Infallible> {
|
||||
#[cfg(feature = "original-uri")]
|
||||
{
|
||||
use crate::extract::OriginalUri;
|
||||
|
||||
if req.extensions().get::<OriginalUri>().is_none() {
|
||||
let original_uri = OriginalUri(req.uri().clone());
|
||||
req.extensions_mut().insert(original_uri);
|
||||
}
|
||||
// required for opaque routers to still inherit the fallback
|
||||
// TODO(david): remove this feature in 0.7
|
||||
if !self.default_fallback {
|
||||
req.extensions_mut().insert(SuperFallback(SyncWrapper::new(
|
||||
self.fallback_router.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
let path = req.uri().path().to_owned();
|
||||
match self.path_router.call_with_state(req, state) {
|
||||
Ok(future) => {
|
||||
println!("path_router hit");
|
||||
future
|
||||
}
|
||||
Err((mut req, state)) => {
|
||||
let super_fallback = req
|
||||
.extensions_mut()
|
||||
.remove::<SuperFallback<S, B>>()
|
||||
.map(|SuperFallback(path_router)| path_router.into_inner());
|
||||
|
||||
match self.node.at(&path) {
|
||||
Ok(match_) => {
|
||||
match &self.fallback {
|
||||
Fallback::Default(_) => {}
|
||||
Fallback::Service(fallback) => {
|
||||
req.extensions_mut()
|
||||
.insert(SuperFallback(SyncWrapper::new(fallback.clone())));
|
||||
}
|
||||
Fallback::BoxedHandler(fallback) => {
|
||||
req.extensions_mut().insert(SuperFallback(SyncWrapper::new(
|
||||
fallback.clone().into_route(state.clone()),
|
||||
)));
|
||||
}
|
||||
if let Some(mut super_fallback) = super_fallback {
|
||||
return super_fallback
|
||||
.call_with_state(req, state)
|
||||
.unwrap_or_else(|_| unreachable!());
|
||||
}
|
||||
|
||||
let id = *match_.value;
|
||||
|
||||
#[cfg(feature = "matched-path")]
|
||||
crate::extract::matched_path::set_matched_path_for_request(
|
||||
id,
|
||||
&self.node.route_id_to_path,
|
||||
req.extensions_mut(),
|
||||
);
|
||||
|
||||
url_params::insert_url_params(req.extensions_mut(), match_.params);
|
||||
|
||||
let endpont = self
|
||||
.routes
|
||||
.get_mut(&id)
|
||||
.expect("no route for id. This is a bug in axum. Please file an issue");
|
||||
|
||||
match endpont {
|
||||
Endpoint::MethodRouter(method_router) => {
|
||||
method_router.call_with_state(req, state)
|
||||
match self.fallback_router.call_with_state(req, state) {
|
||||
Ok(future) => future,
|
||||
Err((_req, _state)) => {
|
||||
unreachable!(
|
||||
"the default fallback added in `Router::new` \
|
||||
matches everything"
|
||||
)
|
||||
}
|
||||
Endpoint::Route(route) => route.call(req),
|
||||
Endpoint::NestedRouter(router) => router.clone().call_with_state(req, state),
|
||||
}
|
||||
}
|
||||
Err(
|
||||
MatchError::NotFound
|
||||
| MatchError::ExtraTrailingSlash
|
||||
| MatchError::MissingTrailingSlash,
|
||||
) => match &mut self.fallback {
|
||||
Fallback::Default(fallback) => {
|
||||
if let Some(super_fallback) = req.extensions_mut().remove::<SuperFallback<B>>()
|
||||
{
|
||||
let mut super_fallback = super_fallback.0.into_inner();
|
||||
super_fallback.call(req)
|
||||
} else {
|
||||
fallback.call(req)
|
||||
}
|
||||
}
|
||||
Fallback::Service(fallback) => fallback.call(req),
|
||||
Fallback::BoxedHandler(handler) => handler.clone().into_route(state).call(req),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn next_route_id(&mut self) -> RouteId {
|
||||
let next_id = self
|
||||
.prev_route_id
|
||||
.0
|
||||
.checked_add(1)
|
||||
.expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
|
||||
self.prev_route_id = RouteId(next_id);
|
||||
self.prev_route_id
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> Router<(), B>
|
||||
|
@ -563,47 +390,6 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Wrapper around `matchit::Router` that supports merging two `Router`s.
|
||||
#[derive(Clone, Default)]
|
||||
struct Node {
|
||||
inner: matchit::Router<RouteId>,
|
||||
route_id_to_path: HashMap<RouteId, Arc<str>>,
|
||||
path_to_route_id: HashMap<Arc<str>, RouteId>,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
fn insert(
|
||||
&mut self,
|
||||
path: impl Into<String>,
|
||||
val: RouteId,
|
||||
) -> Result<(), matchit::InsertError> {
|
||||
let path = path.into();
|
||||
|
||||
self.inner.insert(&path, val)?;
|
||||
|
||||
let shared_path: Arc<str> = path.into();
|
||||
self.route_id_to_path.insert(val, shared_path.clone());
|
||||
self.path_to_route_id.insert(shared_path, val);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn at<'n, 'p>(
|
||||
&'n self,
|
||||
path: &'p str,
|
||||
) -> Result<matchit::Match<'n, 'p, &'n RouteId>, MatchError> {
|
||||
self.inner.at(path)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Node {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Node")
|
||||
.field("paths", &self.route_id_to_path)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
enum Fallback<S, B, E = Infallible> {
|
||||
Default(Route<B, E>),
|
||||
Service(Route<B, E>),
|
||||
|
@ -671,7 +457,6 @@ impl<S, B, E> fmt::Debug for Fallback<S, B, E> {
|
|||
enum Endpoint<S, B> {
|
||||
MethodRouter(MethodRouter<S, B>),
|
||||
Route(Route<B>),
|
||||
NestedRouter(BoxedIntoRoute<S, B, Infallible>),
|
||||
}
|
||||
|
||||
impl<S, B> Endpoint<S, B>
|
||||
|
@ -693,9 +478,6 @@ where
|
|||
Endpoint::MethodRouter(method_router.layer(layer))
|
||||
}
|
||||
Endpoint::Route(route) => Endpoint::Route(route.layer(layer)),
|
||||
Endpoint::NestedRouter(router) => {
|
||||
Endpoint::NestedRouter(router.map(|route| route.layer(layer)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -705,7 +487,6 @@ impl<S, B> Clone for Endpoint<S, B> {
|
|||
match self {
|
||||
Self::MethodRouter(inner) => Self::MethodRouter(inner.clone()),
|
||||
Self::Route(inner) => Self::Route(inner.clone()),
|
||||
Self::NestedRouter(router) => Self::NestedRouter(router.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -717,17 +498,11 @@ impl<S, B> fmt::Debug for Endpoint<S, B> {
|
|||
f.debug_tuple("MethodRouter").field(method_router).finish()
|
||||
}
|
||||
Self::Route(route) => f.debug_tuple("Route").field(route).finish(),
|
||||
Self::NestedRouter(router) => f.debug_tuple("NestedRouter").field(router).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum RouterOrService<S, B, T> {
|
||||
Router(Router<S, B>),
|
||||
Service(T),
|
||||
}
|
||||
|
||||
struct SuperFallback<B>(SyncWrapper<Route<B>>);
|
||||
struct SuperFallback<S, B>(SyncWrapper<PathRouter<S, B>>);
|
||||
|
||||
#[test]
|
||||
#[allow(warnings)]
|
||||
|
|
|
@ -29,6 +29,7 @@ where
|
|||
}
|
||||
|
||||
fn call(&mut self, _req: Request<B>) -> Self::Future {
|
||||
println!("NotFound hit");
|
||||
ready(Ok(StatusCode::NOT_FOUND.into_response()))
|
||||
}
|
||||
}
|
||||
|
|
445
axum/src/routing/path_router.rs
Normal file
445
axum/src/routing/path_router.rs
Normal file
|
@ -0,0 +1,445 @@
|
|||
use crate::body::{Body, HttpBody};
|
||||
use axum_core::response::IntoResponse;
|
||||
use http::Request;
|
||||
use matchit::MatchError;
|
||||
use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
|
||||
use tower_layer::Layer;
|
||||
use tower_service::Service;
|
||||
|
||||
use super::{
|
||||
future::RouteFuture, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route,
|
||||
RouteId, NEST_TAIL_PARAM,
|
||||
};
|
||||
|
||||
pub(super) struct PathRouter<S = (), B = Body> {
|
||||
routes: HashMap<RouteId, Endpoint<S, B>>,
|
||||
node: Arc<Node>,
|
||||
prev_route_id: RouteId,
|
||||
}
|
||||
|
||||
impl<S, B> PathRouter<S, B>
|
||||
where
|
||||
B: HttpBody + Send + 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
pub(super) fn route(
|
||||
&mut self,
|
||||
path: &str,
|
||||
method_router: MethodRouter<S, B>,
|
||||
) -> Result<(), Cow<'static, str>> {
|
||||
fn validate_path(path: &str) -> Result<(), &'static str> {
|
||||
if path.is_empty() {
|
||||
return Err("Paths must start with a `/`. Use \"/\" for root routes");
|
||||
} else if !path.starts_with('/') {
|
||||
return Err("Paths must start with a `/`");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
validate_path(path)?;
|
||||
|
||||
let id = self.next_route_id();
|
||||
|
||||
let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
|
||||
.node
|
||||
.path_to_route_id
|
||||
.get(path)
|
||||
.and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc)))
|
||||
{
|
||||
// if we're adding a new `MethodRouter` to a route that already has one just
|
||||
// merge them. This makes `.route("/", get(_)).route("/", post(_))` work
|
||||
let service = Endpoint::MethodRouter(
|
||||
prev_method_router
|
||||
.clone()
|
||||
.merge_for_path(Some(path), method_router),
|
||||
);
|
||||
self.routes.insert(route_id, service);
|
||||
return Ok(());
|
||||
} else {
|
||||
Endpoint::MethodRouter(method_router)
|
||||
};
|
||||
|
||||
self.set_node(path, id)?;
|
||||
self.routes.insert(id, endpoint);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn route_service<T>(
|
||||
&mut self,
|
||||
path: &str,
|
||||
service: T,
|
||||
) -> Result<(), Cow<'static, str>>
|
||||
where
|
||||
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
|
||||
T::Response: IntoResponse,
|
||||
T::Future: Send + 'static,
|
||||
{
|
||||
self.route_endpoint(path, Endpoint::Route(Route::new(service)))
|
||||
}
|
||||
|
||||
pub(super) fn route_endpoint(
|
||||
&mut self,
|
||||
path: &str,
|
||||
endpoint: Endpoint<S, B>,
|
||||
) -> Result<(), Cow<'static, str>> {
|
||||
if path.is_empty() {
|
||||
return Err("Paths must start with a `/`. Use \"/\" for root routes".into());
|
||||
} else if !path.starts_with('/') {
|
||||
return Err("Paths must start with a `/`".into());
|
||||
}
|
||||
|
||||
let id = self.next_route_id();
|
||||
self.set_node(path, id)?;
|
||||
self.routes.insert(id, endpoint);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_node(&mut self, path: &str, id: RouteId) -> Result<(), String> {
|
||||
let mut node =
|
||||
Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone());
|
||||
if let Err(err) = node.insert(path, id) {
|
||||
return Err(format!("Invalid route {path:?}: {err}"));
|
||||
}
|
||||
self.node = Arc::new(node);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn merge(&mut self, other: PathRouter<S, B>) -> Result<(), Cow<'static, str>> {
|
||||
let PathRouter {
|
||||
routes,
|
||||
node,
|
||||
prev_route_id: _,
|
||||
} = other;
|
||||
|
||||
for (id, route) in routes {
|
||||
let path = node
|
||||
.route_id_to_path
|
||||
.get(&id)
|
||||
.expect("no path for route id. This is a bug in axum. Please file an issue");
|
||||
match route {
|
||||
Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
|
||||
Endpoint::Route(route) => self.route_service(path, route)?,
|
||||
};
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn nest(
|
||||
&mut self,
|
||||
path: &str,
|
||||
router: PathRouter<S, B>,
|
||||
) -> Result<(), Cow<'static, str>> {
|
||||
let prefix = validate_nest_path(path);
|
||||
|
||||
let PathRouter {
|
||||
routes,
|
||||
node,
|
||||
prev_route_id: _,
|
||||
} = router;
|
||||
|
||||
for (id, endpoint) in routes {
|
||||
let inner_path = node
|
||||
.route_id_to_path
|
||||
.get(&id)
|
||||
.expect("no path for route id. This is a bug in axum. Please file an issue");
|
||||
|
||||
let path = path_for_nested_route(prefix, inner_path);
|
||||
|
||||
match endpoint.layer(StripPrefix::layer(prefix)) {
|
||||
Endpoint::MethodRouter(method_router) => {
|
||||
self.route(&path, method_router)?;
|
||||
}
|
||||
Endpoint::Route(route) => {
|
||||
self.route_endpoint(&path, Endpoint::Route(route))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn nest_service<T>(&mut self, path: &str, svc: T) -> Result<(), Cow<'static, str>>
|
||||
where
|
||||
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
|
||||
T::Response: IntoResponse,
|
||||
T::Future: Send + 'static,
|
||||
{
|
||||
let path = validate_nest_path(path);
|
||||
let prefix = path;
|
||||
|
||||
let path = if path.ends_with('/') {
|
||||
format!("{path}*{NEST_TAIL_PARAM}")
|
||||
} else {
|
||||
format!("{path}/*{NEST_TAIL_PARAM}")
|
||||
};
|
||||
|
||||
let endpoint = Endpoint::Route(Route::new(StripPrefix::new(svc, prefix)));
|
||||
|
||||
self.route_endpoint(&path, endpoint.clone())?;
|
||||
|
||||
// `/*rest` is not matched by `/` so we need to also register a router at the
|
||||
// prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself
|
||||
// wouldn't match, which it should
|
||||
self.route_endpoint(prefix, endpoint.clone())?;
|
||||
if !prefix.ends_with('/') {
|
||||
// same goes for `/foo/`, that should also match
|
||||
self.route_endpoint(&format!("{prefix}/"), endpoint)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn layer<L, NewReqBody>(self, layer: L) -> PathRouter<S, NewReqBody>
|
||||
where
|
||||
L: Layer<Route<B>> + Clone + Send + 'static,
|
||||
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
|
||||
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
|
||||
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
|
||||
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
|
||||
NewReqBody: HttpBody + 'static,
|
||||
{
|
||||
let routes = self
|
||||
.routes
|
||||
.into_iter()
|
||||
.map(|(id, endpoint)| {
|
||||
let route = endpoint.layer(layer.clone());
|
||||
(id, route)
|
||||
})
|
||||
.collect();
|
||||
|
||||
PathRouter {
|
||||
routes,
|
||||
node: self.node,
|
||||
prev_route_id: self.prev_route_id,
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
pub(super) fn route_layer<L>(self, layer: L) -> Self
|
||||
where
|
||||
L: Layer<Route<B>> + Clone + Send + 'static,
|
||||
L::Service: Service<Request<B>> + Clone + Send + 'static,
|
||||
<L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
|
||||
<L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
|
||||
<L::Service as Service<Request<B>>>::Future: Send + 'static,
|
||||
{
|
||||
if self.routes.is_empty() {
|
||||
panic!(
|
||||
"Adding a route_layer before any routes is a no-op. \
|
||||
Add the routes you want the layer to apply to first."
|
||||
);
|
||||
}
|
||||
|
||||
let routes = self
|
||||
.routes
|
||||
.into_iter()
|
||||
.map(|(id, endpoint)| {
|
||||
let route = endpoint.layer(layer.clone());
|
||||
(id, route)
|
||||
})
|
||||
.collect();
|
||||
|
||||
PathRouter {
|
||||
routes,
|
||||
node: self.node,
|
||||
prev_route_id: self.prev_route_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, B> {
|
||||
let routes = self
|
||||
.routes
|
||||
.into_iter()
|
||||
.map(|(id, endpoint)| {
|
||||
let endpoint: Endpoint<S2, B> = match endpoint {
|
||||
Endpoint::MethodRouter(method_router) => {
|
||||
Endpoint::MethodRouter(method_router.with_state(state.clone()))
|
||||
}
|
||||
Endpoint::Route(route) => Endpoint::Route(route),
|
||||
};
|
||||
(id, endpoint)
|
||||
})
|
||||
.collect();
|
||||
|
||||
PathRouter {
|
||||
routes,
|
||||
node: self.node,
|
||||
prev_route_id: self.prev_route_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn call_with_state(
|
||||
&mut self,
|
||||
mut req: Request<B>,
|
||||
state: S,
|
||||
) -> Result<RouteFuture<B, Infallible>, (Request<B>, S)> {
|
||||
#[cfg(feature = "original-uri")]
|
||||
{
|
||||
use crate::extract::OriginalUri;
|
||||
|
||||
if req.extensions().get::<OriginalUri>().is_none() {
|
||||
let original_uri = OriginalUri(req.uri().clone());
|
||||
req.extensions_mut().insert(original_uri);
|
||||
}
|
||||
}
|
||||
|
||||
let path = req.uri().path().to_owned();
|
||||
|
||||
match self.node.at(&path) {
|
||||
Ok(match_) => {
|
||||
let id = *match_.value;
|
||||
|
||||
#[cfg(feature = "matched-path")]
|
||||
crate::extract::matched_path::set_matched_path_for_request(
|
||||
id,
|
||||
&self.node.route_id_to_path,
|
||||
req.extensions_mut(),
|
||||
);
|
||||
|
||||
url_params::insert_url_params(req.extensions_mut(), match_.params);
|
||||
|
||||
let endpont = self
|
||||
.routes
|
||||
.get_mut(&id)
|
||||
.expect("no route for id. This is a bug in axum. Please file an issue");
|
||||
|
||||
match endpont {
|
||||
Endpoint::MethodRouter(method_router) => {
|
||||
Ok(method_router.call_with_state(req, state))
|
||||
}
|
||||
Endpoint::Route(route) => Ok(route.clone().call(req)),
|
||||
}
|
||||
}
|
||||
// explicitly handle all variants in case matchit adds
|
||||
// new ones we need to handle differently
|
||||
Err(
|
||||
MatchError::NotFound
|
||||
| MatchError::ExtraTrailingSlash
|
||||
| MatchError::MissingTrailingSlash,
|
||||
) => Err((req, state)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S, B>) {
|
||||
match self.node.at(path) {
|
||||
Ok(match_) => {
|
||||
let id = *match_.value;
|
||||
self.routes.insert(id, endpoint);
|
||||
}
|
||||
Err(_) => self
|
||||
.route_endpoint(path, endpoint)
|
||||
.expect("path wasn't matched so endpoint shouldn't exist"),
|
||||
}
|
||||
}
|
||||
|
||||
fn next_route_id(&mut self) -> RouteId {
|
||||
let next_id = self
|
||||
.prev_route_id
|
||||
.0
|
||||
.checked_add(1)
|
||||
.expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
|
||||
self.prev_route_id = RouteId(next_id);
|
||||
self.prev_route_id
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, S> Default for PathRouter<S, B> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
routes: Default::default(),
|
||||
node: Default::default(),
|
||||
prev_route_id: RouteId(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, B> fmt::Debug for PathRouter<S, B> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("PathRouter")
|
||||
.field("routes", &self.routes)
|
||||
.field("node", &self.node)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, B> Clone for PathRouter<S, B> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
routes: self.routes.clone(),
|
||||
node: self.node.clone(),
|
||||
prev_route_id: self.prev_route_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper around `matchit::Router` that supports merging two `Router`s.
|
||||
#[derive(Clone, Default)]
|
||||
struct Node {
|
||||
inner: matchit::Router<RouteId>,
|
||||
route_id_to_path: HashMap<RouteId, Arc<str>>,
|
||||
path_to_route_id: HashMap<Arc<str>, RouteId>,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
fn insert(
|
||||
&mut self,
|
||||
path: impl Into<String>,
|
||||
val: RouteId,
|
||||
) -> Result<(), matchit::InsertError> {
|
||||
let path = path.into();
|
||||
|
||||
self.inner.insert(&path, val)?;
|
||||
|
||||
let shared_path: Arc<str> = path.into();
|
||||
self.route_id_to_path.insert(val, shared_path.clone());
|
||||
self.path_to_route_id.insert(shared_path, val);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn at<'n, 'p>(
|
||||
&'n self,
|
||||
path: &'p str,
|
||||
) -> Result<matchit::Match<'n, 'p, &'n RouteId>, MatchError> {
|
||||
self.inner.at(path)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Node {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Node")
|
||||
.field("paths", &self.route_id_to_path)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn validate_nest_path(path: &str) -> &str {
|
||||
if path.is_empty() {
|
||||
// nesting at `""` and `"/"` should mean the same thing
|
||||
return "/";
|
||||
}
|
||||
|
||||
if path.contains('*') {
|
||||
panic!("Invalid route: nested routes cannot contain wildcards (*)");
|
||||
}
|
||||
|
||||
path
|
||||
}
|
||||
|
||||
pub(crate) fn path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str> {
|
||||
debug_assert!(prefix.starts_with('/'));
|
||||
debug_assert!(path.starts_with('/'));
|
||||
|
||||
if prefix.ends_with('/') {
|
||||
format!("{prefix}{}", path.trim_start_matches('/')).into()
|
||||
} else if path == "/" {
|
||||
prefix.into()
|
||||
} else {
|
||||
format!("{prefix}{path}").into()
|
||||
}
|
||||
}
|
|
@ -3,6 +3,8 @@ use std::{
|
|||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tower::Layer;
|
||||
use tower_layer::layer_fn;
|
||||
use tower_service::Service;
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -18,6 +20,14 @@ impl<S> StripPrefix<S> {
|
|||
prefix: prefix.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn layer(prefix: &str) -> impl Layer<S, Service = Self> + Clone {
|
||||
let prefix = Arc::from(prefix);
|
||||
layer_fn(move |inner| Self {
|
||||
inner,
|
||||
prefix: Arc::clone(&prefix),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, B> Service<Request<B>> for StripPrefix<S>
|
||||
|
|
|
@ -93,6 +93,10 @@ async fn doesnt_inherit_fallback_if_overriden() {
|
|||
let res = client.get("/foo/bar").send().await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
assert_eq!(res.text().await, "inner");
|
||||
|
||||
let res = client.get("/").send().await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
assert_eq!(res.text().await, "outer");
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
|
@ -203,3 +207,21 @@ async fn fallback_inherited_into_nested_opaque_service() {
|
|||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
assert_eq!(res.text().await, "outer");
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn nest_fallback_on_inner() {
|
||||
let app = Router::new()
|
||||
.nest(
|
||||
"/foo",
|
||||
Router::new()
|
||||
.route("/", get(|| async {}))
|
||||
.fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }),
|
||||
)
|
||||
.fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") });
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo/not-found").send().await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
assert_eq!(res.text().await, "inner fallback");
|
||||
}
|
||||
|
|
|
@ -4,7 +4,10 @@ use crate::{
|
|||
extract::{self, DefaultBodyLimit, FromRef, Path, State},
|
||||
handler::{Handler, HandlerWithoutStateExt},
|
||||
response::IntoResponse,
|
||||
routing::{delete, get, get_service, on, on_service, patch, patch_service, post, MethodFilter},
|
||||
routing::{
|
||||
delete, get, get_service, on, on_service, patch, patch_service,
|
||||
path_router::path_for_nested_route, post, MethodFilter,
|
||||
},
|
||||
test_helpers::*,
|
||||
BoxError, Json, Router,
|
||||
};
|
||||
|
@ -601,7 +604,10 @@ async fn head_with_middleware_applied() {
|
|||
use tower_http::compression::{predicate::SizeAbove, CompressionLayer};
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async { "Hello, World!" }))
|
||||
.nest(
|
||||
"/",
|
||||
Router::new().route("/", get(|| async { "Hello, World!" })),
|
||||
)
|
||||
.layer(CompressionLayer::new().compress_when(SizeAbove::new(0)));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
@ -841,6 +847,21 @@ fn method_router_fallback_with_state() {
|
|||
.with_state(state);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_path_for_nested_route() {
|
||||
assert_eq!(path_for_nested_route("/", "/"), "/");
|
||||
|
||||
assert_eq!(path_for_nested_route("/a", "/"), "/a");
|
||||
assert_eq!(path_for_nested_route("/", "/b"), "/b");
|
||||
assert_eq!(path_for_nested_route("/a/", "/"), "/a/");
|
||||
assert_eq!(path_for_nested_route("/", "/b/"), "/b/");
|
||||
|
||||
assert_eq!(path_for_nested_route("/a", "/b"), "/a/b");
|
||||
assert_eq!(path_for_nested_route("/a/", "/b"), "/a/b");
|
||||
assert_eq!(path_for_nested_route("/a", "/b/"), "/a/b/");
|
||||
assert_eq!(path_for_nested_route("/a/", "/b/"), "/a/b/");
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn state_isnt_cloned_too_much() {
|
||||
static SETUP_DONE: AtomicBool = AtomicBool::new(false);
|
||||
|
|
|
@ -230,6 +230,13 @@ async fn nested_multiple_routes() {
|
|||
|
||||
#[test]
|
||||
#[should_panic = "Invalid route \"/\": insertion failed due to conflict with previously registered route: /*__private__axum_nest_tail_param"]
|
||||
fn nested_service_at_root_with_other_routes() {
|
||||
let _: Router = Router::new()
|
||||
.nest_service("/", Router::new().route("/users", get(|| async {})))
|
||||
.route("/", get(|| async {}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nested_at_root_with_other_routes() {
|
||||
let _: Router = Router::new()
|
||||
.nest("/", Router::new().route("/users", get(|| async {})))
|
||||
|
@ -343,42 +350,40 @@ async fn nest_with_and_without_trailing() {
|
|||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
#[tokio::test]
|
||||
async fn nesting_with_root_inner_router() {
|
||||
let app = Router::new().nest(
|
||||
"/foo",
|
||||
Router::new().route("/", get(|| async { "inner route" })),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
// `/foo/` does match the `/foo` prefix and the remaining path is technically
|
||||
// empty, which is the same as `/` which matches `.route("/", _)`
|
||||
let res = client.get("/foo").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
// `/foo/` does match the `/foo` prefix and the remaining path is `/`
|
||||
// which matches `.route("/", _)`
|
||||
let res = client.get("/foo/").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn fallback_on_inner() {
|
||||
let app = Router::new()
|
||||
.nest(
|
||||
"/foo",
|
||||
Router::new()
|
||||
.route("/", get(|| async {}))
|
||||
.fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }),
|
||||
)
|
||||
.fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") });
|
||||
.nest_service("/service", Router::new().route("/", get(|| async {})))
|
||||
.nest("/router", Router::new().route("/", get(|| async {})))
|
||||
.nest("/router-slash/", Router::new().route("/", get(|| async {})));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo/not-found").send().await;
|
||||
// `/service/` does match the `/service` prefix and the remaining path is technically
|
||||
// empty, which is the same as `/` which matches `.route("/", _)`
|
||||
let res = client.get("/service").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
// `/service/` does match the `/service` prefix and the remaining path is `/`
|
||||
// which matches `.route("/", _)`
|
||||
//
|
||||
// this is perhaps a little surprising but don't think there is much we can do
|
||||
let res = client.get("/service/").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
// at least it does work like you'd expect when using `nest`
|
||||
|
||||
let res = client.get("/router").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let res = client.get("/router/").send().await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
assert_eq!(res.text().await, "inner fallback");
|
||||
|
||||
let res = client.get("/router-slash").send().await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
let res = client.get("/router-slash/").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
macro_rules! nested_route_test {
|
||||
|
|
|
@ -19,6 +19,7 @@ pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params) {
|
|||
let params = params
|
||||
.iter()
|
||||
.filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM))
|
||||
.filter(|(key, _)| !key.starts_with(super::FALLBACK_PARAM))
|
||||
.map(|(k, v)| {
|
||||
if let Some(decoded) = PercentDecodedStr::new(v) {
|
||||
Ok((Arc::from(k), decoded))
|
||||
|
|
Loading…
Reference in a new issue