diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 8a3f92e3..41de488b 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -440,8 +440,7 @@ where LayeredResBody::Error: Into, { let layer = ServiceBuilder::new() - .layer_fn(Route) - .layer_fn(CloneBoxService::new) + .layer_fn(Route::new) .layer(MapResponseBodyLayer::new(box_body)) .layer(layer); @@ -755,9 +754,7 @@ where .collect::>(); if let Some(tail) = match_.params.get(NEST_TAIL_PARAM) { - UriStack::push(&mut req); - let new_uri = with_path(req.uri(), tail); - *req.uri_mut() = new_uri; + req.extensions_mut().insert(NestMatchTail(tail.to_string())); } insert_url_params(&mut req, params); @@ -772,6 +769,9 @@ where } } +#[derive(Clone)] +struct NestMatchTail(String); + impl Service> for Router where B: Send + Sync + 'static, @@ -903,7 +903,15 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { + // strip the prefix from the URI just before calling the inner service + // such that any surrounding middleware still see the full path + if let Some(tail) = req.extensions_mut().remove::() { + UriStack::push(&mut req); + let new_uri = with_path(req.uri(), &tail.0); + *req.uri_mut() = new_uri; + } + NestedFuture { inner: self.svc.clone().oneshot(req), } diff --git a/src/tests/mod.rs b/src/tests/mod.rs index f1c1e98c..0ce3595e 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -684,8 +684,6 @@ async fn middleware_still_run_for_unmatched_requests() { assert_eq!(COUNT.load(Ordering::SeqCst), 2); } -// TODO(david): middleware still run for empty routers - pub(crate) fn assert_send() {} pub(crate) fn assert_sync() {} pub(crate) fn assert_unpin() {} diff --git a/src/tests/nest.rs b/src/tests/nest.rs index db3f3585..96758444 100644 --- a/src/tests/nest.rs +++ b/src/tests/nest.rs @@ -1,6 +1,7 @@ use super::*; use crate::body::box_body; use crate::error_handling::HandleErrorExt; +use crate::extract::Extension; use std::collections::HashMap; #[tokio::test] @@ -257,3 +258,54 @@ async fn multiple_top_level_nests() { async fn nest_cannot_contain_wildcards() { Router::::new().nest("/one/*rest", Router::new()); } + +#[tokio::test] +async fn outer_middleware_still_see_whole_url() { + #[derive(Clone)] + struct SetUriExtension(S); + + #[derive(Clone)] + struct Uri(http::Uri); + + impl Service> for SetUriExtension + where + S: Service>, + { + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let uri = Uri(req.uri().clone()); + req.extensions_mut().insert(uri); + self.0.call(req) + } + } + + async fn handler(Extension(Uri(middleware_uri)): Extension) -> impl IntoResponse { + middleware_uri.to_string() + } + + let app = Router::new() + .route("/", get(handler)) + .route("/foo", get(handler)) + .route("/foo/bar", get(handler)) + .nest("/one", Router::new().route("/two", get(handler))) + .fallback(handler.into_service()) + .layer(tower::layer::layer_fn(SetUriExtension)); + + let client = TestClient::new(app); + + assert_eq!(client.get("/").send().await.text().await, "/"); + assert_eq!(client.get("/foo").send().await.text().await, "/foo"); + assert_eq!(client.get("/foo/bar").send().await.text().await, "/foo/bar"); + assert_eq!( + client.get("/not-found").send().await.text().await, + "/not-found" + ); + assert_eq!(client.get("/one/two").send().await.text().await, "/one/two"); +}