mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-28 23:38:20 +01:00
Make sure nested services still see full URI (#166)
They'd previously see the nested URI as we mutated the request. Now we always route based on the nested URI (if present) without mutating the request. Also meant we could get rid of `OriginalUri` which is nice.
This commit is contained in:
parent
0674c9136a
commit
bc27b09f5c
3 changed files with 59 additions and 28 deletions
|
@ -244,7 +244,7 @@
|
|||
//!
|
||||
//! [`body::Body`]: crate::body::Body
|
||||
|
||||
use crate::{response::IntoResponse, routing::OriginalUri};
|
||||
use crate::response::IntoResponse;
|
||||
use async_trait::async_trait;
|
||||
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
|
||||
use rejection::*;
|
||||
|
@ -378,7 +378,7 @@ impl<B> RequestParts<B> {
|
|||
let (
|
||||
http::request::Parts {
|
||||
method,
|
||||
mut uri,
|
||||
uri,
|
||||
version,
|
||||
headers,
|
||||
extensions,
|
||||
|
@ -387,10 +387,6 @@ impl<B> RequestParts<B> {
|
|||
body,
|
||||
) = req.into_parts();
|
||||
|
||||
if let Some(original_uri) = extensions.get::<OriginalUri>() {
|
||||
uri = original_uri.0.clone();
|
||||
};
|
||||
|
||||
RequestParts {
|
||||
method,
|
||||
uri,
|
||||
|
|
|
@ -459,7 +459,7 @@ where
|
|||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
||||
if let Some(captures) = self.pattern.full_match(req.uri().path()) {
|
||||
if let Some(captures) = self.pattern.full_match(&req) {
|
||||
insert_url_params(&mut req, captures);
|
||||
let fut = self.svc.clone().oneshot(req);
|
||||
RouteFuture::a(fut)
|
||||
|
@ -606,8 +606,8 @@ impl PathPattern {
|
|||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn full_match(&self, path: &str) -> Option<Captures> {
|
||||
self.do_match(path).and_then(|match_| {
|
||||
pub(crate) fn full_match<B>(&self, req: &Request<B>) -> Option<Captures> {
|
||||
self.do_match(req).and_then(|match_| {
|
||||
if match_.full_match {
|
||||
Some(match_.captures)
|
||||
} else {
|
||||
|
@ -616,12 +616,18 @@ impl PathPattern {
|
|||
})
|
||||
}
|
||||
|
||||
pub(crate) fn prefix_match<'a>(&self, path: &'a str) -> Option<(&'a str, Captures)> {
|
||||
self.do_match(path)
|
||||
pub(crate) fn prefix_match<'a, B>(&self, req: &'a Request<B>) -> Option<(&'a str, Captures)> {
|
||||
self.do_match(req)
|
||||
.map(|match_| (match_.matched, match_.captures))
|
||||
}
|
||||
|
||||
fn do_match<'a>(&self, path: &'a str) -> Option<Match<'a>> {
|
||||
fn do_match<'a, B>(&self, req: &'a Request<B>) -> Option<Match<'a>> {
|
||||
let path = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
|
||||
nested_uri.0.path()
|
||||
} else {
|
||||
req.uri().path()
|
||||
};
|
||||
|
||||
self.0.full_path_regex.captures(path).map(|captures| {
|
||||
let matched = captures.get(0).unwrap();
|
||||
let full_match = matched.as_str() == path;
|
||||
|
@ -864,16 +870,15 @@ where
|
|||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
||||
if req.extensions().get::<OriginalUri>().is_none() {
|
||||
let original_uri = OriginalUri(req.uri().clone());
|
||||
req.extensions_mut().insert(original_uri);
|
||||
}
|
||||
let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) {
|
||||
let uri = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
|
||||
&nested_uri.0
|
||||
} else {
|
||||
req.uri()
|
||||
};
|
||||
|
||||
let f = if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) {
|
||||
let without_prefix = strip_prefix(req.uri(), prefix);
|
||||
req.extensions_mut()
|
||||
.insert(NestedUri(without_prefix.clone()));
|
||||
*req.uri_mut() = without_prefix;
|
||||
let without_prefix = strip_prefix(uri, prefix);
|
||||
req.extensions_mut().insert(NestedUri(without_prefix));
|
||||
|
||||
insert_url_params(&mut req, captures);
|
||||
let fut = self.svc.clone().oneshot(req);
|
||||
|
@ -887,11 +892,6 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// `Nested` changes the incoming requests URI. This will be saved as an
|
||||
/// extension so extractors can still access the original URI.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct OriginalUri(pub(crate) Uri);
|
||||
|
||||
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
|
||||
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
|
||||
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
|
||||
|
@ -953,8 +953,9 @@ mod tests {
|
|||
|
||||
fn assert_match(route_spec: &'static str, path: &'static str) {
|
||||
let route = PathPattern::new(route_spec);
|
||||
let req = Request::builder().uri(path).body(()).unwrap();
|
||||
assert!(
|
||||
route.full_match(path).is_some(),
|
||||
route.full_match(&req).is_some(),
|
||||
"`{}` doesn't match `{}`",
|
||||
path,
|
||||
route_spec
|
||||
|
@ -963,8 +964,9 @@ mod tests {
|
|||
|
||||
fn refute_match(route_spec: &'static str, path: &'static str) {
|
||||
let route = PathPattern::new(route_spec);
|
||||
let req = Request::builder().uri(path).body(()).unwrap();
|
||||
assert!(
|
||||
route.full_match(path).is_none(),
|
||||
route.full_match(&req).is_none(),
|
||||
"`{}` did match `{}` (but shouldn't)",
|
||||
path,
|
||||
route_spec
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use super::*;
|
||||
use crate::body::box_body;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -149,6 +150,7 @@ async fn nested_url_extractor() {
|
|||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
|
||||
|
||||
let res = client
|
||||
|
@ -156,6 +158,7 @@ async fn nested_url_extractor() {
|
|||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await.unwrap(), "/foo/bar/qux");
|
||||
}
|
||||
|
||||
|
@ -181,5 +184,35 @@ async fn nested_url_nested_extractor() {
|
|||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await.unwrap(), "/baz");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn nested_service_sees_original_uri() {
|
||||
let app = nest(
|
||||
"/foo",
|
||||
nest(
|
||||
"/bar",
|
||||
route(
|
||||
"/baz",
|
||||
service_fn(|req: Request<Body>| async move {
|
||||
let body = box_body(Body::from(req.uri().to_string()));
|
||||
Ok::<_, Infallible>(Response::new(body))
|
||||
}),
|
||||
),
|
||||
),
|
||||
);
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo/bar/baz", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue