Change nested routes to see the URI with prefix stripped (#197)

This commit is contained in:
David Pedersen 2021-08-18 09:48:36 +02:00 committed by GitHub
parent cb637f1124
commit e22045d42f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 255 additions and 55 deletions

View file

@ -14,7 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Implement `std::error::Error` for all rejections ([#153](https://github.com/tokio-rs/axum/pull/153))
- Add `RoutingDsl::or` for combining routes ([#108](https://github.com/tokio-rs/axum/pull/108))
- Add `handle_error` to `service::OnMethod` ([#160](https://github.com/tokio-rs/axum/pull/160))
- Add `NestedUri` for extracting request URI in nested services ([#161](https://github.com/tokio-rs/axum/pull/161))
- Add `OriginalUri` for extracting original request URI in nested services ([#197](https://github.com/tokio-rs/axum/pull/197))
- Implement `FromRequest` for `http::Extensions`
- Implement SSE as an `IntoResponse` instead of a service ([#98](https://github.com/tokio-rs/axum/pull/98))
- Add `Headers` for easily customizing headers on a response ([#193](https://github.com/tokio-rs/axum/pull/193))
@ -35,7 +35,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Ensure a `HandleError` service created from `ServiceExt::handle_error`
_does not_ implement `RoutingDsl` as that could lead to confusing routing
behavior ([#120](https://github.com/tokio-rs/axum/pull/120))
- Fix `Uri` extractor not being the full URI if using `nest` ([#156](https://github.com/tokio-rs/axum/pull/156))
- Implement `routing::MethodFilter` via [`bitflags`](https://crates.io/crates/bitflags)
- Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead
- `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108))

View file

@ -51,6 +51,8 @@ futures = "0.3"
reqwest = { version = "0.11", features = ["json", "stream"] }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread", "net"] }
uuid = { version = "0.8", features = ["serde", "v4"] }
tokio-stream = "0.1"
[dev-dependencies.tower]
version = "0.4"

View file

@ -316,7 +316,7 @@ pub use self::{
path::Path,
query::Query,
raw_query::RawQuery,
request_parts::NestedUri,
request_parts::OriginalUri,
request_parts::{Body, BodyStream},
};
#[doc(no_inline)]

View file

@ -102,16 +102,6 @@ define_rejection! {
pub struct InvalidFormContentType;
}
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "`NestedUri` extractor used for route that isn't nested"]
/// Rejection type used if you try and extract [`NestedUri`] from a route that
/// isn't nested.
///
/// [`NestedUri`]: crate::extract::NestedUri
pub struct NotNested;
}
/// Rejection type for [`Path`](super::Path) if the capture route
/// param didn't have the expected type.
#[derive(Debug)]

View file

@ -82,10 +82,10 @@ where
}
}
/// Extractor that gets the request URI for a nested service.
/// Extractor that gets the original request URI regardless of nesting.
///
/// This is necessary since [`Uri`](http::Uri), when used as an extractor, will
/// always be the full URI.
/// have the prefix stripped if used in a nested service.
///
/// # Example
///
@ -94,15 +94,15 @@ where
/// handler::get,
/// route,
/// routing::{nest, RoutingDsl},
/// extract::NestedUri,
/// extract::OriginalUri,
/// http::Uri
/// };
///
/// let api_routes = route(
/// "/users",
/// get(|uri: Uri, NestedUri(nested_uri): NestedUri| async {
/// // `uri` is `/api/users`
/// // `nested_uri` is `/users`
/// get(|uri: Uri, OriginalUri(original_uri): OriginalUri| async {
/// // `uri` is `/users`
/// // `original_uri` is `/api/users`
/// }),
/// );
///
@ -112,19 +112,19 @@ where
/// # };
/// ```
#[derive(Debug, Clone)]
pub struct NestedUri(pub Uri);
pub struct OriginalUri(pub Uri);
#[async_trait]
impl<B> FromRequest<B> for NestedUri
impl<B> FromRequest<B> for OriginalUri
where
B: Send,
{
type Rejection = NotNested;
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let uri = Extension::<Self>::from_request(req)
.await
.map_err(|_| NotNested)?
.unwrap_or_else(|_| Extension(OriginalUri(req.uri().clone())))
.0;
Ok(uri)
}

View file

@ -304,6 +304,11 @@
//! # };
//! ```
//!
//! Note that nested routes will not see the orignal request URI but instead
//! have the matched prefix stripped. This is necessary for services like static
//! file serving to work. Use [`OriginalUri`] if you need the original request
//! URI.
//!
//! # Extractors
//!
//! An extractor is a type that implements [`FromRequest`]. Extractors is how
@ -747,6 +752,7 @@
//! [examples]: https://github.com/tokio-rs/axum/tree/main/examples
//! [`RoutingDsl::or`]: crate::routing::RoutingDsl::or
//! [`axum::Server`]: hyper::server::Server
//! [`OriginalUri`]: crate::extract::OriginalUri
#![warn(
clippy::all,

View file

@ -6,7 +6,7 @@ use crate::{
buffer::MpscBuffer,
extract::{
connect_info::{Connected, IntoMakeServiceWithConnectInfo},
NestedUri,
OriginalUri,
},
service::{HandleError, HandleErrorFromRouter},
util::ByteStr,
@ -690,11 +690,7 @@ impl PathPattern {
}
fn do_match<'a, B>(&self, req: &'a Request<B>) -> Option<Match<'a>> {
let path = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
nested_uri.0.path()
} else {
req.uri().path()
};
let path = req.uri().path();
self.0.full_path_regex.captures(path).map(|captures| {
let matched = captures.get(0).unwrap();
@ -948,15 +944,14 @@ where
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) {
let uri = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
&nested_uri.0
} else {
req.uri()
};
if req.extensions().get::<OriginalUri>().is_none() {
let original_uri = OriginalUri(req.uri().clone());
req.extensions_mut().insert(original_uri);
}
let without_prefix = strip_prefix(uri, prefix);
req.extensions_mut().insert(NestedUri(without_prefix));
let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) {
let without_prefix = strip_prefix(req.uri(), prefix);
*req.uri_mut() = without_prefix;
insert_url_params(&mut req, captures);
let fut = self.svc.clone().oneshot(req);

View file

@ -47,6 +47,8 @@ where
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
let original_uri = req.uri().clone();
if let Some(count) = req.extensions_mut().get_mut::<OrDepth>() {
count.increment();
} else {
@ -58,6 +60,7 @@ where
f: self.first.clone().oneshot(req),
},
second: Some(self.second.clone()),
original_uri: Some(original_uri),
}
}
}
@ -72,6 +75,9 @@ pin_project! {
#[pin]
state: State<A, B, ReqBody>,
second: Option<B>,
// Some services, namely `Nested`, mutates the request URI so we must
// restore it to its original state before calling `second`
original_uri: Option<http::Uri>,
}
}
@ -115,6 +121,8 @@ where
return Poll::Ready(Ok(response));
};
*req.uri_mut() = this.original_uri.take().unwrap();
let mut leaving_outermost_or = false;
if let Some(depth) = req.extensions_mut().get_mut::<OrDepth>() {
if depth == 1 {

View file

@ -151,7 +151,7 @@ async fn nested_url_extractor() {
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
assert_eq!(res.text().await.unwrap(), "/baz");
let res = client
.get(format!("http://{}/foo/bar/qux", addr))
@ -159,18 +159,18 @@ async fn nested_url_extractor() {
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/foo/bar/qux");
assert_eq!(res.text().await.unwrap(), "/qux");
}
#[tokio::test]
async fn nested_url_nested_extractor() {
async fn nested_url_original_extractor() {
let app = nest(
"/foo",
nest(
"/bar",
route(
"/baz",
get(|uri: extract::NestedUri| async move { uri.0.to_string() }),
get(|uri: extract::OriginalUri| async move { uri.0.to_string() }),
),
),
);
@ -185,11 +185,11 @@ async fn nested_url_nested_extractor() {
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/baz");
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
}
#[tokio::test]
async fn nested_service_sees_original_uri() {
async fn nested_service_sees_stripped_uri() {
let app = nest(
"/foo",
nest(
@ -214,5 +214,29 @@ async fn nested_service_sees_original_uri() {
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
assert_eq!(res.text().await.unwrap(), "/baz");
}
#[tokio::test]
async fn nest_static_file_server() {
let app = nest(
"/static",
service::get(tower_http::services::ServeDir::new(".")).handle_error(|error| {
Ok::<_, Infallible>((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
))
}),
);
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/static/README.md", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}

View file

@ -1,5 +1,8 @@
use serde_json::{json, Value};
use tower::{limit::ConcurrencyLimitLayer, timeout::TimeoutLayer};
use crate::{extract::OriginalUri, response::IntoResponse, Json};
use super::*;
#[tokio::test]
@ -302,17 +305,190 @@ async fn services() {
assert_eq!(res.status(), StatusCode::OK);
}
// TODO(david): can we make this not compile?
// #[tokio::test]
// async fn foo() {
// let svc_one = service_fn(|_: Request<Body>| async {
// Ok::<_, hyper::Error>(Response::new(Body::empty()))
// })
// .handle_error::<_, _, hyper::Error>(|_| Ok(StatusCode::INTERNAL_SERVER_ERROR));
async fn all_the_uris(
uri: Uri,
OriginalUri(original_uri): OriginalUri,
req: Request<Body>,
) -> impl IntoResponse {
Json(json!({
"uri": uri.to_string(),
"request_uri": req.uri().to_string(),
"original_uri": original_uri.to_string(),
}))
}
// let svc_two = svc_one.clone();
#[tokio::test]
async fn nesting_and_seeing_the_right_uri() {
let one = nest("/foo", route("/bar", get(all_the_uris)));
let two = route("/foo", get(all_the_uris));
// let app = svc_one.or(svc_two);
let addr = run_in_background(one.or(two)).await;
// let addr = run_in_background(app).await;
// }
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/foo/bar", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.json::<Value>().await.unwrap(),
json!({
"uri": "/bar",
"request_uri": "/bar",
"original_uri": "/foo/bar",
})
);
let res = client
.get(format!("http://{}/foo", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.json::<Value>().await.unwrap(),
json!({
"uri": "/foo",
"request_uri": "/foo",
"original_uri": "/foo",
})
);
}
#[tokio::test]
async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() {
let one = nest("/foo", nest("/bar", route("/baz", get(all_the_uris))));
let two = route("/foo", get(all_the_uris));
let addr = run_in_background(one.or(two)).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/foo/bar/baz", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.json::<Value>().await.unwrap(),
json!({
"uri": "/baz",
"request_uri": "/baz",
"original_uri": "/foo/bar/baz",
})
);
let res = client
.get(format!("http://{}/foo", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.json::<Value>().await.unwrap(),
json!({
"uri": "/foo",
"request_uri": "/foo",
"original_uri": "/foo",
})
);
}
#[tokio::test]
async fn nesting_and_seeing_the_right_uri_ors_with_nesting() {
let one = nest("/foo", nest("/bar", route("/baz", get(all_the_uris))));
let two = nest("/foo", route("/qux", get(all_the_uris)));
let three = route("/foo", get(all_the_uris));
let addr = run_in_background(one.or(two).or(three)).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/foo/bar/baz", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.json::<Value>().await.unwrap(),
json!({
"uri": "/baz",
"request_uri": "/baz",
"original_uri": "/foo/bar/baz",
})
);
let res = client
.get(format!("http://{}/foo/qux", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.json::<Value>().await.unwrap(),
json!({
"uri": "/qux",
"request_uri": "/qux",
"original_uri": "/foo/qux",
})
);
let res = client
.get(format!("http://{}/foo", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.json::<Value>().await.unwrap(),
json!({
"uri": "/foo",
"request_uri": "/foo",
"original_uri": "/foo",
})
);
}
#[tokio::test]
async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() {
let one = nest("/foo", nest("/bar", route("/baz", get(all_the_uris))));
let two = route("/foo/bar", get(all_the_uris));
let addr = run_in_background(one.or(two)).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/foo/bar/baz", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.json::<Value>().await.unwrap(),
json!({
"uri": "/baz",
"request_uri": "/baz",
"original_uri": "/foo/bar/baz",
})
);
let res = client
.get(format!("http://{}/foo/bar", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(
res.json::<Value>().await.unwrap(),
json!({
"uri": "/foo/bar",
"request_uri": "/foo/bar",
"original_uri": "/foo/bar",
})
);
}