diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 8bb09d6c..3167a3a9 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -6,7 +6,7 @@ use crate::{ body::{box_body, Body, BoxBody}, extract::{ connect_info::{Connected, IntoMakeServiceWithConnectInfo}, - OriginalUri, + MatchedPath, OriginalUri, }, util::{ByteStr, PercentDecodedByteStr}, BoxError, @@ -33,9 +33,9 @@ pub mod service_method_router; mod into_make_service; mod method_filter; mod method_not_allowed; -mod nested; mod not_found; mod route; +mod strip_prefix; pub(crate) use self::method_not_allowed::MethodNotAllowed; pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route}; @@ -92,7 +92,8 @@ impl fmt::Debug for Router { } } -const NEST_TAIL_PARAM: &str = "__axum_internal_nest_capture"; +pub(crate) const NEST_TAIL_PARAM: &str = "axum_nest"; +const NEST_TAIL_PARAM_CAPTURE: &str = "/*axum_nest"; impl Router where @@ -342,20 +343,43 @@ where panic!("Invalid route: nested routes cannot contain wildcards (*)"); } - let id = RouteId::next(); + let prefix = path; - let path = if path == "/" { - format!("/*{}", NEST_TAIL_PARAM) - } else { - format!("{}/*{}", path, NEST_TAIL_PARAM) - }; + match try_downcast::, _>(svc) { + // if the user is nesting a `Router` we can implement nesting + // by simplying copying all the routes and adding the prefix in + // front + Ok(router) => { + let Router { + mut routes, + node, + fallback: _, + } = router; - if let Err(err) = self.node.insert(path, id) { - panic!("Invalid route: {}", err); + for (id, nested_path) in node.paths { + let route = routes.remove(&id).unwrap(); + let full_path = if &*nested_path == "/" { + path.to_string() + } else { + format!("{}{}", path, nested_path) + }; + self = self.route(&full_path, strip_prefix::StripPrefix::new(route, prefix)); + } + + debug_assert!(routes.is_empty()); + } + // otherwise we add a wildcard route to the service + Err(svc) => { + let path = if path == "/" { + format!("/*{}", NEST_TAIL_PARAM) + } else { + format!("{}/*{}", path, NEST_TAIL_PARAM) + }; + + self = self.route(&path, strip_prefix::StripPrefix::new(svc, prefix)); + } } - self.routes.insert(id, Route::new(nested::Nested { svc })); - self } @@ -700,31 +724,8 @@ where /// ## When used with `Router::nest` /// /// If a router with a fallback is nested inside another router the fallback - /// will only apply to requests that matches the prefix: - /// - /// ```rust - /// use axum::{ - /// Router, - /// routing::get, - /// handler::Handler, - /// response::IntoResponse, - /// http::{StatusCode, Uri}, - /// }; - /// - /// let api = Router::new() - /// .route("/", get(|| async { /* ... */ })) - /// .fallback(api_fallback.into_service()); - /// - /// let app = Router::new().nest("/api", api); - /// - /// async fn api_fallback() -> impl IntoResponse { /* ... */ } - /// - /// // `api_fallback` will be called for `/api/some-unknown-path` but not for - /// // `/some-unknown-path` as the path doesn't start with `/api` - /// # async { - /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); - /// # }; - /// ``` + /// of the nested router will be discarded and not used. This is such that + /// the outer router's fallback takes precedence. pub fn fallback(mut self, svc: T) -> Self where T: Service, Response = Response, Error = Infallible> @@ -743,8 +744,22 @@ where req.extensions_mut().insert(id); if let Some(matched_path) = self.node.paths.get(&id) { - req.extensions_mut() - .insert(crate::extract::MatchedPath(matched_path.clone())); + let matched_path = if let Some(previous) = req.extensions_mut().get::() { + // a previous `MatchedPath` might exist if we're inside a nested Router + let previous = if let Some(previous) = + previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE) + { + previous + } else { + previous.as_str() + }; + + let matched_path = format!("{}{}", previous, matched_path); + matched_path.into() + } else { + Arc::clone(matched_path) + }; + req.extensions_mut().insert(MatchedPath(matched_path)); } let params = match_ @@ -754,11 +769,6 @@ where .map(|(key, value)| (key.to_string(), value.to_string())) .collect::>(); - if let Some(tail) = match_.params.get(NEST_TAIL_PARAM) { - req.extensions_mut() - .insert(nested::NestMatchTail(tail.to_string())); - } - insert_url_params(&mut req, params); let route = self @@ -970,14 +980,22 @@ impl Fallback { } } -#[cfg(test)] -mod tests { - use super::*; +fn try_downcast(k: K) -> Result +where + T: 'static, + K: Send + 'static, +{ + use std::any::Any; - #[test] - fn traits() { - use crate::tests::*; - - assert_send::>(); + let k = Box::new(k) as Box; + match k.downcast() { + Ok(t) => Ok(*t), + Err(other) => Err(*other.downcast().unwrap()), } } + +#[test] +fn traits() { + use crate::tests::*; + assert_send::>(); +} diff --git a/src/routing/nested.rs b/src/routing/nested.rs deleted file mode 100644 index 9ef7a1e9..00000000 --- a/src/routing/nested.rs +++ /dev/null @@ -1,69 +0,0 @@ -use crate::body::BoxBody; -use http::{Request, Response, Uri}; -use std::{ - convert::Infallible, - task::{Context, Poll}, -}; -use tower::util::Oneshot; -use tower::ServiceExt; -use tower_service::Service; - -/// A [`Service`] that has been nested inside a router at some path. -/// -/// Created with [`Router::nest`]. -#[derive(Debug, Clone)] -pub(super) struct Nested { - pub(super) svc: S, -} - -impl Service> for Nested -where - S: Service, Response = Response, Error = Infallible> + Clone, - B: Send + Sync + 'static, -{ - type Response = Response; - type Error = Infallible; - type Future = Oneshot>; - - #[inline] - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, mut req: Request) -> Self::Future { - // strip the prefix from the URI just before calling the inner service - // such that any surrounding middleware still see the full path - if let Some(tail) = req.extensions_mut().remove::() { - UriStack::push(&mut req); - let new_uri = super::with_path(req.uri(), &tail.0); - *req.uri_mut() = new_uri; - } - - self.svc.clone().oneshot(req) - } -} - -pub(crate) struct UriStack(Vec); - -impl UriStack { - fn push(req: &mut Request) { - let uri = req.uri().clone(); - - if let Some(stack) = req.extensions_mut().get_mut::() { - stack.0.push(uri); - } else { - req.extensions_mut().insert(Self(vec![uri])); - } - } -} - -#[derive(Clone)] -pub(super) struct NestMatchTail(pub(super) String); - -#[test] -fn traits() { - use crate::tests::*; - - assert_send::>(); - assert_sync::>(); -} diff --git a/src/routing/strip_prefix.rs b/src/routing/strip_prefix.rs new file mode 100644 index 00000000..c7950538 --- /dev/null +++ b/src/routing/strip_prefix.rs @@ -0,0 +1,77 @@ +use http::{Request, Uri}; +use std::{ + borrow::Cow, + sync::Arc, + task::{Context, Poll}, +}; +use tower_service::Service; + +#[derive(Clone)] +pub(super) struct StripPrefix { + inner: S, + prefix: Arc, +} + +impl StripPrefix { + pub(super) fn new(inner: S, prefix: &str) -> Self { + Self { + inner, + prefix: prefix.into(), + } + } +} + +impl Service> for StripPrefix +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let new_uri = strip_prefix(req.uri(), &self.prefix); + *req.uri_mut() = new_uri; + self.inner.call(req) + } +} + +fn strip_prefix(uri: &Uri, prefix: &str) -> Uri { + let path_and_query = if let Some(path_and_query) = uri.path_and_query() { + let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) { + path + } else { + path_and_query.path() + }; + + let new_path = if new_path.starts_with('/') { + Cow::Borrowed(new_path) + } else { + Cow::Owned(format!("/{}", new_path)) + }; + + if let Some(query) = path_and_query.query() { + Some( + format!("{}?{}", new_path, query) + .parse::() + .unwrap(), + ) + } else { + Some(new_path.parse().unwrap()) + } + } else { + None + }; + + let mut parts = http::uri::Parts::default(); + parts.scheme = uri.scheme().cloned(); + parts.authority = uri.authority().cloned(); + parts.path_and_query = path_and_query; + + Uri::from_parts(parts).unwrap() +} diff --git a/src/tests/fallback.rs b/src/tests/fallback.rs index c5708929..8107c13f 100644 --- a/src/tests/fallback.rs +++ b/src/tests/fallback.rs @@ -31,29 +31,6 @@ async fn nest() { assert_eq!(res.text().await, "fallback"); } -#[tokio::test] -async fn nesting_with_fallback() { - let app = Router::new().nest( - "/foo", - Router::new() - .route("/bar", get(|| async {})) - .fallback((|| async { "fallback" }).into_service()), - ); - - let client = TestClient::new(app); - - assert_eq!(client.get("/foo/bar").send().await.status(), StatusCode::OK); - - // this shouldn't exist because the fallback is inside the nested router - let res = client.get("/does-not-exist").send().await; - assert_eq!(res.status(), StatusCode::NOT_FOUND); - - // this should work since we get into the nested router - let res = client.get("/foo/does-not-exist").send().await; - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await, "fallback"); -} - #[tokio::test] async fn or() { let one = Router::new().route("/one", get(|| async {})); diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 0ce3595e..fc6c314b 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,7 +1,7 @@ #![allow(clippy::blacklisted_name)] use crate::error_handling::HandleErrorLayer; -use crate::extract::MatchedPath; +use crate::extract::{Extension, MatchedPath}; use crate::BoxError; use crate::{ extract::{self, Path}, @@ -29,7 +29,6 @@ use std::{ }; use tower::{service_fn, timeout::TimeoutLayer, ServiceBuilder}; use tower_http::auth::RequireAuthorizationLayer; -use tower_http::trace::TraceLayer; use tower_service::Service; pub(crate) use helpers::*; @@ -592,24 +591,92 @@ async fn with_and_without_trailing_slash() { assert_eq!(res.text().await, "without tsr"); } +#[derive(Clone)] +struct SetMatchedPathExtension(S); + +impl Service> for SetMatchedPathExtension +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let path = req + .extensions() + .get::() + .unwrap() + .as_str() + .to_string(); + req.extensions_mut().insert(MatchedPathFromMiddleware(path)); + self.0.call(req) + } +} + +#[derive(Clone)] +struct MatchedPathFromMiddleware(String); + #[tokio::test] async fn access_matched_path() { + let api = Router::new().route( + "/users/:id", + get(|path: MatchedPath| async move { path.as_str().to_string() }), + ); + + async fn handler( + path: MatchedPath, + Extension(MatchedPathFromMiddleware(path_from_middleware)): Extension< + MatchedPathFromMiddleware, + >, + ) -> String { + format!( + "extractor = {}, middleware = {}", + path.as_str(), + path_from_middleware + ) + } + let app = Router::new() .route( "/:key", get(|path: MatchedPath| async move { path.as_str().to_string() }), ) - .layer( - TraceLayer::new_for_http().make_span_with(|req: &Request<_>| { - let path = req.extensions().get::().unwrap().as_str(); - tracing::info_span!("http-request", %path) - }), - ); + .nest("/api", api) + .nest( + "/public", + Router::new().route("/assets/*path", get(handler)), + ) + .nest("/foo", handler.into_service()) + .layer(tower::layer::layer_fn(SetMatchedPathExtension)); let client = TestClient::new(app); let res = client.get("/foo").send().await; assert_eq!(res.text().await, "/:key"); + + let res = client.get("/api/users/123").send().await; + assert_eq!(res.text().await, "/api/users/:id"); + + let res = client.get("/public/assets/css/style.css").send().await; + assert_eq!( + res.text().await, + "extractor = /public/assets/*path, middleware = /public/assets/*path" + ); + + let res = client.get("/foo/bar/baz").send().await; + assert_eq!( + res.text().await, + format!( + "extractor = /foo/*{}, middleware = /foo/*{}", + crate::routing::NEST_TAIL_PARAM, + crate::routing::NEST_TAIL_PARAM, + ), + ); } #[tokio::test]