From e22045d42f91757f7224bdad35385e1119d5bfbe Mon Sep 17 00:00:00 2001 From: David Pedersen <david.pdrsn@gmail.com> Date: Wed, 18 Aug 2021 09:48:36 +0200 Subject: [PATCH] Change nested routes to see the URI with prefix stripped (#197) --- CHANGELOG.md | 3 +- Cargo.toml | 2 + src/extract/mod.rs | 2 +- src/extract/rejection.rs | 10 -- src/extract/request_parts.rs | 20 ++-- src/lib.rs | 6 ++ src/routing/mod.rs | 23 ++-- src/routing/or.rs | 8 ++ src/tests/nest.rs | 38 +++++-- src/tests/or.rs | 198 +++++++++++++++++++++++++++++++++-- 10 files changed, 255 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b737428..ab4ea5f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Implement `std::error::Error` for all rejections ([#153](https://github.com/tokio-rs/axum/pull/153)) - Add `RoutingDsl::or` for combining routes ([#108](https://github.com/tokio-rs/axum/pull/108)) - Add `handle_error` to `service::OnMethod` ([#160](https://github.com/tokio-rs/axum/pull/160)) -- Add `NestedUri` for extracting request URI in nested services ([#161](https://github.com/tokio-rs/axum/pull/161)) +- Add `OriginalUri` for extracting original request URI in nested services ([#197](https://github.com/tokio-rs/axum/pull/197)) - Implement `FromRequest` for `http::Extensions` - Implement SSE as an `IntoResponse` instead of a service ([#98](https://github.com/tokio-rs/axum/pull/98)) - Add `Headers` for easily customizing headers on a response ([#193](https://github.com/tokio-rs/axum/pull/193)) @@ -35,7 +35,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ensure a `HandleError` service created from `ServiceExt::handle_error` _does not_ implement `RoutingDsl` as that could lead to confusing routing behavior ([#120](https://github.com/tokio-rs/axum/pull/120)) -- Fix `Uri` extractor not being the full URI if using `nest` ([#156](https://github.com/tokio-rs/axum/pull/156)) - Implement `routing::MethodFilter` via [`bitflags`](https://crates.io/crates/bitflags) - Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead - `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108)) diff --git a/Cargo.toml b/Cargo.toml index 4a54b2ae..af2b4846 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,8 @@ futures = "0.3" reqwest = { version = "0.11", features = ["json", "stream"] } serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread", "net"] } +uuid = { version = "0.8", features = ["serde", "v4"] } +tokio-stream = "0.1" [dev-dependencies.tower] version = "0.4" diff --git a/src/extract/mod.rs b/src/extract/mod.rs index 0d5e247f..8e3adaea 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -316,7 +316,7 @@ pub use self::{ path::Path, query::Query, raw_query::RawQuery, - request_parts::NestedUri, + request_parts::OriginalUri, request_parts::{Body, BodyStream}, }; #[doc(no_inline)] diff --git a/src/extract/rejection.rs b/src/extract/rejection.rs index 89c05835..289577af 100644 --- a/src/extract/rejection.rs +++ b/src/extract/rejection.rs @@ -102,16 +102,6 @@ define_rejection! { pub struct InvalidFormContentType; } -define_rejection! { - #[status = INTERNAL_SERVER_ERROR] - #[body = "`NestedUri` extractor used for route that isn't nested"] - /// Rejection type used if you try and extract [`NestedUri`] from a route that - /// isn't nested. - /// - /// [`NestedUri`]: crate::extract::NestedUri - pub struct NotNested; -} - /// Rejection type for [`Path`](super::Path) if the capture route /// param didn't have the expected type. #[derive(Debug)] diff --git a/src/extract/request_parts.rs b/src/extract/request_parts.rs index 4a6e2401..60efdf4e 100644 --- a/src/extract/request_parts.rs +++ b/src/extract/request_parts.rs @@ -82,10 +82,10 @@ where } } -/// Extractor that gets the request URI for a nested service. +/// Extractor that gets the original request URI regardless of nesting. /// /// This is necessary since [`Uri`](http::Uri), when used as an extractor, will -/// always be the full URI. +/// have the prefix stripped if used in a nested service. /// /// # Example /// @@ -94,15 +94,15 @@ where /// handler::get, /// route, /// routing::{nest, RoutingDsl}, -/// extract::NestedUri, +/// extract::OriginalUri, /// http::Uri /// }; /// /// let api_routes = route( /// "/users", -/// get(|uri: Uri, NestedUri(nested_uri): NestedUri| async { -/// // `uri` is `/api/users` -/// // `nested_uri` is `/users` +/// get(|uri: Uri, OriginalUri(original_uri): OriginalUri| async { +/// // `uri` is `/users` +/// // `original_uri` is `/api/users` /// }), /// ); /// @@ -112,19 +112,19 @@ where /// # }; /// ``` #[derive(Debug, Clone)] -pub struct NestedUri(pub Uri); +pub struct OriginalUri(pub Uri); #[async_trait] -impl<B> FromRequest<B> for NestedUri +impl<B> FromRequest<B> for OriginalUri where B: Send, { - type Rejection = NotNested; + type Rejection = Infallible; async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { let uri = Extension::<Self>::from_request(req) .await - .map_err(|_| NotNested)? + .unwrap_or_else(|_| Extension(OriginalUri(req.uri().clone()))) .0; Ok(uri) } diff --git a/src/lib.rs b/src/lib.rs index 28f0c8fb..14f4292c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -304,6 +304,11 @@ //! # }; //! ``` //! +//! Note that nested routes will not see the orignal request URI but instead +//! have the matched prefix stripped. This is necessary for services like static +//! file serving to work. Use [`OriginalUri`] if you need the original request +//! URI. +//! //! # Extractors //! //! An extractor is a type that implements [`FromRequest`]. Extractors is how @@ -747,6 +752,7 @@ //! [examples]: https://github.com/tokio-rs/axum/tree/main/examples //! [`RoutingDsl::or`]: crate::routing::RoutingDsl::or //! [`axum::Server`]: hyper::server::Server +//! [`OriginalUri`]: crate::extract::OriginalUri #![warn( clippy::all, diff --git a/src/routing/mod.rs b/src/routing/mod.rs index e991adff..e24e7ecc 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -6,7 +6,7 @@ use crate::{ buffer::MpscBuffer, extract::{ connect_info::{Connected, IntoMakeServiceWithConnectInfo}, - NestedUri, + OriginalUri, }, service::{HandleError, HandleErrorFromRouter}, util::ByteStr, @@ -690,11 +690,7 @@ impl PathPattern { } fn do_match<'a, B>(&self, req: &'a Request<B>) -> Option<Match<'a>> { - let path = if let Some(nested_uri) = req.extensions().get::<NestedUri>() { - nested_uri.0.path() - } else { - req.uri().path() - }; + let path = req.uri().path(); self.0.full_path_regex.captures(path).map(|captures| { let matched = captures.get(0).unwrap(); @@ -948,15 +944,14 @@ where } fn call(&mut self, mut req: Request<B>) -> Self::Future { - let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) { - let uri = if let Some(nested_uri) = req.extensions().get::<NestedUri>() { - &nested_uri.0 - } else { - req.uri() - }; + if req.extensions().get::<OriginalUri>().is_none() { + let original_uri = OriginalUri(req.uri().clone()); + req.extensions_mut().insert(original_uri); + } - let without_prefix = strip_prefix(uri, prefix); - req.extensions_mut().insert(NestedUri(without_prefix)); + let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) { + let without_prefix = strip_prefix(req.uri(), prefix); + *req.uri_mut() = without_prefix; insert_url_params(&mut req, captures); let fut = self.svc.clone().oneshot(req); diff --git a/src/routing/or.rs b/src/routing/or.rs index 4d5ee97e..836bf9db 100644 --- a/src/routing/or.rs +++ b/src/routing/or.rs @@ -47,6 +47,8 @@ where } fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future { + let original_uri = req.uri().clone(); + if let Some(count) = req.extensions_mut().get_mut::<OrDepth>() { count.increment(); } else { @@ -58,6 +60,7 @@ where f: self.first.clone().oneshot(req), }, second: Some(self.second.clone()), + original_uri: Some(original_uri), } } } @@ -72,6 +75,9 @@ pin_project! { #[pin] state: State<A, B, ReqBody>, second: Option<B>, + // Some services, namely `Nested`, mutates the request URI so we must + // restore it to its original state before calling `second` + original_uri: Option<http::Uri>, } } @@ -115,6 +121,8 @@ where return Poll::Ready(Ok(response)); }; + *req.uri_mut() = this.original_uri.take().unwrap(); + let mut leaving_outermost_or = false; if let Some(depth) = req.extensions_mut().get_mut::<OrDepth>() { if depth == 1 { diff --git a/src/tests/nest.rs b/src/tests/nest.rs index bde77a42..3bf8587c 100644 --- a/src/tests/nest.rs +++ b/src/tests/nest.rs @@ -151,7 +151,7 @@ async fn nested_url_extractor() { .await .unwrap(); assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await.unwrap(), "/foo/bar/baz"); + assert_eq!(res.text().await.unwrap(), "/baz"); let res = client .get(format!("http://{}/foo/bar/qux", addr)) @@ -159,18 +159,18 @@ async fn nested_url_extractor() { .await .unwrap(); assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await.unwrap(), "/foo/bar/qux"); + assert_eq!(res.text().await.unwrap(), "/qux"); } #[tokio::test] -async fn nested_url_nested_extractor() { +async fn nested_url_original_extractor() { let app = nest( "/foo", nest( "/bar", route( "/baz", - get(|uri: extract::NestedUri| async move { uri.0.to_string() }), + get(|uri: extract::OriginalUri| async move { uri.0.to_string() }), ), ), ); @@ -185,11 +185,11 @@ async fn nested_url_nested_extractor() { .await .unwrap(); assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await.unwrap(), "/baz"); + assert_eq!(res.text().await.unwrap(), "/foo/bar/baz"); } #[tokio::test] -async fn nested_service_sees_original_uri() { +async fn nested_service_sees_stripped_uri() { let app = nest( "/foo", nest( @@ -214,5 +214,29 @@ async fn nested_service_sees_original_uri() { .await .unwrap(); assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await.unwrap(), "/foo/bar/baz"); + assert_eq!(res.text().await.unwrap(), "/baz"); +} + +#[tokio::test] +async fn nest_static_file_server() { + let app = nest( + "/static", + service::get(tower_http::services::ServeDir::new(".")).handle_error(|error| { + Ok::<_, Infallible>(( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Unhandled internal error: {}", error), + )) + }), + ); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/static/README.md", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); } diff --git a/src/tests/or.rs b/src/tests/or.rs index b8e3766f..e7b44f22 100644 --- a/src/tests/or.rs +++ b/src/tests/or.rs @@ -1,5 +1,8 @@ +use serde_json::{json, Value}; use tower::{limit::ConcurrencyLimitLayer, timeout::TimeoutLayer}; +use crate::{extract::OriginalUri, response::IntoResponse, Json}; + use super::*; #[tokio::test] @@ -302,17 +305,190 @@ async fn services() { assert_eq!(res.status(), StatusCode::OK); } -// TODO(david): can we make this not compile? -// #[tokio::test] -// async fn foo() { -// let svc_one = service_fn(|_: Request<Body>| async { -// Ok::<_, hyper::Error>(Response::new(Body::empty())) -// }) -// .handle_error::<_, _, hyper::Error>(|_| Ok(StatusCode::INTERNAL_SERVER_ERROR)); +async fn all_the_uris( + uri: Uri, + OriginalUri(original_uri): OriginalUri, + req: Request<Body>, +) -> impl IntoResponse { + Json(json!({ + "uri": uri.to_string(), + "request_uri": req.uri().to_string(), + "original_uri": original_uri.to_string(), + })) +} -// let svc_two = svc_one.clone(); +#[tokio::test] +async fn nesting_and_seeing_the_right_uri() { + let one = nest("/foo", route("/bar", get(all_the_uris))); + let two = route("/foo", get(all_the_uris)); -// let app = svc_one.or(svc_two); + let addr = run_in_background(one.or(two)).await; -// let addr = run_in_background(app).await; -// } + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/foo/bar", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.json::<Value>().await.unwrap(), + json!({ + "uri": "/bar", + "request_uri": "/bar", + "original_uri": "/foo/bar", + }) + ); + + let res = client + .get(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.json::<Value>().await.unwrap(), + json!({ + "uri": "/foo", + "request_uri": "/foo", + "original_uri": "/foo", + }) + ); +} + +#[tokio::test] +async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() { + let one = nest("/foo", nest("/bar", route("/baz", get(all_the_uris)))); + let two = route("/foo", get(all_the_uris)); + + let addr = run_in_background(one.or(two)).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/foo/bar/baz", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.json::<Value>().await.unwrap(), + json!({ + "uri": "/baz", + "request_uri": "/baz", + "original_uri": "/foo/bar/baz", + }) + ); + + let res = client + .get(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.json::<Value>().await.unwrap(), + json!({ + "uri": "/foo", + "request_uri": "/foo", + "original_uri": "/foo", + }) + ); +} + +#[tokio::test] +async fn nesting_and_seeing_the_right_uri_ors_with_nesting() { + let one = nest("/foo", nest("/bar", route("/baz", get(all_the_uris)))); + let two = nest("/foo", route("/qux", get(all_the_uris))); + let three = route("/foo", get(all_the_uris)); + + let addr = run_in_background(one.or(two).or(three)).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/foo/bar/baz", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.json::<Value>().await.unwrap(), + json!({ + "uri": "/baz", + "request_uri": "/baz", + "original_uri": "/foo/bar/baz", + }) + ); + + let res = client + .get(format!("http://{}/foo/qux", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.json::<Value>().await.unwrap(), + json!({ + "uri": "/qux", + "request_uri": "/qux", + "original_uri": "/foo/qux", + }) + ); + + let res = client + .get(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.json::<Value>().await.unwrap(), + json!({ + "uri": "/foo", + "request_uri": "/foo", + "original_uri": "/foo", + }) + ); +} + +#[tokio::test] +async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() { + let one = nest("/foo", nest("/bar", route("/baz", get(all_the_uris)))); + let two = route("/foo/bar", get(all_the_uris)); + + let addr = run_in_background(one.or(two)).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/foo/bar/baz", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.json::<Value>().await.unwrap(), + json!({ + "uri": "/baz", + "request_uri": "/baz", + "original_uri": "/foo/bar/baz", + }) + ); + + let res = client + .get(format!("http://{}/foo/bar", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.json::<Value>().await.unwrap(), + json!({ + "uri": "/foo/bar", + "request_uri": "/foo/bar", + "original_uri": "/foo/bar", + }) + ); +}