Make sure nested services still see full URI (#166)

They'd previously see the nested URI as we mutated the request. Now we
always route based on the nested URI (if present) without mutating the
request. Also meant we could get rid of `OriginalUri` which is nice.
This commit is contained in:
David Pedersen 2021-08-08 17:27:23 +02:00 committed by GitHub
parent 0674c9136a
commit bc27b09f5c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 28 deletions

View file

@ -244,7 +244,7 @@
//! //!
//! [`body::Body`]: crate::body::Body //! [`body::Body`]: crate::body::Body
use crate::{response::IntoResponse, routing::OriginalUri}; use crate::response::IntoResponse;
use async_trait::async_trait; use async_trait::async_trait;
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version}; use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
use rejection::*; use rejection::*;
@ -378,7 +378,7 @@ impl<B> RequestParts<B> {
let ( let (
http::request::Parts { http::request::Parts {
method, method,
mut uri, uri,
version, version,
headers, headers,
extensions, extensions,
@ -387,10 +387,6 @@ impl<B> RequestParts<B> {
body, body,
) = req.into_parts(); ) = req.into_parts();
if let Some(original_uri) = extensions.get::<OriginalUri>() {
uri = original_uri.0.clone();
};
RequestParts { RequestParts {
method, method,
uri, uri,

View file

@ -459,7 +459,7 @@ where
} }
fn call(&mut self, mut req: Request<B>) -> Self::Future { fn call(&mut self, mut req: Request<B>) -> Self::Future {
if let Some(captures) = self.pattern.full_match(req.uri().path()) { if let Some(captures) = self.pattern.full_match(&req) {
insert_url_params(&mut req, captures); insert_url_params(&mut req, captures);
let fut = self.svc.clone().oneshot(req); let fut = self.svc.clone().oneshot(req);
RouteFuture::a(fut) RouteFuture::a(fut)
@ -606,8 +606,8 @@ impl PathPattern {
})) }))
} }
pub(crate) fn full_match(&self, path: &str) -> Option<Captures> { pub(crate) fn full_match<B>(&self, req: &Request<B>) -> Option<Captures> {
self.do_match(path).and_then(|match_| { self.do_match(req).and_then(|match_| {
if match_.full_match { if match_.full_match {
Some(match_.captures) Some(match_.captures)
} else { } else {
@ -616,12 +616,18 @@ impl PathPattern {
}) })
} }
pub(crate) fn prefix_match<'a>(&self, path: &'a str) -> Option<(&'a str, Captures)> { pub(crate) fn prefix_match<'a, B>(&self, req: &'a Request<B>) -> Option<(&'a str, Captures)> {
self.do_match(path) self.do_match(req)
.map(|match_| (match_.matched, match_.captures)) .map(|match_| (match_.matched, match_.captures))
} }
fn do_match<'a>(&self, path: &'a str) -> Option<Match<'a>> { fn do_match<'a, B>(&self, req: &'a Request<B>) -> Option<Match<'a>> {
let path = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
nested_uri.0.path()
} else {
req.uri().path()
};
self.0.full_path_regex.captures(path).map(|captures| { self.0.full_path_regex.captures(path).map(|captures| {
let matched = captures.get(0).unwrap(); let matched = captures.get(0).unwrap();
let full_match = matched.as_str() == path; let full_match = matched.as_str() == path;
@ -864,16 +870,15 @@ where
} }
fn call(&mut self, mut req: Request<B>) -> Self::Future { fn call(&mut self, mut req: Request<B>) -> Self::Future {
if req.extensions().get::<OriginalUri>().is_none() { let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) {
let original_uri = OriginalUri(req.uri().clone()); let uri = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
req.extensions_mut().insert(original_uri); &nested_uri.0
} } else {
req.uri()
};
let f = if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) { let without_prefix = strip_prefix(uri, prefix);
let without_prefix = strip_prefix(req.uri(), prefix); req.extensions_mut().insert(NestedUri(without_prefix));
req.extensions_mut()
.insert(NestedUri(without_prefix.clone()));
*req.uri_mut() = without_prefix;
insert_url_params(&mut req, captures); insert_url_params(&mut req, captures);
let fut = self.svc.clone().oneshot(req); let fut = self.svc.clone().oneshot(req);
@ -887,11 +892,6 @@ where
} }
} }
/// `Nested` changes the incoming requests URI. This will be saved as an
/// extension so extractors can still access the original URI.
#[derive(Clone)]
pub(crate) struct OriginalUri(pub(crate) Uri);
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri { fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
let path_and_query = if let Some(path_and_query) = uri.path_and_query() { let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) { let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
@ -953,8 +953,9 @@ mod tests {
fn assert_match(route_spec: &'static str, path: &'static str) { fn assert_match(route_spec: &'static str, path: &'static str) {
let route = PathPattern::new(route_spec); let route = PathPattern::new(route_spec);
let req = Request::builder().uri(path).body(()).unwrap();
assert!( assert!(
route.full_match(path).is_some(), route.full_match(&req).is_some(),
"`{}` doesn't match `{}`", "`{}` doesn't match `{}`",
path, path,
route_spec route_spec
@ -963,8 +964,9 @@ mod tests {
fn refute_match(route_spec: &'static str, path: &'static str) { fn refute_match(route_spec: &'static str, path: &'static str) {
let route = PathPattern::new(route_spec); let route = PathPattern::new(route_spec);
let req = Request::builder().uri(path).body(()).unwrap();
assert!( assert!(
route.full_match(path).is_none(), route.full_match(&req).is_none(),
"`{}` did match `{}` (but shouldn't)", "`{}` did match `{}` (but shouldn't)",
path, path,
route_spec route_spec

View file

@ -1,4 +1,5 @@
use super::*; use super::*;
use crate::body::box_body;
use std::collections::HashMap; use std::collections::HashMap;
#[tokio::test] #[tokio::test]
@ -149,6 +150,7 @@ async fn nested_url_extractor() {
.send() .send()
.await .await
.unwrap(); .unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz"); assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
let res = client let res = client
@ -156,6 +158,7 @@ async fn nested_url_extractor() {
.send() .send()
.await .await
.unwrap(); .unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/foo/bar/qux"); assert_eq!(res.text().await.unwrap(), "/foo/bar/qux");
} }
@ -181,5 +184,35 @@ async fn nested_url_nested_extractor() {
.send() .send()
.await .await
.unwrap(); .unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/baz"); assert_eq!(res.text().await.unwrap(), "/baz");
} }
#[tokio::test]
async fn nested_service_sees_original_uri() {
let app = nest(
"/foo",
nest(
"/bar",
route(
"/baz",
service_fn(|req: Request<Body>| async move {
let body = box_body(Body::from(req.uri().to_string()));
Ok::<_, Infallible>(Response::new(body))
}),
),
),
);
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/foo/bar/baz", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
}