Fix middleware not seeing the full URI for nested routes (#423)

We stripped the prefix before calling any middleware. Instead the prefix
should be stripped just before calling the actual route. This way
middleware still see the full URI, as it used to work in 0.2.

Fixes #419
This commit is contained in:
David Pedersen 2021-10-26 19:29:59 +02:00 committed by GitHub
parent 8fe4eaf1d5
commit 91981db8c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 8 deletions

View file

@ -440,8 +440,7 @@ where
LayeredResBody::Error: Into<BoxError>, LayeredResBody::Error: Into<BoxError>,
{ {
let layer = ServiceBuilder::new() let layer = ServiceBuilder::new()
.layer_fn(Route) .layer_fn(Route::new)
.layer_fn(CloneBoxService::new)
.layer(MapResponseBodyLayer::new(box_body)) .layer(MapResponseBodyLayer::new(box_body))
.layer(layer); .layer(layer);
@ -755,9 +754,7 @@ where
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if let Some(tail) = match_.params.get(NEST_TAIL_PARAM) { if let Some(tail) = match_.params.get(NEST_TAIL_PARAM) {
UriStack::push(&mut req); req.extensions_mut().insert(NestMatchTail(tail.to_string()));
let new_uri = with_path(req.uri(), tail);
*req.uri_mut() = new_uri;
} }
insert_url_params(&mut req, params); insert_url_params(&mut req, params);
@ -772,6 +769,9 @@ where
} }
} }
#[derive(Clone)]
struct NestMatchTail(String);
impl<B> Service<Request<B>> for Router<B> impl<B> Service<Request<B>> for Router<B>
where where
B: Send + Sync + 'static, B: Send + Sync + 'static,
@ -903,7 +903,15 @@ where
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, req: Request<B>) -> Self::Future { fn call(&mut self, mut req: Request<B>) -> Self::Future {
// strip the prefix from the URI just before calling the inner service
// such that any surrounding middleware still see the full path
if let Some(tail) = req.extensions_mut().remove::<NestMatchTail>() {
UriStack::push(&mut req);
let new_uri = with_path(req.uri(), &tail.0);
*req.uri_mut() = new_uri;
}
NestedFuture { NestedFuture {
inner: self.svc.clone().oneshot(req), inner: self.svc.clone().oneshot(req),
} }

View file

@ -684,8 +684,6 @@ async fn middleware_still_run_for_unmatched_requests() {
assert_eq!(COUNT.load(Ordering::SeqCst), 2); assert_eq!(COUNT.load(Ordering::SeqCst), 2);
} }
// TODO(david): middleware still run for empty routers
pub(crate) fn assert_send<T: Send>() {} pub(crate) fn assert_send<T: Send>() {}
pub(crate) fn assert_sync<T: Sync>() {} pub(crate) fn assert_sync<T: Sync>() {}
pub(crate) fn assert_unpin<T: Unpin>() {} pub(crate) fn assert_unpin<T: Unpin>() {}

View file

@ -1,6 +1,7 @@
use super::*; use super::*;
use crate::body::box_body; use crate::body::box_body;
use crate::error_handling::HandleErrorExt; use crate::error_handling::HandleErrorExt;
use crate::extract::Extension;
use std::collections::HashMap; use std::collections::HashMap;
#[tokio::test] #[tokio::test]
@ -257,3 +258,54 @@ async fn multiple_top_level_nests() {
async fn nest_cannot_contain_wildcards() { async fn nest_cannot_contain_wildcards() {
Router::<Body>::new().nest("/one/*rest", Router::new()); Router::<Body>::new().nest("/one/*rest", Router::new());
} }
#[tokio::test]
async fn outer_middleware_still_see_whole_url() {
#[derive(Clone)]
struct SetUriExtension<S>(S);
#[derive(Clone)]
struct Uri(http::Uri);
impl<B, S> Service<Request<B>> for SetUriExtension<S>
where
S: Service<Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
let uri = Uri(req.uri().clone());
req.extensions_mut().insert(uri);
self.0.call(req)
}
}
async fn handler(Extension(Uri(middleware_uri)): Extension<Uri>) -> impl IntoResponse {
middleware_uri.to_string()
}
let app = Router::new()
.route("/", get(handler))
.route("/foo", get(handler))
.route("/foo/bar", get(handler))
.nest("/one", Router::new().route("/two", get(handler)))
.fallback(handler.into_service())
.layer(tower::layer::layer_fn(SetUriExtension));
let client = TestClient::new(app);
assert_eq!(client.get("/").send().await.text().await, "/");
assert_eq!(client.get("/foo").send().await.text().await, "/foo");
assert_eq!(client.get("/foo/bar").send().await.text().await, "/foo/bar");
assert_eq!(
client.get("/not-found").send().await.text().await,
"/not-found"
);
assert_eq!(client.get("/one/two").send().await.text().await, "/one/two");
}