From 91981db8c7c42478c0efb37036b029c1ff3c9e24 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Oct 2021 19:29:59 +0200 Subject: [PATCH] Fix middleware not seeing the full URI for nested routes (#423) We stripped the prefix before calling any middleware. Instead the prefix should be stripped just before calling the actual route. This way middleware still see the full URI, as it used to work in 0.2. Fixes #419 --- src/routing/mod.rs | 20 ++++++++++++------ src/tests/mod.rs | 2 -- src/tests/nest.rs | 52 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 8 deletions(-) diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 8a3f92e3..41de488b 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -440,8 +440,7 @@ where LayeredResBody::Error: Into, { let layer = ServiceBuilder::new() - .layer_fn(Route) - .layer_fn(CloneBoxService::new) + .layer_fn(Route::new) .layer(MapResponseBodyLayer::new(box_body)) .layer(layer); @@ -755,9 +754,7 @@ where .collect::>(); if let Some(tail) = match_.params.get(NEST_TAIL_PARAM) { - UriStack::push(&mut req); - let new_uri = with_path(req.uri(), tail); - *req.uri_mut() = new_uri; + req.extensions_mut().insert(NestMatchTail(tail.to_string())); } insert_url_params(&mut req, params); @@ -772,6 +769,9 @@ where } } +#[derive(Clone)] +struct NestMatchTail(String); + impl Service> for Router where B: Send + Sync + 'static, @@ -903,7 +903,15 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { + // strip the prefix from the URI just before calling the inner service + // such that any surrounding middleware still see the full path + if let Some(tail) = req.extensions_mut().remove::() { + UriStack::push(&mut req); + let new_uri = with_path(req.uri(), &tail.0); + *req.uri_mut() = new_uri; + } + NestedFuture { inner: self.svc.clone().oneshot(req), } diff --git a/src/tests/mod.rs b/src/tests/mod.rs index f1c1e98c..0ce3595e 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -684,8 +684,6 @@ async fn middleware_still_run_for_unmatched_requests() { assert_eq!(COUNT.load(Ordering::SeqCst), 2); } -// TODO(david): middleware still run for empty routers - pub(crate) fn assert_send() {} pub(crate) fn assert_sync() {} pub(crate) fn assert_unpin() {} diff --git a/src/tests/nest.rs b/src/tests/nest.rs index db3f3585..96758444 100644 --- a/src/tests/nest.rs +++ b/src/tests/nest.rs @@ -1,6 +1,7 @@ use super::*; use crate::body::box_body; use crate::error_handling::HandleErrorExt; +use crate::extract::Extension; use std::collections::HashMap; #[tokio::test] @@ -257,3 +258,54 @@ async fn multiple_top_level_nests() { async fn nest_cannot_contain_wildcards() { Router::::new().nest("/one/*rest", Router::new()); } + +#[tokio::test] +async fn outer_middleware_still_see_whole_url() { + #[derive(Clone)] + struct SetUriExtension(S); + + #[derive(Clone)] + struct Uri(http::Uri); + + impl Service> for SetUriExtension + where + S: Service>, + { + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let uri = Uri(req.uri().clone()); + req.extensions_mut().insert(uri); + self.0.call(req) + } + } + + async fn handler(Extension(Uri(middleware_uri)): Extension) -> impl IntoResponse { + middleware_uri.to_string() + } + + let app = Router::new() + .route("/", get(handler)) + .route("/foo", get(handler)) + .route("/foo/bar", get(handler)) + .nest("/one", Router::new().route("/two", get(handler))) + .fallback(handler.into_service()) + .layer(tower::layer::layer_fn(SetUriExtension)); + + let client = TestClient::new(app); + + assert_eq!(client.get("/").send().await.text().await, "/"); + assert_eq!(client.get("/foo").send().await.text().await, "/foo"); + assert_eq!(client.get("/foo/bar").send().await.text().await, "/foo/bar"); + assert_eq!( + client.get("/not-found").send().await.text().await, + "/not-found" + ); + assert_eq!(client.get("/one/two").send().await.text().await, "/one/two"); +}