From 92f96a201c08b37567bb572d533afa080725d730 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Oct 2021 22:31:22 +0200 Subject: [PATCH] Implement nesting as regular wildcard routes (#426) When fixing bugs with `MatchedPath` (introduced to fix https://github.com/tokio-rs/axum/issues/386) I realized nesting could basically be implemented using regular routes, if we can detect that the service passed to `nest` is in fact a `Router`. Then we can transfer over all its routes and add the prefix. This makes nesting much simpler in general and should also be slightly faster since we're no longer nesting routers. --- src/routing/mod.rs | 124 +++++++++++++++++++++--------------- src/routing/nested.rs | 69 -------------------- src/routing/strip_prefix.rs | 77 ++++++++++++++++++++++ src/tests/fallback.rs | 23 ------- src/tests/mod.rs | 83 +++++++++++++++++++++--- 5 files changed, 223 insertions(+), 153 deletions(-) delete mode 100644 src/routing/nested.rs create mode 100644 src/routing/strip_prefix.rs diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 8bb09d6c..3167a3a9 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -6,7 +6,7 @@ use crate::{ body::{box_body, Body, BoxBody}, extract::{ connect_info::{Connected, IntoMakeServiceWithConnectInfo}, - OriginalUri, + MatchedPath, OriginalUri, }, util::{ByteStr, PercentDecodedByteStr}, BoxError, @@ -33,9 +33,9 @@ pub mod service_method_router; mod into_make_service; mod method_filter; mod method_not_allowed; -mod nested; mod not_found; mod route; +mod strip_prefix; pub(crate) use self::method_not_allowed::MethodNotAllowed; pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route}; @@ -92,7 +92,8 @@ impl fmt::Debug for Router { } } -const NEST_TAIL_PARAM: &str = "__axum_internal_nest_capture"; +pub(crate) const NEST_TAIL_PARAM: &str = "axum_nest"; +const NEST_TAIL_PARAM_CAPTURE: &str = "/*axum_nest"; impl Router where @@ -342,20 +343,43 @@ where panic!("Invalid route: nested routes cannot contain wildcards (*)"); } - let id = RouteId::next(); + let prefix = path; - let path = if path == "/" { - format!("/*{}", NEST_TAIL_PARAM) - } else { - format!("{}/*{}", path, NEST_TAIL_PARAM) - }; + match try_downcast::, _>(svc) { + // if the user is nesting a `Router` we can implement nesting + // by simplying copying all the routes and adding the prefix in + // front + Ok(router) => { + let Router { + mut routes, + node, + fallback: _, + } = router; - if let Err(err) = self.node.insert(path, id) { - panic!("Invalid route: {}", err); + for (id, nested_path) in node.paths { + let route = routes.remove(&id).unwrap(); + let full_path = if &*nested_path == "/" { + path.to_string() + } else { + format!("{}{}", path, nested_path) + }; + self = self.route(&full_path, strip_prefix::StripPrefix::new(route, prefix)); + } + + debug_assert!(routes.is_empty()); + } + // otherwise we add a wildcard route to the service + Err(svc) => { + let path = if path == "/" { + format!("/*{}", NEST_TAIL_PARAM) + } else { + format!("{}/*{}", path, NEST_TAIL_PARAM) + }; + + self = self.route(&path, strip_prefix::StripPrefix::new(svc, prefix)); + } } - self.routes.insert(id, Route::new(nested::Nested { svc })); - self } @@ -700,31 +724,8 @@ where /// ## When used with `Router::nest` /// /// If a router with a fallback is nested inside another router the fallback - /// will only apply to requests that matches the prefix: - /// - /// ```rust - /// use axum::{ - /// Router, - /// routing::get, - /// handler::Handler, - /// response::IntoResponse, - /// http::{StatusCode, Uri}, - /// }; - /// - /// let api = Router::new() - /// .route("/", get(|| async { /* ... */ })) - /// .fallback(api_fallback.into_service()); - /// - /// let app = Router::new().nest("/api", api); - /// - /// async fn api_fallback() -> impl IntoResponse { /* ... */ } - /// - /// // `api_fallback` will be called for `/api/some-unknown-path` but not for - /// // `/some-unknown-path` as the path doesn't start with `/api` - /// # async { - /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); - /// # }; - /// ``` + /// of the nested router will be discarded and not used. This is such that + /// the outer router's fallback takes precedence. pub fn fallback(mut self, svc: T) -> Self where T: Service, Response = Response, Error = Infallible> @@ -743,8 +744,22 @@ where req.extensions_mut().insert(id); if let Some(matched_path) = self.node.paths.get(&id) { - req.extensions_mut() - .insert(crate::extract::MatchedPath(matched_path.clone())); + let matched_path = if let Some(previous) = req.extensions_mut().get::() { + // a previous `MatchedPath` might exist if we're inside a nested Router + let previous = if let Some(previous) = + previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE) + { + previous + } else { + previous.as_str() + }; + + let matched_path = format!("{}{}", previous, matched_path); + matched_path.into() + } else { + Arc::clone(matched_path) + }; + req.extensions_mut().insert(MatchedPath(matched_path)); } let params = match_ @@ -754,11 +769,6 @@ where .map(|(key, value)| (key.to_string(), value.to_string())) .collect::>(); - if let Some(tail) = match_.params.get(NEST_TAIL_PARAM) { - req.extensions_mut() - .insert(nested::NestMatchTail(tail.to_string())); - } - insert_url_params(&mut req, params); let route = self @@ -970,14 +980,22 @@ impl Fallback { } } -#[cfg(test)] -mod tests { - use super::*; +fn try_downcast(k: K) -> Result +where + T: 'static, + K: Send + 'static, +{ + use std::any::Any; - #[test] - fn traits() { - use crate::tests::*; - - assert_send::>(); + let k = Box::new(k) as Box; + match k.downcast() { + Ok(t) => Ok(*t), + Err(other) => Err(*other.downcast().unwrap()), } } + +#[test] +fn traits() { + use crate::tests::*; + assert_send::>(); +} diff --git a/src/routing/nested.rs b/src/routing/nested.rs deleted file mode 100644 index 9ef7a1e9..00000000 --- a/src/routing/nested.rs +++ /dev/null @@ -1,69 +0,0 @@ -use crate::body::BoxBody; -use http::{Request, Response, Uri}; -use std::{ - convert::Infallible, - task::{Context, Poll}, -}; -use tower::util::Oneshot; -use tower::ServiceExt; -use tower_service::Service; - -/// A [`Service`] that has been nested inside a router at some path. -/// -/// Created with [`Router::nest`]. -#[derive(Debug, Clone)] -pub(super) struct Nested { - pub(super) svc: S, -} - -impl Service> for Nested -where - S: Service, Response = Response, Error = Infallible> + Clone, - B: Send + Sync + 'static, -{ - type Response = Response; - type Error = Infallible; - type Future = Oneshot>; - - #[inline] - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, mut req: Request) -> Self::Future { - // strip the prefix from the URI just before calling the inner service - // such that any surrounding middleware still see the full path - if let Some(tail) = req.extensions_mut().remove::() { - UriStack::push(&mut req); - let new_uri = super::with_path(req.uri(), &tail.0); - *req.uri_mut() = new_uri; - } - - self.svc.clone().oneshot(req) - } -} - -pub(crate) struct UriStack(Vec); - -impl UriStack { - fn push(req: &mut Request) { - let uri = req.uri().clone(); - - if let Some(stack) = req.extensions_mut().get_mut::() { - stack.0.push(uri); - } else { - req.extensions_mut().insert(Self(vec![uri])); - } - } -} - -#[derive(Clone)] -pub(super) struct NestMatchTail(pub(super) String); - -#[test] -fn traits() { - use crate::tests::*; - - assert_send::>(); - assert_sync::>(); -} diff --git a/src/routing/strip_prefix.rs b/src/routing/strip_prefix.rs new file mode 100644 index 00000000..c7950538 --- /dev/null +++ b/src/routing/strip_prefix.rs @@ -0,0 +1,77 @@ +use http::{Request, Uri}; +use std::{ + borrow::Cow, + sync::Arc, + task::{Context, Poll}, +}; +use tower_service::Service; + +#[derive(Clone)] +pub(super) struct StripPrefix { + inner: S, + prefix: Arc, +} + +impl StripPrefix { + pub(super) fn new(inner: S, prefix: &str) -> Self { + Self { + inner, + prefix: prefix.into(), + } + } +} + +impl Service> for StripPrefix +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let new_uri = strip_prefix(req.uri(), &self.prefix); + *req.uri_mut() = new_uri; + self.inner.call(req) + } +} + +fn strip_prefix(uri: &Uri, prefix: &str) -> Uri { + let path_and_query = if let Some(path_and_query) = uri.path_and_query() { + let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) { + path + } else { + path_and_query.path() + }; + + let new_path = if new_path.starts_with('/') { + Cow::Borrowed(new_path) + } else { + Cow::Owned(format!("/{}", new_path)) + }; + + if let Some(query) = path_and_query.query() { + Some( + format!("{}?{}", new_path, query) + .parse::() + .unwrap(), + ) + } else { + Some(new_path.parse().unwrap()) + } + } else { + None + }; + + let mut parts = http::uri::Parts::default(); + parts.scheme = uri.scheme().cloned(); + parts.authority = uri.authority().cloned(); + parts.path_and_query = path_and_query; + + Uri::from_parts(parts).unwrap() +} diff --git a/src/tests/fallback.rs b/src/tests/fallback.rs index c5708929..8107c13f 100644 --- a/src/tests/fallback.rs +++ b/src/tests/fallback.rs @@ -31,29 +31,6 @@ async fn nest() { assert_eq!(res.text().await, "fallback"); } -#[tokio::test] -async fn nesting_with_fallback() { - let app = Router::new().nest( - "/foo", - Router::new() - .route("/bar", get(|| async {})) - .fallback((|| async { "fallback" }).into_service()), - ); - - let client = TestClient::new(app); - - assert_eq!(client.get("/foo/bar").send().await.status(), StatusCode::OK); - - // this shouldn't exist because the fallback is inside the nested router - let res = client.get("/does-not-exist").send().await; - assert_eq!(res.status(), StatusCode::NOT_FOUND); - - // this should work since we get into the nested router - let res = client.get("/foo/does-not-exist").send().await; - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await, "fallback"); -} - #[tokio::test] async fn or() { let one = Router::new().route("/one", get(|| async {})); diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 0ce3595e..fc6c314b 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,7 +1,7 @@ #![allow(clippy::blacklisted_name)] use crate::error_handling::HandleErrorLayer; -use crate::extract::MatchedPath; +use crate::extract::{Extension, MatchedPath}; use crate::BoxError; use crate::{ extract::{self, Path}, @@ -29,7 +29,6 @@ use std::{ }; use tower::{service_fn, timeout::TimeoutLayer, ServiceBuilder}; use tower_http::auth::RequireAuthorizationLayer; -use tower_http::trace::TraceLayer; use tower_service::Service; pub(crate) use helpers::*; @@ -592,24 +591,92 @@ async fn with_and_without_trailing_slash() { assert_eq!(res.text().await, "without tsr"); } +#[derive(Clone)] +struct SetMatchedPathExtension(S); + +impl Service> for SetMatchedPathExtension +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let path = req + .extensions() + .get::() + .unwrap() + .as_str() + .to_string(); + req.extensions_mut().insert(MatchedPathFromMiddleware(path)); + self.0.call(req) + } +} + +#[derive(Clone)] +struct MatchedPathFromMiddleware(String); + #[tokio::test] async fn access_matched_path() { + let api = Router::new().route( + "/users/:id", + get(|path: MatchedPath| async move { path.as_str().to_string() }), + ); + + async fn handler( + path: MatchedPath, + Extension(MatchedPathFromMiddleware(path_from_middleware)): Extension< + MatchedPathFromMiddleware, + >, + ) -> String { + format!( + "extractor = {}, middleware = {}", + path.as_str(), + path_from_middleware + ) + } + let app = Router::new() .route( "/:key", get(|path: MatchedPath| async move { path.as_str().to_string() }), ) - .layer( - TraceLayer::new_for_http().make_span_with(|req: &Request<_>| { - let path = req.extensions().get::().unwrap().as_str(); - tracing::info_span!("http-request", %path) - }), - ); + .nest("/api", api) + .nest( + "/public", + Router::new().route("/assets/*path", get(handler)), + ) + .nest("/foo", handler.into_service()) + .layer(tower::layer::layer_fn(SetMatchedPathExtension)); let client = TestClient::new(app); let res = client.get("/foo").send().await; assert_eq!(res.text().await, "/:key"); + + let res = client.get("/api/users/123").send().await; + assert_eq!(res.text().await, "/api/users/:id"); + + let res = client.get("/public/assets/css/style.css").send().await; + assert_eq!( + res.text().await, + "extractor = /public/assets/*path, middleware = /public/assets/*path" + ); + + let res = client.get("/foo/bar/baz").send().await; + assert_eq!( + res.text().await, + format!( + "extractor = /foo/*{}, middleware = /foo/*{}", + crate::routing::NEST_TAIL_PARAM, + crate::routing::NEST_TAIL_PARAM, + ), + ); } #[tokio::test]