mirror of
https://github.com/tokio-rs/axum.git
synced 2025-03-13 19:27:53 +01:00
Change nested routes to see the URI with prefix stripped (#197)
This commit is contained in:
parent
cb637f1124
commit
e22045d42f
10 changed files with 255 additions and 55 deletions
|
@ -14,7 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- Implement `std::error::Error` for all rejections ([#153](https://github.com/tokio-rs/axum/pull/153))
|
||||
- Add `RoutingDsl::or` for combining routes ([#108](https://github.com/tokio-rs/axum/pull/108))
|
||||
- Add `handle_error` to `service::OnMethod` ([#160](https://github.com/tokio-rs/axum/pull/160))
|
||||
- Add `NestedUri` for extracting request URI in nested services ([#161](https://github.com/tokio-rs/axum/pull/161))
|
||||
- Add `OriginalUri` for extracting original request URI in nested services ([#197](https://github.com/tokio-rs/axum/pull/197))
|
||||
- Implement `FromRequest` for `http::Extensions`
|
||||
- Implement SSE as an `IntoResponse` instead of a service ([#98](https://github.com/tokio-rs/axum/pull/98))
|
||||
- Add `Headers` for easily customizing headers on a response ([#193](https://github.com/tokio-rs/axum/pull/193))
|
||||
|
@ -35,7 +35,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- Ensure a `HandleError` service created from `ServiceExt::handle_error`
|
||||
_does not_ implement `RoutingDsl` as that could lead to confusing routing
|
||||
behavior ([#120](https://github.com/tokio-rs/axum/pull/120))
|
||||
- Fix `Uri` extractor not being the full URI if using `nest` ([#156](https://github.com/tokio-rs/axum/pull/156))
|
||||
- Implement `routing::MethodFilter` via [`bitflags`](https://crates.io/crates/bitflags)
|
||||
- Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead
|
||||
- `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108))
|
||||
|
|
|
@ -51,6 +51,8 @@ futures = "0.3"
|
|||
reqwest = { version = "0.11", features = ["json", "stream"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread", "net"] }
|
||||
uuid = { version = "0.8", features = ["serde", "v4"] }
|
||||
tokio-stream = "0.1"
|
||||
|
||||
[dev-dependencies.tower]
|
||||
version = "0.4"
|
||||
|
|
|
@ -316,7 +316,7 @@ pub use self::{
|
|||
path::Path,
|
||||
query::Query,
|
||||
raw_query::RawQuery,
|
||||
request_parts::NestedUri,
|
||||
request_parts::OriginalUri,
|
||||
request_parts::{Body, BodyStream},
|
||||
};
|
||||
#[doc(no_inline)]
|
||||
|
|
|
@ -102,16 +102,6 @@ define_rejection! {
|
|||
pub struct InvalidFormContentType;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = INTERNAL_SERVER_ERROR]
|
||||
#[body = "`NestedUri` extractor used for route that isn't nested"]
|
||||
/// Rejection type used if you try and extract [`NestedUri`] from a route that
|
||||
/// isn't nested.
|
||||
///
|
||||
/// [`NestedUri`]: crate::extract::NestedUri
|
||||
pub struct NotNested;
|
||||
}
|
||||
|
||||
/// Rejection type for [`Path`](super::Path) if the capture route
|
||||
/// param didn't have the expected type.
|
||||
#[derive(Debug)]
|
||||
|
|
|
@ -82,10 +82,10 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Extractor that gets the request URI for a nested service.
|
||||
/// Extractor that gets the original request URI regardless of nesting.
|
||||
///
|
||||
/// This is necessary since [`Uri`](http::Uri), when used as an extractor, will
|
||||
/// always be the full URI.
|
||||
/// have the prefix stripped if used in a nested service.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
@ -94,15 +94,15 @@ where
|
|||
/// handler::get,
|
||||
/// route,
|
||||
/// routing::{nest, RoutingDsl},
|
||||
/// extract::NestedUri,
|
||||
/// extract::OriginalUri,
|
||||
/// http::Uri
|
||||
/// };
|
||||
///
|
||||
/// let api_routes = route(
|
||||
/// "/users",
|
||||
/// get(|uri: Uri, NestedUri(nested_uri): NestedUri| async {
|
||||
/// // `uri` is `/api/users`
|
||||
/// // `nested_uri` is `/users`
|
||||
/// get(|uri: Uri, OriginalUri(original_uri): OriginalUri| async {
|
||||
/// // `uri` is `/users`
|
||||
/// // `original_uri` is `/api/users`
|
||||
/// }),
|
||||
/// );
|
||||
///
|
||||
|
@ -112,19 +112,19 @@ where
|
|||
/// # };
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NestedUri(pub Uri);
|
||||
pub struct OriginalUri(pub Uri);
|
||||
|
||||
#[async_trait]
|
||||
impl<B> FromRequest<B> for NestedUri
|
||||
impl<B> FromRequest<B> for OriginalUri
|
||||
where
|
||||
B: Send,
|
||||
{
|
||||
type Rejection = NotNested;
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
let uri = Extension::<Self>::from_request(req)
|
||||
.await
|
||||
.map_err(|_| NotNested)?
|
||||
.unwrap_or_else(|_| Extension(OriginalUri(req.uri().clone())))
|
||||
.0;
|
||||
Ok(uri)
|
||||
}
|
||||
|
|
|
@ -304,6 +304,11 @@
|
|||
//! # };
|
||||
//! ```
|
||||
//!
|
||||
//! Note that nested routes will not see the orignal request URI but instead
|
||||
//! have the matched prefix stripped. This is necessary for services like static
|
||||
//! file serving to work. Use [`OriginalUri`] if you need the original request
|
||||
//! URI.
|
||||
//!
|
||||
//! # Extractors
|
||||
//!
|
||||
//! An extractor is a type that implements [`FromRequest`]. Extractors is how
|
||||
|
@ -747,6 +752,7 @@
|
|||
//! [examples]: https://github.com/tokio-rs/axum/tree/main/examples
|
||||
//! [`RoutingDsl::or`]: crate::routing::RoutingDsl::or
|
||||
//! [`axum::Server`]: hyper::server::Server
|
||||
//! [`OriginalUri`]: crate::extract::OriginalUri
|
||||
|
||||
#![warn(
|
||||
clippy::all,
|
||||
|
|
|
@ -6,7 +6,7 @@ use crate::{
|
|||
buffer::MpscBuffer,
|
||||
extract::{
|
||||
connect_info::{Connected, IntoMakeServiceWithConnectInfo},
|
||||
NestedUri,
|
||||
OriginalUri,
|
||||
},
|
||||
service::{HandleError, HandleErrorFromRouter},
|
||||
util::ByteStr,
|
||||
|
@ -690,11 +690,7 @@ impl PathPattern {
|
|||
}
|
||||
|
||||
fn do_match<'a, B>(&self, req: &'a Request<B>) -> Option<Match<'a>> {
|
||||
let path = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
|
||||
nested_uri.0.path()
|
||||
} else {
|
||||
req.uri().path()
|
||||
};
|
||||
let path = req.uri().path();
|
||||
|
||||
self.0.full_path_regex.captures(path).map(|captures| {
|
||||
let matched = captures.get(0).unwrap();
|
||||
|
@ -948,15 +944,14 @@ where
|
|||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
||||
let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) {
|
||||
let uri = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
|
||||
&nested_uri.0
|
||||
} else {
|
||||
req.uri()
|
||||
};
|
||||
if req.extensions().get::<OriginalUri>().is_none() {
|
||||
let original_uri = OriginalUri(req.uri().clone());
|
||||
req.extensions_mut().insert(original_uri);
|
||||
}
|
||||
|
||||
let without_prefix = strip_prefix(uri, prefix);
|
||||
req.extensions_mut().insert(NestedUri(without_prefix));
|
||||
let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) {
|
||||
let without_prefix = strip_prefix(req.uri(), prefix);
|
||||
*req.uri_mut() = without_prefix;
|
||||
|
||||
insert_url_params(&mut req, captures);
|
||||
let fut = self.svc.clone().oneshot(req);
|
||||
|
|
|
@ -47,6 +47,8 @@ where
|
|||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
|
||||
let original_uri = req.uri().clone();
|
||||
|
||||
if let Some(count) = req.extensions_mut().get_mut::<OrDepth>() {
|
||||
count.increment();
|
||||
} else {
|
||||
|
@ -58,6 +60,7 @@ where
|
|||
f: self.first.clone().oneshot(req),
|
||||
},
|
||||
second: Some(self.second.clone()),
|
||||
original_uri: Some(original_uri),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -72,6 +75,9 @@ pin_project! {
|
|||
#[pin]
|
||||
state: State<A, B, ReqBody>,
|
||||
second: Option<B>,
|
||||
// Some services, namely `Nested`, mutates the request URI so we must
|
||||
// restore it to its original state before calling `second`
|
||||
original_uri: Option<http::Uri>,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -115,6 +121,8 @@ where
|
|||
return Poll::Ready(Ok(response));
|
||||
};
|
||||
|
||||
*req.uri_mut() = this.original_uri.take().unwrap();
|
||||
|
||||
let mut leaving_outermost_or = false;
|
||||
if let Some(depth) = req.extensions_mut().get_mut::<OrDepth>() {
|
||||
if depth == 1 {
|
||||
|
|
|
@ -151,7 +151,7 @@ async fn nested_url_extractor() {
|
|||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
|
||||
assert_eq!(res.text().await.unwrap(), "/baz");
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo/bar/qux", addr))
|
||||
|
@ -159,18 +159,18 @@ async fn nested_url_extractor() {
|
|||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await.unwrap(), "/foo/bar/qux");
|
||||
assert_eq!(res.text().await.unwrap(), "/qux");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn nested_url_nested_extractor() {
|
||||
async fn nested_url_original_extractor() {
|
||||
let app = nest(
|
||||
"/foo",
|
||||
nest(
|
||||
"/bar",
|
||||
route(
|
||||
"/baz",
|
||||
get(|uri: extract::NestedUri| async move { uri.0.to_string() }),
|
||||
get(|uri: extract::OriginalUri| async move { uri.0.to_string() }),
|
||||
),
|
||||
),
|
||||
);
|
||||
|
@ -185,11 +185,11 @@ async fn nested_url_nested_extractor() {
|
|||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await.unwrap(), "/baz");
|
||||
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn nested_service_sees_original_uri() {
|
||||
async fn nested_service_sees_stripped_uri() {
|
||||
let app = nest(
|
||||
"/foo",
|
||||
nest(
|
||||
|
@ -214,5 +214,29 @@ async fn nested_service_sees_original_uri() {
|
|||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
|
||||
assert_eq!(res.text().await.unwrap(), "/baz");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn nest_static_file_server() {
|
||||
let app = nest(
|
||||
"/static",
|
||||
service::get(tower_http::services::ServeDir::new(".")).handle_error(|error| {
|
||||
Ok::<_, Infallible>((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
))
|
||||
}),
|
||||
);
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/static/README.md", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
|
198
src/tests/or.rs
198
src/tests/or.rs
|
@ -1,5 +1,8 @@
|
|||
use serde_json::{json, Value};
|
||||
use tower::{limit::ConcurrencyLimitLayer, timeout::TimeoutLayer};
|
||||
|
||||
use crate::{extract::OriginalUri, response::IntoResponse, Json};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -302,17 +305,190 @@ async fn services() {
|
|||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
// TODO(david): can we make this not compile?
|
||||
// #[tokio::test]
|
||||
// async fn foo() {
|
||||
// let svc_one = service_fn(|_: Request<Body>| async {
|
||||
// Ok::<_, hyper::Error>(Response::new(Body::empty()))
|
||||
// })
|
||||
// .handle_error::<_, _, hyper::Error>(|_| Ok(StatusCode::INTERNAL_SERVER_ERROR));
|
||||
async fn all_the_uris(
|
||||
uri: Uri,
|
||||
OriginalUri(original_uri): OriginalUri,
|
||||
req: Request<Body>,
|
||||
) -> impl IntoResponse {
|
||||
Json(json!({
|
||||
"uri": uri.to_string(),
|
||||
"request_uri": req.uri().to_string(),
|
||||
"original_uri": original_uri.to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
// let svc_two = svc_one.clone();
|
||||
#[tokio::test]
|
||||
async fn nesting_and_seeing_the_right_uri() {
|
||||
let one = nest("/foo", route("/bar", get(all_the_uris)));
|
||||
let two = route("/foo", get(all_the_uris));
|
||||
|
||||
// let app = svc_one.or(svc_two);
|
||||
let addr = run_in_background(one.or(two)).await;
|
||||
|
||||
// let addr = run_in_background(app).await;
|
||||
// }
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo/bar", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
res.json::<Value>().await.unwrap(),
|
||||
json!({
|
||||
"uri": "/bar",
|
||||
"request_uri": "/bar",
|
||||
"original_uri": "/foo/bar",
|
||||
})
|
||||
);
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
res.json::<Value>().await.unwrap(),
|
||||
json!({
|
||||
"uri": "/foo",
|
||||
"request_uri": "/foo",
|
||||
"original_uri": "/foo",
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() {
|
||||
let one = nest("/foo", nest("/bar", route("/baz", get(all_the_uris))));
|
||||
let two = route("/foo", get(all_the_uris));
|
||||
|
||||
let addr = run_in_background(one.or(two)).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo/bar/baz", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
res.json::<Value>().await.unwrap(),
|
||||
json!({
|
||||
"uri": "/baz",
|
||||
"request_uri": "/baz",
|
||||
"original_uri": "/foo/bar/baz",
|
||||
})
|
||||
);
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
res.json::<Value>().await.unwrap(),
|
||||
json!({
|
||||
"uri": "/foo",
|
||||
"request_uri": "/foo",
|
||||
"original_uri": "/foo",
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn nesting_and_seeing_the_right_uri_ors_with_nesting() {
|
||||
let one = nest("/foo", nest("/bar", route("/baz", get(all_the_uris))));
|
||||
let two = nest("/foo", route("/qux", get(all_the_uris)));
|
||||
let three = route("/foo", get(all_the_uris));
|
||||
|
||||
let addr = run_in_background(one.or(two).or(three)).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo/bar/baz", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
res.json::<Value>().await.unwrap(),
|
||||
json!({
|
||||
"uri": "/baz",
|
||||
"request_uri": "/baz",
|
||||
"original_uri": "/foo/bar/baz",
|
||||
})
|
||||
);
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo/qux", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
res.json::<Value>().await.unwrap(),
|
||||
json!({
|
||||
"uri": "/qux",
|
||||
"request_uri": "/qux",
|
||||
"original_uri": "/foo/qux",
|
||||
})
|
||||
);
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
res.json::<Value>().await.unwrap(),
|
||||
json!({
|
||||
"uri": "/foo",
|
||||
"request_uri": "/foo",
|
||||
"original_uri": "/foo",
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() {
|
||||
let one = nest("/foo", nest("/bar", route("/baz", get(all_the_uris))));
|
||||
let two = route("/foo/bar", get(all_the_uris));
|
||||
|
||||
let addr = run_in_background(one.or(two)).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo/bar/baz", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
res.json::<Value>().await.unwrap(),
|
||||
json!({
|
||||
"uri": "/baz",
|
||||
"request_uri": "/baz",
|
||||
"original_uri": "/foo/bar/baz",
|
||||
})
|
||||
);
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo/bar", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(
|
||||
res.json::<Value>().await.unwrap(),
|
||||
json!({
|
||||
"uri": "/foo/bar",
|
||||
"request_uri": "/foo/bar",
|
||||
"original_uri": "/foo/bar",
|
||||
})
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue