mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-01 08:56:15 +01:00
Fix middleware not seeing the full URI for nested routes (#423)
We stripped the prefix before calling any middleware. Instead the prefix should be stripped just before calling the actual route. This way middleware still see the full URI, as it used to work in 0.2. Fixes #419
This commit is contained in:
parent
8fe4eaf1d5
commit
91981db8c7
3 changed files with 66 additions and 8 deletions
|
@ -440,8 +440,7 @@ where
|
|||
LayeredResBody::Error: Into<BoxError>,
|
||||
{
|
||||
let layer = ServiceBuilder::new()
|
||||
.layer_fn(Route)
|
||||
.layer_fn(CloneBoxService::new)
|
||||
.layer_fn(Route::new)
|
||||
.layer(MapResponseBodyLayer::new(box_body))
|
||||
.layer(layer);
|
||||
|
||||
|
@ -755,9 +754,7 @@ where
|
|||
.collect::<Vec<_>>();
|
||||
|
||||
if let Some(tail) = match_.params.get(NEST_TAIL_PARAM) {
|
||||
UriStack::push(&mut req);
|
||||
let new_uri = with_path(req.uri(), tail);
|
||||
*req.uri_mut() = new_uri;
|
||||
req.extensions_mut().insert(NestMatchTail(tail.to_string()));
|
||||
}
|
||||
|
||||
insert_url_params(&mut req, params);
|
||||
|
@ -772,6 +769,9 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct NestMatchTail(String);
|
||||
|
||||
impl<B> Service<Request<B>> for Router<B>
|
||||
where
|
||||
B: Send + Sync + 'static,
|
||||
|
@ -903,7 +903,15 @@ where
|
|||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<B>) -> Self::Future {
|
||||
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
||||
// strip the prefix from the URI just before calling the inner service
|
||||
// such that any surrounding middleware still see the full path
|
||||
if let Some(tail) = req.extensions_mut().remove::<NestMatchTail>() {
|
||||
UriStack::push(&mut req);
|
||||
let new_uri = with_path(req.uri(), &tail.0);
|
||||
*req.uri_mut() = new_uri;
|
||||
}
|
||||
|
||||
NestedFuture {
|
||||
inner: self.svc.clone().oneshot(req),
|
||||
}
|
||||
|
|
|
@ -684,8 +684,6 @@ async fn middleware_still_run_for_unmatched_requests() {
|
|||
assert_eq!(COUNT.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
// TODO(david): middleware still run for empty routers
|
||||
|
||||
pub(crate) fn assert_send<T: Send>() {}
|
||||
pub(crate) fn assert_sync<T: Sync>() {}
|
||||
pub(crate) fn assert_unpin<T: Unpin>() {}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use super::*;
|
||||
use crate::body::box_body;
|
||||
use crate::error_handling::HandleErrorExt;
|
||||
use crate::extract::Extension;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -257,3 +258,54 @@ async fn multiple_top_level_nests() {
|
|||
async fn nest_cannot_contain_wildcards() {
|
||||
Router::<Body>::new().nest("/one/*rest", Router::new());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn outer_middleware_still_see_whole_url() {
|
||||
#[derive(Clone)]
|
||||
struct SetUriExtension<S>(S);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Uri(http::Uri);
|
||||
|
||||
impl<B, S> Service<Request<B>> for SetUriExtension<S>
|
||||
where
|
||||
S: Service<Request<B>>,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = S::Future;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.0.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
||||
let uri = Uri(req.uri().clone());
|
||||
req.extensions_mut().insert(uri);
|
||||
self.0.call(req)
|
||||
}
|
||||
}
|
||||
|
||||
async fn handler(Extension(Uri(middleware_uri)): Extension<Uri>) -> impl IntoResponse {
|
||||
middleware_uri.to_string()
|
||||
}
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(handler))
|
||||
.route("/foo", get(handler))
|
||||
.route("/foo/bar", get(handler))
|
||||
.nest("/one", Router::new().route("/two", get(handler)))
|
||||
.fallback(handler.into_service())
|
||||
.layer(tower::layer::layer_fn(SetUriExtension));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
assert_eq!(client.get("/").send().await.text().await, "/");
|
||||
assert_eq!(client.get("/foo").send().await.text().await, "/foo");
|
||||
assert_eq!(client.get("/foo/bar").send().await.text().await, "/foo/bar");
|
||||
assert_eq!(
|
||||
client.get("/not-found").send().await.text().await,
|
||||
"/not-found"
|
||||
);
|
||||
assert_eq!(client.get("/one/two").send().await.text().await, "/one/two");
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue