From e22045d42f91757f7224bdad35385e1119d5bfbe Mon Sep 17 00:00:00 2001
From: David Pedersen <david.pdrsn@gmail.com>
Date: Wed, 18 Aug 2021 09:48:36 +0200
Subject: [PATCH] Change nested routes to see the URI with prefix stripped
 (#197)

---
 CHANGELOG.md                 |   3 +-
 Cargo.toml                   |   2 +
 src/extract/mod.rs           |   2 +-
 src/extract/rejection.rs     |  10 --
 src/extract/request_parts.rs |  20 ++--
 src/lib.rs                   |   6 ++
 src/routing/mod.rs           |  23 ++--
 src/routing/or.rs            |   8 ++
 src/tests/nest.rs            |  38 +++++--
 src/tests/or.rs              | 198 +++++++++++++++++++++++++++++++++--
 10 files changed, 255 insertions(+), 55 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8b737428..ab4ea5f9 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -14,7 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 - Implement `std::error::Error` for all rejections ([#153](https://github.com/tokio-rs/axum/pull/153))
 - Add `RoutingDsl::or` for combining routes ([#108](https://github.com/tokio-rs/axum/pull/108))
 - Add `handle_error` to `service::OnMethod` ([#160](https://github.com/tokio-rs/axum/pull/160))
-- Add `NestedUri` for extracting request URI in nested services ([#161](https://github.com/tokio-rs/axum/pull/161))
+- Add `OriginalUri` for extracting original request URI in nested services ([#197](https://github.com/tokio-rs/axum/pull/197))
 - Implement `FromRequest` for `http::Extensions`
 - Implement SSE as an `IntoResponse` instead of a service ([#98](https://github.com/tokio-rs/axum/pull/98))
 - Add `Headers` for easily customizing headers on a response ([#193](https://github.com/tokio-rs/axum/pull/193))
@@ -35,7 +35,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 - Ensure a `HandleError` service created from `ServiceExt::handle_error`
   _does not_ implement `RoutingDsl` as that could lead to confusing routing
   behavior ([#120](https://github.com/tokio-rs/axum/pull/120))
-- Fix `Uri` extractor not being the full URI if using `nest` ([#156](https://github.com/tokio-rs/axum/pull/156))
 - Implement `routing::MethodFilter` via [`bitflags`](https://crates.io/crates/bitflags)
 - Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead
 - `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108))
diff --git a/Cargo.toml b/Cargo.toml
index 4a54b2ae..af2b4846 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -51,6 +51,8 @@ futures = "0.3"
 reqwest = { version = "0.11", features = ["json", "stream"] }
 serde = { version = "1.0", features = ["derive"] }
 tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread", "net"] }
+uuid = { version = "0.8", features = ["serde", "v4"] }
+tokio-stream = "0.1"
 
 [dev-dependencies.tower]
 version = "0.4"
diff --git a/src/extract/mod.rs b/src/extract/mod.rs
index 0d5e247f..8e3adaea 100644
--- a/src/extract/mod.rs
+++ b/src/extract/mod.rs
@@ -316,7 +316,7 @@ pub use self::{
     path::Path,
     query::Query,
     raw_query::RawQuery,
-    request_parts::NestedUri,
+    request_parts::OriginalUri,
     request_parts::{Body, BodyStream},
 };
 #[doc(no_inline)]
diff --git a/src/extract/rejection.rs b/src/extract/rejection.rs
index 89c05835..289577af 100644
--- a/src/extract/rejection.rs
+++ b/src/extract/rejection.rs
@@ -102,16 +102,6 @@ define_rejection! {
     pub struct InvalidFormContentType;
 }
 
-define_rejection! {
-    #[status = INTERNAL_SERVER_ERROR]
-    #[body = "`NestedUri` extractor used for route that isn't nested"]
-    /// Rejection type used if you try and extract [`NestedUri`] from a route that
-    /// isn't nested.
-    ///
-    /// [`NestedUri`]: crate::extract::NestedUri
-    pub struct NotNested;
-}
-
 /// Rejection type for [`Path`](super::Path) if the capture route
 /// param didn't have the expected type.
 #[derive(Debug)]
diff --git a/src/extract/request_parts.rs b/src/extract/request_parts.rs
index 4a6e2401..60efdf4e 100644
--- a/src/extract/request_parts.rs
+++ b/src/extract/request_parts.rs
@@ -82,10 +82,10 @@ where
     }
 }
 
-/// Extractor that gets the request URI for a nested service.
+/// Extractor that gets the original request URI regardless of nesting.
 ///
 /// This is necessary since [`Uri`](http::Uri), when used as an extractor, will
-/// always be the full URI.
+/// have the prefix stripped if used in a nested service.
 ///
 /// # Example
 ///
@@ -94,15 +94,15 @@ where
 ///     handler::get,
 ///     route,
 ///     routing::{nest, RoutingDsl},
-///     extract::NestedUri,
+///     extract::OriginalUri,
 ///     http::Uri
 /// };
 ///
 /// let api_routes = route(
 ///     "/users",
-///     get(|uri: Uri, NestedUri(nested_uri): NestedUri| async {
-///         // `uri` is `/api/users`
-///         // `nested_uri` is `/users`
+///     get(|uri: Uri, OriginalUri(original_uri): OriginalUri| async {
+///         // `uri` is `/users`
+///         // `original_uri` is `/api/users`
 ///     }),
 /// );
 ///
@@ -112,19 +112,19 @@ where
 /// # };
 /// ```
 #[derive(Debug, Clone)]
-pub struct NestedUri(pub Uri);
+pub struct OriginalUri(pub Uri);
 
 #[async_trait]
-impl<B> FromRequest<B> for NestedUri
+impl<B> FromRequest<B> for OriginalUri
 where
     B: Send,
 {
-    type Rejection = NotNested;
+    type Rejection = Infallible;
 
     async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
         let uri = Extension::<Self>::from_request(req)
             .await
-            .map_err(|_| NotNested)?
+            .unwrap_or_else(|_| Extension(OriginalUri(req.uri().clone())))
             .0;
         Ok(uri)
     }
diff --git a/src/lib.rs b/src/lib.rs
index 28f0c8fb..14f4292c 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -304,6 +304,11 @@
 //! # };
 //! ```
 //!
+//! Note that nested routes will not see the orignal request URI but instead
+//! have the matched prefix stripped. This is necessary for services like static
+//! file serving to work. Use [`OriginalUri`] if you need the original request
+//! URI.
+//!
 //! # Extractors
 //!
 //! An extractor is a type that implements [`FromRequest`]. Extractors is how
@@ -747,6 +752,7 @@
 //! [examples]: https://github.com/tokio-rs/axum/tree/main/examples
 //! [`RoutingDsl::or`]: crate::routing::RoutingDsl::or
 //! [`axum::Server`]: hyper::server::Server
+//! [`OriginalUri`]: crate::extract::OriginalUri
 
 #![warn(
     clippy::all,
diff --git a/src/routing/mod.rs b/src/routing/mod.rs
index e991adff..e24e7ecc 100644
--- a/src/routing/mod.rs
+++ b/src/routing/mod.rs
@@ -6,7 +6,7 @@ use crate::{
     buffer::MpscBuffer,
     extract::{
         connect_info::{Connected, IntoMakeServiceWithConnectInfo},
-        NestedUri,
+        OriginalUri,
     },
     service::{HandleError, HandleErrorFromRouter},
     util::ByteStr,
@@ -690,11 +690,7 @@ impl PathPattern {
     }
 
     fn do_match<'a, B>(&self, req: &'a Request<B>) -> Option<Match<'a>> {
-        let path = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
-            nested_uri.0.path()
-        } else {
-            req.uri().path()
-        };
+        let path = req.uri().path();
 
         self.0.full_path_regex.captures(path).map(|captures| {
             let matched = captures.get(0).unwrap();
@@ -948,15 +944,14 @@ where
     }
 
     fn call(&mut self, mut req: Request<B>) -> Self::Future {
-        let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) {
-            let uri = if let Some(nested_uri) = req.extensions().get::<NestedUri>() {
-                &nested_uri.0
-            } else {
-                req.uri()
-            };
+        if req.extensions().get::<OriginalUri>().is_none() {
+            let original_uri = OriginalUri(req.uri().clone());
+            req.extensions_mut().insert(original_uri);
+        }
 
-            let without_prefix = strip_prefix(uri, prefix);
-            req.extensions_mut().insert(NestedUri(without_prefix));
+        let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) {
+            let without_prefix = strip_prefix(req.uri(), prefix);
+            *req.uri_mut() = without_prefix;
 
             insert_url_params(&mut req, captures);
             let fut = self.svc.clone().oneshot(req);
diff --git a/src/routing/or.rs b/src/routing/or.rs
index 4d5ee97e..836bf9db 100644
--- a/src/routing/or.rs
+++ b/src/routing/or.rs
@@ -47,6 +47,8 @@ where
     }
 
     fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
+        let original_uri = req.uri().clone();
+
         if let Some(count) = req.extensions_mut().get_mut::<OrDepth>() {
             count.increment();
         } else {
@@ -58,6 +60,7 @@ where
                 f: self.first.clone().oneshot(req),
             },
             second: Some(self.second.clone()),
+            original_uri: Some(original_uri),
         }
     }
 }
@@ -72,6 +75,9 @@ pin_project! {
         #[pin]
         state: State<A, B, ReqBody>,
         second: Option<B>,
+        // Some services, namely `Nested`, mutates the request URI so we must
+        // restore it to its original state before calling `second`
+        original_uri: Option<http::Uri>,
     }
 }
 
@@ -115,6 +121,8 @@ where
                         return Poll::Ready(Ok(response));
                     };
 
+                    *req.uri_mut() = this.original_uri.take().unwrap();
+
                     let mut leaving_outermost_or = false;
                     if let Some(depth) = req.extensions_mut().get_mut::<OrDepth>() {
                         if depth == 1 {
diff --git a/src/tests/nest.rs b/src/tests/nest.rs
index bde77a42..3bf8587c 100644
--- a/src/tests/nest.rs
+++ b/src/tests/nest.rs
@@ -151,7 +151,7 @@ async fn nested_url_extractor() {
         .await
         .unwrap();
     assert_eq!(res.status(), StatusCode::OK);
-    assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
+    assert_eq!(res.text().await.unwrap(), "/baz");
 
     let res = client
         .get(format!("http://{}/foo/bar/qux", addr))
@@ -159,18 +159,18 @@ async fn nested_url_extractor() {
         .await
         .unwrap();
     assert_eq!(res.status(), StatusCode::OK);
-    assert_eq!(res.text().await.unwrap(), "/foo/bar/qux");
+    assert_eq!(res.text().await.unwrap(), "/qux");
 }
 
 #[tokio::test]
-async fn nested_url_nested_extractor() {
+async fn nested_url_original_extractor() {
     let app = nest(
         "/foo",
         nest(
             "/bar",
             route(
                 "/baz",
-                get(|uri: extract::NestedUri| async move { uri.0.to_string() }),
+                get(|uri: extract::OriginalUri| async move { uri.0.to_string() }),
             ),
         ),
     );
@@ -185,11 +185,11 @@ async fn nested_url_nested_extractor() {
         .await
         .unwrap();
     assert_eq!(res.status(), StatusCode::OK);
-    assert_eq!(res.text().await.unwrap(), "/baz");
+    assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
 }
 
 #[tokio::test]
-async fn nested_service_sees_original_uri() {
+async fn nested_service_sees_stripped_uri() {
     let app = nest(
         "/foo",
         nest(
@@ -214,5 +214,29 @@ async fn nested_service_sees_original_uri() {
         .await
         .unwrap();
     assert_eq!(res.status(), StatusCode::OK);
-    assert_eq!(res.text().await.unwrap(), "/foo/bar/baz");
+    assert_eq!(res.text().await.unwrap(), "/baz");
+}
+
+#[tokio::test]
+async fn nest_static_file_server() {
+    let app = nest(
+        "/static",
+        service::get(tower_http::services::ServeDir::new(".")).handle_error(|error| {
+            Ok::<_, Infallible>((
+                StatusCode::INTERNAL_SERVER_ERROR,
+                format!("Unhandled internal error: {}", error),
+            ))
+        }),
+    );
+
+    let addr = run_in_background(app).await;
+
+    let client = reqwest::Client::new();
+
+    let res = client
+        .get(format!("http://{}/static/README.md", addr))
+        .send()
+        .await
+        .unwrap();
+    assert_eq!(res.status(), StatusCode::OK);
 }
diff --git a/src/tests/or.rs b/src/tests/or.rs
index b8e3766f..e7b44f22 100644
--- a/src/tests/or.rs
+++ b/src/tests/or.rs
@@ -1,5 +1,8 @@
+use serde_json::{json, Value};
 use tower::{limit::ConcurrencyLimitLayer, timeout::TimeoutLayer};
 
+use crate::{extract::OriginalUri, response::IntoResponse, Json};
+
 use super::*;
 
 #[tokio::test]
@@ -302,17 +305,190 @@ async fn services() {
     assert_eq!(res.status(), StatusCode::OK);
 }
 
-// TODO(david): can we make this not compile?
-// #[tokio::test]
-// async fn foo() {
-//     let svc_one = service_fn(|_: Request<Body>| async {
-//         Ok::<_, hyper::Error>(Response::new(Body::empty()))
-//     })
-//     .handle_error::<_, _, hyper::Error>(|_| Ok(StatusCode::INTERNAL_SERVER_ERROR));
+async fn all_the_uris(
+    uri: Uri,
+    OriginalUri(original_uri): OriginalUri,
+    req: Request<Body>,
+) -> impl IntoResponse {
+    Json(json!({
+        "uri": uri.to_string(),
+        "request_uri": req.uri().to_string(),
+        "original_uri": original_uri.to_string(),
+    }))
+}
 
-//     let svc_two = svc_one.clone();
+#[tokio::test]
+async fn nesting_and_seeing_the_right_uri() {
+    let one = nest("/foo", route("/bar", get(all_the_uris)));
+    let two = route("/foo", get(all_the_uris));
 
-//     let app = svc_one.or(svc_two);
+    let addr = run_in_background(one.or(two)).await;
 
-//     let addr = run_in_background(app).await;
-// }
+    let client = reqwest::Client::new();
+
+    let res = client
+        .get(format!("http://{}/foo/bar", addr))
+        .send()
+        .await
+        .unwrap();
+    assert_eq!(res.status(), StatusCode::OK);
+    assert_eq!(
+        res.json::<Value>().await.unwrap(),
+        json!({
+            "uri": "/bar",
+            "request_uri": "/bar",
+            "original_uri": "/foo/bar",
+        })
+    );
+
+    let res = client
+        .get(format!("http://{}/foo", addr))
+        .send()
+        .await
+        .unwrap();
+    assert_eq!(res.status(), StatusCode::OK);
+    assert_eq!(
+        res.json::<Value>().await.unwrap(),
+        json!({
+            "uri": "/foo",
+            "request_uri": "/foo",
+            "original_uri": "/foo",
+        })
+    );
+}
+
+#[tokio::test]
+async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() {
+    let one = nest("/foo", nest("/bar", route("/baz", get(all_the_uris))));
+    let two = route("/foo", get(all_the_uris));
+
+    let addr = run_in_background(one.or(two)).await;
+
+    let client = reqwest::Client::new();
+
+    let res = client
+        .get(format!("http://{}/foo/bar/baz", addr))
+        .send()
+        .await
+        .unwrap();
+    assert_eq!(res.status(), StatusCode::OK);
+    assert_eq!(
+        res.json::<Value>().await.unwrap(),
+        json!({
+            "uri": "/baz",
+            "request_uri": "/baz",
+            "original_uri": "/foo/bar/baz",
+        })
+    );
+
+    let res = client
+        .get(format!("http://{}/foo", addr))
+        .send()
+        .await
+        .unwrap();
+    assert_eq!(res.status(), StatusCode::OK);
+    assert_eq!(
+        res.json::<Value>().await.unwrap(),
+        json!({
+            "uri": "/foo",
+            "request_uri": "/foo",
+            "original_uri": "/foo",
+        })
+    );
+}
+
+#[tokio::test]
+async fn nesting_and_seeing_the_right_uri_ors_with_nesting() {
+    let one = nest("/foo", nest("/bar", route("/baz", get(all_the_uris))));
+    let two = nest("/foo", route("/qux", get(all_the_uris)));
+    let three = route("/foo", get(all_the_uris));
+
+    let addr = run_in_background(one.or(two).or(three)).await;
+
+    let client = reqwest::Client::new();
+
+    let res = client
+        .get(format!("http://{}/foo/bar/baz", addr))
+        .send()
+        .await
+        .unwrap();
+    assert_eq!(res.status(), StatusCode::OK);
+    assert_eq!(
+        res.json::<Value>().await.unwrap(),
+        json!({
+            "uri": "/baz",
+            "request_uri": "/baz",
+            "original_uri": "/foo/bar/baz",
+        })
+    );
+
+    let res = client
+        .get(format!("http://{}/foo/qux", addr))
+        .send()
+        .await
+        .unwrap();
+    assert_eq!(res.status(), StatusCode::OK);
+    assert_eq!(
+        res.json::<Value>().await.unwrap(),
+        json!({
+            "uri": "/qux",
+            "request_uri": "/qux",
+            "original_uri": "/foo/qux",
+        })
+    );
+
+    let res = client
+        .get(format!("http://{}/foo", addr))
+        .send()
+        .await
+        .unwrap();
+    assert_eq!(res.status(), StatusCode::OK);
+    assert_eq!(
+        res.json::<Value>().await.unwrap(),
+        json!({
+            "uri": "/foo",
+            "request_uri": "/foo",
+            "original_uri": "/foo",
+        })
+    );
+}
+
+#[tokio::test]
+async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() {
+    let one = nest("/foo", nest("/bar", route("/baz", get(all_the_uris))));
+    let two = route("/foo/bar", get(all_the_uris));
+
+    let addr = run_in_background(one.or(two)).await;
+
+    let client = reqwest::Client::new();
+
+    let res = client
+        .get(format!("http://{}/foo/bar/baz", addr))
+        .send()
+        .await
+        .unwrap();
+    assert_eq!(res.status(), StatusCode::OK);
+    assert_eq!(
+        res.json::<Value>().await.unwrap(),
+        json!({
+            "uri": "/baz",
+            "request_uri": "/baz",
+            "original_uri": "/foo/bar/baz",
+        })
+    );
+
+    let res = client
+        .get(format!("http://{}/foo/bar", addr))
+        .send()
+        .await
+        .unwrap();
+    assert_eq!(res.status(), StatusCode::OK);
+    assert_eq!(
+        res.json::<Value>().await.unwrap(),
+        json!({
+            "uri": "/foo/bar",
+            "request_uri": "/foo/bar",
+            "original_uri": "/foo/bar",
+        })
+    );
+}