diff --git a/CHANGELOG.md b/CHANGELOG.md index 66eb2197..dba2f7ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Make `FromRequest` default to being generic over `axum::body::Body` ([#146](https://github.com/tokio-rs/axum/pull/146)) - Implement `std::error::Error` for all rejections ([#153](https://github.com/tokio-rs/axum/pull/153)) +- Fix `Uri` extractor not being the full URI if using `nest` ([#156](https://github.com/tokio-rs/axum/pull/156)) ## Breaking changes diff --git a/src/extract/mod.rs b/src/extract/mod.rs index da37e282..33c63639 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -244,7 +244,7 @@ //! //! [`body::Body`]: crate::body::Body -use crate::response::IntoResponse; +use crate::{response::IntoResponse, routing::OriginalUri}; use async_trait::async_trait; use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version}; use rejection::*; @@ -377,7 +377,7 @@ impl RequestParts { let ( http::request::Parts { method, - uri, + mut uri, version, headers, extensions, @@ -386,6 +386,10 @@ impl RequestParts { body, ) = req.into_parts(); + if let Some(original_uri) = extensions.get::() { + uri = original_uri.0.clone(); + }; + RequestParts { method, uri, diff --git a/src/routing.rs b/src/routing.rs index 71931c51..1988851a 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -907,6 +907,11 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { + if req.extensions().get::().is_none() { + let original_uri = OriginalUri(req.uri().clone()); + req.extensions_mut().insert(original_uri); + } + if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) { let without_prefix = strip_prefix(req.uri(), prefix); *req.uri_mut() = without_prefix; @@ -921,6 +926,11 @@ where } } +/// `Nested` changes the incoming requests URI. This will be saved as an +/// extension so extractors can still access the original URI. +#[derive(Clone)] +pub(crate) struct OriginalUri(pub(crate) Uri); + fn strip_prefix(uri: &Uri, prefix: &str) -> Uri { let path_and_query = if let Some(path_and_query) = uri.path_and_query() { let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) { diff --git a/src/tests/nest.rs b/src/tests/nest.rs index f518eb4e..cee5f1d5 100644 --- a/src/tests/nest.rs +++ b/src/tests/nest.rs @@ -126,3 +126,35 @@ async fn nesting_at_root() { assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await.unwrap(), "/foo/bar"); } + +#[tokio::test] +async fn nested_url_extractor() { + let app = nest( + "/foo", + nest( + "/bar", + route("/baz", get(|uri: Uri| async move { uri.to_string() })).route( + "/qux", + get(|req: Request| async move { req.uri().to_string() }), + ), + ), + ); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/foo/bar/baz", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.text().await.unwrap(), "/foo/bar/baz"); + + let res = client + .get(format!("http://{}/foo/bar/qux", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.text().await.unwrap(), "/foo/bar/qux"); +}