Make sure nested services still see full URI (#166)

They'd previously see the nested URI as we mutated the request. Now we
always route based on the nested URI (if present) without mutating the
request. Also meant we could get rid of `OriginalUri` which is nice.
This commit is contained in:
David Pedersen 2021-08-08 17:27:23 +02:00 committed by GitHub
parent 0674c9136a
commit bc27b09f5c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 28 deletions

View file

@ -244,7 +244,7 @@
//!
//! [`body::Body`]: crate::body::Body
use crate::{response::IntoResponse, routing::OriginalUri};
use crate::response::IntoResponse;
use async_trait::async_trait;
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
use rejection::*;
@ -378,7 +378,7 @@ impl<B> RequestParts<B> {
let (
http::request::Parts {
method,
mut uri,
uri,
version,
headers,
extensions,
@ -387,10 +387,6 @@ impl<B> RequestParts<B> {
body,
) = req.into_parts();
if let Some(original_uri) = extensions.get::<OriginalUri>() {
uri = original_uri.0.clone();
};
RequestParts {
method,
uri,

View file

@ -459,7 +459,7 @@ where
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
if let Some(captures) = self.pattern.full_match(req.uri().path()) {
if let Some(captures) = self.pattern.full_match(&req) {
insert_url_params(&mut req, captures);
let fut = self.svc.clone().oneshot(req);
RouteFuture::a(fut)
@ -606,8 +606,8 @@ impl PathPattern {
}))
}
pub(crate) fn full_match(&self, path: &str) -> Option<Captures> {
self.do_match(path).and_then(|match_| {
pub(crate) fn full_match<B>(&self, req: &Request<B>) -> Option<Captures> {
self.do_match(req).and_then(|match_| {
if match_.full_match {
Some(match_.captures)
} else {
@ -616,12 +616,18 @@ impl PathPattern {
})
}
pub(crate) fn prefix_match<'a>(&self, path: &'a str) -> Option<(&'a str, Captures)> {
self.do_match(path)
pub(crate) fn prefix_match<'a, B>(&self, req: &'a Request<B>) -> Option<(&'a str, Captures)> {
self.do_match(req)
.map(|match_| (match_.matched, match_.captures))
}
fn do_match<'a>(&self, path: &'a str) -> Option<Match<'a>> {
fn do_match<'a, B>(&self, req: &'a Request<B>) -> Option<Match<'a>> {
let path = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
nested_uri.0.path()
} else {
req.uri().path()
};
self.0.full_path_regex.captures(path).map(|captures| {
let matched = captures.get(0).unwrap();
let full_match = matched.as_str() == path;
@ -864,16 +870,15 @@ where
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
if req.extensions().get::<OriginalUri>().is_none() {
let original_uri = OriginalUri(req.uri().clone());
req.extensions_mut().insert(original_uri);
}
let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) {
let uri = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
&nested_uri.0
} else {
req.uri()
};
let f = if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) {
let without_prefix = strip_prefix(req.uri(), prefix);
req.extensions_mut()
.insert(NestedUri(without_prefix.clone()));
*req.uri_mut() = without_prefix;
let without_prefix = strip_prefix(uri, prefix);
req.extensions_mut().insert(NestedUri(without_prefix));
insert_url_params(&mut req, captures);
let fut = self.svc.clone().oneshot(req);
@ -887,11 +892,6 @@ where
}
}
/// `Nested` changes the incoming requests URI. This will be saved as an
/// extension so extractors can still access the original URI.
#[derive(Clone)]
pub(crate) struct OriginalUri(pub(crate) Uri);
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
@ -953,8 +953,9 @@ mod tests {
fn assert_match(route_spec: &'static str, path: &'static str) {
let route = PathPattern::new(route_spec);
let req = Request::builder().uri(path).body(()).unwrap();
assert!(
route.full_match(path).is_some(),
route.full_match(&req).is_some(),
"`{}` doesn't match `{}`",
path,
route_spec
@ -963,8 +964,9 @@ mod tests {
fn refute_match(route_spec: &'static str, path: &'static str) {
let route = PathPattern::new(route_spec);
let req = Request::builder().uri(path).body(()).unwrap();
assert!(
route.full_match(path).is_none(),
route.full_match(&req).is_none(),
"`{}` did match `{}` (but shouldn't)",
path,
route_spec

View file

@ -1,4 +1,5 @@
use super::*;
use crate::body::box_body;
use std::collections::HashMap;
#[tokio::test]
@ -149,6 +150,7 @@ async fn nested_url_extractor() {
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
let res = client
@ -156,6 +158,7 @@ async fn nested_url_extractor() {
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/foo/bar/qux");
}
@ -181,5 +184,35 @@ async fn nested_url_nested_extractor() {
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/baz");
}
#[tokio::test]
async fn nested_service_sees_original_uri() {
let app = nest(
"/foo",
nest(
"/bar",
route(
"/baz",
service_fn(|req: Request<Body>| async move {
let body = box_body(Body::from(req.uri().to_string()));
Ok::<_, Infallible>(Response::new(body))
}),
),
),
);
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/foo/bar/baz", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
}