mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-16 22:43:03 +01:00
Make sure nested services still see full URI (#166)
They'd previously see the nested URI as we mutated the request. Now we always route based on the nested URI (if present) without mutating the request. Also meant we could get rid of `OriginalUri` which is nice.
This commit is contained in:
parent
0674c9136a
commit
bc27b09f5c
3 changed files with 59 additions and 28 deletions
|
@ -244,7 +244,7 @@
|
||||||
//!
|
//!
|
||||||
//! [`body::Body`]: crate::body::Body
|
//! [`body::Body`]: crate::body::Body
|
||||||
|
|
||||||
use crate::{response::IntoResponse, routing::OriginalUri};
|
use crate::response::IntoResponse;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
|
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
|
||||||
use rejection::*;
|
use rejection::*;
|
||||||
|
@ -378,7 +378,7 @@ impl<B> RequestParts<B> {
|
||||||
let (
|
let (
|
||||||
http::request::Parts {
|
http::request::Parts {
|
||||||
method,
|
method,
|
||||||
mut uri,
|
uri,
|
||||||
version,
|
version,
|
||||||
headers,
|
headers,
|
||||||
extensions,
|
extensions,
|
||||||
|
@ -387,10 +387,6 @@ impl<B> RequestParts<B> {
|
||||||
body,
|
body,
|
||||||
) = req.into_parts();
|
) = req.into_parts();
|
||||||
|
|
||||||
if let Some(original_uri) = extensions.get::<OriginalUri>() {
|
|
||||||
uri = original_uri.0.clone();
|
|
||||||
};
|
|
||||||
|
|
||||||
RequestParts {
|
RequestParts {
|
||||||
method,
|
method,
|
||||||
uri,
|
uri,
|
||||||
|
|
|
@ -459,7 +459,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
||||||
if let Some(captures) = self.pattern.full_match(req.uri().path()) {
|
if let Some(captures) = self.pattern.full_match(&req) {
|
||||||
insert_url_params(&mut req, captures);
|
insert_url_params(&mut req, captures);
|
||||||
let fut = self.svc.clone().oneshot(req);
|
let fut = self.svc.clone().oneshot(req);
|
||||||
RouteFuture::a(fut)
|
RouteFuture::a(fut)
|
||||||
|
@ -606,8 +606,8 @@ impl PathPattern {
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn full_match(&self, path: &str) -> Option<Captures> {
|
pub(crate) fn full_match<B>(&self, req: &Request<B>) -> Option<Captures> {
|
||||||
self.do_match(path).and_then(|match_| {
|
self.do_match(req).and_then(|match_| {
|
||||||
if match_.full_match {
|
if match_.full_match {
|
||||||
Some(match_.captures)
|
Some(match_.captures)
|
||||||
} else {
|
} else {
|
||||||
|
@ -616,12 +616,18 @@ impl PathPattern {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn prefix_match<'a>(&self, path: &'a str) -> Option<(&'a str, Captures)> {
|
pub(crate) fn prefix_match<'a, B>(&self, req: &'a Request<B>) -> Option<(&'a str, Captures)> {
|
||||||
self.do_match(path)
|
self.do_match(req)
|
||||||
.map(|match_| (match_.matched, match_.captures))
|
.map(|match_| (match_.matched, match_.captures))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn do_match<'a>(&self, path: &'a str) -> Option<Match<'a>> {
|
fn do_match<'a, B>(&self, req: &'a Request<B>) -> Option<Match<'a>> {
|
||||||
|
let path = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
|
||||||
|
nested_uri.0.path()
|
||||||
|
} else {
|
||||||
|
req.uri().path()
|
||||||
|
};
|
||||||
|
|
||||||
self.0.full_path_regex.captures(path).map(|captures| {
|
self.0.full_path_regex.captures(path).map(|captures| {
|
||||||
let matched = captures.get(0).unwrap();
|
let matched = captures.get(0).unwrap();
|
||||||
let full_match = matched.as_str() == path;
|
let full_match = matched.as_str() == path;
|
||||||
|
@ -864,16 +870,15 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
||||||
if req.extensions().get::<OriginalUri>().is_none() {
|
let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) {
|
||||||
let original_uri = OriginalUri(req.uri().clone());
|
let uri = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
|
||||||
req.extensions_mut().insert(original_uri);
|
&nested_uri.0
|
||||||
}
|
} else {
|
||||||
|
req.uri()
|
||||||
|
};
|
||||||
|
|
||||||
let f = if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) {
|
let without_prefix = strip_prefix(uri, prefix);
|
||||||
let without_prefix = strip_prefix(req.uri(), prefix);
|
req.extensions_mut().insert(NestedUri(without_prefix));
|
||||||
req.extensions_mut()
|
|
||||||
.insert(NestedUri(without_prefix.clone()));
|
|
||||||
*req.uri_mut() = without_prefix;
|
|
||||||
|
|
||||||
insert_url_params(&mut req, captures);
|
insert_url_params(&mut req, captures);
|
||||||
let fut = self.svc.clone().oneshot(req);
|
let fut = self.svc.clone().oneshot(req);
|
||||||
|
@ -887,11 +892,6 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `Nested` changes the incoming requests URI. This will be saved as an
|
|
||||||
/// extension so extractors can still access the original URI.
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub(crate) struct OriginalUri(pub(crate) Uri);
|
|
||||||
|
|
||||||
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
|
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
|
||||||
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
|
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
|
||||||
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
|
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
|
||||||
|
@ -953,8 +953,9 @@ mod tests {
|
||||||
|
|
||||||
fn assert_match(route_spec: &'static str, path: &'static str) {
|
fn assert_match(route_spec: &'static str, path: &'static str) {
|
||||||
let route = PathPattern::new(route_spec);
|
let route = PathPattern::new(route_spec);
|
||||||
|
let req = Request::builder().uri(path).body(()).unwrap();
|
||||||
assert!(
|
assert!(
|
||||||
route.full_match(path).is_some(),
|
route.full_match(&req).is_some(),
|
||||||
"`{}` doesn't match `{}`",
|
"`{}` doesn't match `{}`",
|
||||||
path,
|
path,
|
||||||
route_spec
|
route_spec
|
||||||
|
@ -963,8 +964,9 @@ mod tests {
|
||||||
|
|
||||||
fn refute_match(route_spec: &'static str, path: &'static str) {
|
fn refute_match(route_spec: &'static str, path: &'static str) {
|
||||||
let route = PathPattern::new(route_spec);
|
let route = PathPattern::new(route_spec);
|
||||||
|
let req = Request::builder().uri(path).body(()).unwrap();
|
||||||
assert!(
|
assert!(
|
||||||
route.full_match(path).is_none(),
|
route.full_match(&req).is_none(),
|
||||||
"`{}` did match `{}` (but shouldn't)",
|
"`{}` did match `{}` (but shouldn't)",
|
||||||
path,
|
path,
|
||||||
route_spec
|
route_spec
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::body::box_body;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
@ -149,6 +150,7 @@ async fn nested_url_extractor() {
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
|
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
|
||||||
|
|
||||||
let res = client
|
let res = client
|
||||||
|
@ -156,6 +158,7 @@ async fn nested_url_extractor() {
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
assert_eq!(res.text().await.unwrap(), "/foo/bar/qux");
|
assert_eq!(res.text().await.unwrap(), "/foo/bar/qux");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -181,5 +184,35 @@ async fn nested_url_nested_extractor() {
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
assert_eq!(res.text().await.unwrap(), "/baz");
|
assert_eq!(res.text().await.unwrap(), "/baz");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn nested_service_sees_original_uri() {
|
||||||
|
let app = nest(
|
||||||
|
"/foo",
|
||||||
|
nest(
|
||||||
|
"/bar",
|
||||||
|
route(
|
||||||
|
"/baz",
|
||||||
|
service_fn(|req: Request<Body>| async move {
|
||||||
|
let body = box_body(Body::from(req.uri().to_string()));
|
||||||
|
Ok::<_, Infallible>(Response::new(body))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
let addr = run_in_background(app).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.get(format!("http://{}/foo/bar/baz", addr))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
|
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue