diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md
index 1596f37b..a6d9e112 100644
--- a/axum/CHANGELOG.md
+++ b/axum/CHANGELOG.md
@@ -46,6 +46,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Add `#[derive(axum::extract::FromRef)]` ([#1430])
- **added:** `FromRequest` and `FromRequestParts` derive macro re-exports from
[`axum-macros`] behind the `macros` feature ([#1352])
+- **breaking:** `MatchedPath` can now no longer be extracted in middleware for
+ nested routes ([#1462])
- **added:** Add `extract::RawForm` for accessing raw urlencoded query bytes or request body ([#1487])
- **breaking:** Rename `FormRejection::FailedToDeserializeQueryString` to
`FormRejection::FailedToDeserializeForm` ([#1496])
@@ -64,6 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1418]: https://github.com/tokio-rs/axum/pull/1418
[#1420]: https://github.com/tokio-rs/axum/pull/1420
[#1421]: https://github.com/tokio-rs/axum/pull/1421
+[#1462]: https://github.com/tokio-rs/axum/pull/1462
[#1487]: https://github.com/tokio-rs/axum/pull/1487
[#1496]: https://github.com/tokio-rs/axum/pull/1496
diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs
index fa4d50b3..ef20cdb9 100644
--- a/axum/src/extract/matched_path.rs
+++ b/axum/src/extract/matched_path.rs
@@ -1,7 +1,8 @@
use super::{rejection::*, FromRequestParts};
+use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE};
use async_trait::async_trait;
use http::request::Parts;
-use std::sync::Arc;
+use std::{collections::HashMap, sync::Arc};
/// Access the path in the router that matches the request.
///
@@ -24,7 +25,10 @@ use std::sync::Arc;
/// # };
/// ```
///
+/// # Accessing `MatchedPath` via extensions
+///
/// `MatchedPath` can also be accessed from middleware via request extensions.
+///
/// This is useful for example with [`Trace`](tower_http::trace::Trace) to
/// create a span that contains the matched path:
///
@@ -49,10 +53,46 @@ use std::sync::Arc;
/// tracing::info_span!("http-request", %path)
/// }),
/// );
-/// # async {
-/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
-/// # };
+/// # let _: Router = app;
/// ```
+///
+/// # Matched path in nested routers
+///
+/// Because of how [nesting] works `MatchedPath` isn't accessible in middleware on nested routes:
+///
+/// ```
+/// use axum::{
+/// Router,
+/// RequestExt,
+/// routing::get,
+/// extract::{MatchedPath, rejection::MatchedPathRejection},
+/// middleware::map_request,
+/// http::Request,
+/// body::Body,
+/// };
+///
+/// async fn access_matched_path(mut request: Request
) -> Request {
+/// // if `/foo/bar` is called this will be `Err(_)` since that matches
+/// // a nested route
+/// let matched_path: Result =
+/// request.extract_parts::().await;
+///
+/// request
+/// }
+///
+/// // `MatchedPath` is always accessible on handlers added via `Router::route`
+/// async fn handler(matched_path: MatchedPath) {}
+///
+/// let app = Router::new()
+/// .nest(
+/// "/foo",
+/// Router::new().route("/bar", get(handler)),
+/// )
+/// .layer(map_request(access_matched_path));
+/// # let _: Router = app;
+/// ```
+///
+/// [nesting]: crate::Router::nest
#[cfg_attr(docsrs, doc(cfg(feature = "matched-path")))]
#[derive(Clone, Debug)]
pub struct MatchedPath(pub(crate) Arc);
@@ -82,119 +122,74 @@ where
}
}
+#[derive(Clone, Debug)]
+struct MatchedNestedPath(Arc);
+
+pub(crate) fn set_matched_path_for_request(
+ id: RouteId,
+ route_id_to_path: &HashMap>,
+ extensions: &mut http::Extensions,
+) {
+ let matched_path = if let Some(matched_path) = route_id_to_path.get(&id) {
+ matched_path
+ } else {
+ #[cfg(debug_assertions)]
+ panic!("should always have a matched path for a route id");
+ };
+
+ let matched_path = append_nested_matched_path(matched_path, extensions);
+
+ if matched_path.ends_with(NEST_TAIL_PARAM_CAPTURE) {
+ extensions.insert(MatchedNestedPath(matched_path));
+ debug_assert!(matches!(dbg!(extensions.remove::()), None));
+ } else {
+ extensions.insert(MatchedPath(matched_path));
+ extensions.remove::();
+ }
+}
+
+// a previous `MatchedPath` might exist if we're inside a nested Router
+fn append_nested_matched_path(matched_path: &Arc, extensions: &http::Extensions) -> Arc {
+ if let Some(previous) = extensions
+ .get::()
+ .map(|matched_path| matched_path.as_str())
+ .or_else(|| Some(&extensions.get::()?.0))
+ {
+ let previous = previous
+ .strip_suffix(NEST_TAIL_PARAM_CAPTURE)
+ .unwrap_or(previous);
+
+ let matched_path = format!("{previous}{matched_path}");
+ matched_path.into()
+ } else {
+ Arc::clone(matched_path)
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
use crate::{
- extract::Extension, handler::HandlerWithoutStateExt, routing::get, test_helpers::*, Router,
+ handler::HandlerWithoutStateExt, middleware::map_request, routing::get, test_helpers::*,
+ Router,
};
use http::{Request, StatusCode};
- use std::task::{Context, Poll};
- use tower::layer::layer_fn;
- use tower_service::Service;
-
- #[derive(Clone)]
- struct SetMatchedPathExtension(S);
-
- impl Service> for SetMatchedPathExtension
- where
- S: Service>,
- {
- type Response = S::Response;
- type Error = S::Error;
- type Future = S::Future;
-
- fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
- self.0.poll_ready(cx)
- }
-
- fn call(&mut self, mut req: Request) -> Self::Future {
- let path = req
- .extensions()
- .get::()
- .unwrap()
- .as_str()
- .to_owned();
- req.extensions_mut().insert(MatchedPathFromMiddleware(path));
- self.0.call(req)
- }
- }
-
- #[derive(Clone)]
- struct MatchedPathFromMiddleware(String);
#[tokio::test]
- async fn access_matched_path() {
- let api = Router::new().route(
- "/users/:id",
+ async fn extracting_on_handler() {
+ let app = Router::new().route(
+ "/:a",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
);
- async fn handler(
- path: MatchedPath,
- Extension(MatchedPathFromMiddleware(path_from_middleware)): Extension<
- MatchedPathFromMiddleware,
- >,
- ) -> String {
- format!(
- "extractor = {}, middleware = {}",
- path.as_str(),
- path_from_middleware
- )
- }
-
- let app = Router::new()
- .route(
- "/:key",
- get(|path: MatchedPath| async move { path.as_str().to_owned() }),
- )
- .nest("/api", api)
- .nest(
- "/public",
- Router::new()
- .route("/assets/*path", get(handler))
- // have to set the middleware here since otherwise the
- // matched path is just `/public/*` since we're nesting
- // this router
- .layer(layer_fn(SetMatchedPathExtension)),
- )
- .nest_service("/foo", handler.into_service())
- .layer(layer_fn(SetMatchedPathExtension));
-
let client = TestClient::new(app);
- let res = client.get("/api/users/123").send().await;
- assert_eq!(res.text().await, "/api/users/:id");
-
- // the router nested at `/public` doesn't handle `/`
- let res = client.get("/public").send().await;
- assert_eq!(res.status(), StatusCode::NOT_FOUND);
-
- let res = client.get("/public/assets/css/style.css").send().await;
- assert_eq!(
- res.text().await,
- "extractor = /public/assets/*path, middleware = /public/assets/*path"
- );
-
let res = client.get("/foo").send().await;
- assert_eq!(res.text().await, "extractor = /foo, middleware = /foo");
-
- let res = client.get("/foo/").send().await;
- assert_eq!(res.text().await, "extractor = /foo/, middleware = /foo/");
-
- let res = client.get("/foo/bar/baz").send().await;
- assert_eq!(
- res.text().await,
- format!(
- "extractor = /foo/*{}, middleware = /foo/*{}",
- crate::routing::NEST_TAIL_PARAM,
- crate::routing::NEST_TAIL_PARAM,
- ),
- );
+ assert_eq!(res.text().await, "/:a");
}
#[tokio::test]
- async fn nested_opaque_routers_append_to_matched_path() {
+ async fn extracting_on_handler_in_nested_router() {
let app = Router::new().nest(
"/:a",
Router::new().route(
@@ -208,4 +203,115 @@ mod tests {
let res = client.get("/foo/bar").send().await;
assert_eq!(res.text().await, "/:a/:b");
}
+
+ #[tokio::test]
+ async fn extracting_on_handler_in_deeply_nested_router() {
+ let app = Router::new().nest(
+ "/:a",
+ Router::new().nest(
+ "/:b",
+ Router::new().route(
+ "/:c",
+ get(|path: MatchedPath| async move { path.as_str().to_owned() }),
+ ),
+ ),
+ );
+
+ let client = TestClient::new(app);
+
+ let res = client.get("/foo/bar/baz").send().await;
+ assert_eq!(res.text().await, "/:a/:b/:c");
+ }
+
+ #[tokio::test]
+ async fn cannot_extract_nested_matched_path_in_middleware() {
+ async fn extract_matched_path(
+ matched_path: Option,
+ req: Request,
+ ) -> Request {
+ assert!(matched_path.is_none());
+ req
+ }
+
+ let app = Router::new()
+ .nest("/:a", Router::new().route("/:b", get(|| async move {})))
+ .layer(map_request(extract_matched_path));
+
+ let client = TestClient::new(app);
+
+ let res = client.get("/foo/bar").send().await;
+ assert_eq!(res.status(), StatusCode::OK);
+ }
+
+ #[tokio::test]
+ async fn cannot_extract_nested_matched_path_in_middleware_via_extension() {
+ async fn assert_no_matched_path(req: Request) -> Request {
+ assert!(req.extensions().get::().is_none());
+ req
+ }
+
+ let app = Router::new()
+ .nest("/:a", Router::new().route("/:b", get(|| async move {})))
+ .layer(map_request(assert_no_matched_path));
+
+ let client = TestClient::new(app);
+
+ let res = client.get("/foo/bar").send().await;
+ assert_eq!(res.status(), StatusCode::OK);
+ }
+
+ #[tokio::test]
+ async fn can_extract_nested_matched_path_in_middleware_on_nested_router() {
+ async fn extract_matched_path(matched_path: MatchedPath, req: Request) -> Request {
+ assert_eq!(matched_path.as_str(), "/:a/:b");
+ req
+ }
+
+ let app = Router::new().nest(
+ "/:a",
+ Router::new()
+ .route("/:b", get(|| async move {}))
+ .layer(map_request(extract_matched_path)),
+ );
+
+ let client = TestClient::new(app);
+
+ let res = client.get("/foo/bar").send().await;
+ assert_eq!(res.status(), StatusCode::OK);
+ }
+
+ #[tokio::test]
+ async fn can_extract_nested_matched_path_in_middleware_on_nested_router_via_extension() {
+ async fn extract_matched_path(req: Request) -> Request {
+ let matched_path = req.extensions().get::().unwrap();
+ assert_eq!(matched_path.as_str(), "/:a/:b");
+ req
+ }
+
+ let app = Router::new().nest(
+ "/:a",
+ Router::new()
+ .route("/:b", get(|| async move {}))
+ .layer(map_request(extract_matched_path)),
+ );
+
+ let client = TestClient::new(app);
+
+ let res = client.get("/foo/bar").send().await;
+ assert_eq!(res.status(), StatusCode::OK);
+ }
+
+ #[tokio::test]
+ async fn extracting_on_nested_handler() {
+ async fn handler(path: Option) {
+ assert!(path.is_none());
+ }
+
+ let app = Router::new().nest_service("/:a", handler.into_service());
+
+ let client = TestClient::new(app);
+
+ let res = client.get("/foo/bar").send().await;
+ assert_eq!(res.status(), StatusCode::OK);
+ }
}
diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs
index 8bed75b3..9d5fb07c 100644
--- a/axum/src/extract/mod.rs
+++ b/axum/src/extract/mod.rs
@@ -49,7 +49,7 @@ pub use crate::Extension;
pub use crate::form::Form;
#[cfg(feature = "matched-path")]
-mod matched_path;
+pub(crate) mod matched_path;
#[cfg(feature = "matched-path")]
#[doc(inline)]
diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs
index 34741254..c5156a47 100644
--- a/axum/src/routing/mod.rs
+++ b/axum/src/routing/mod.rs
@@ -48,7 +48,7 @@ pub use self::method_routing::{
};
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
-struct RouteId(u32);
+pub(crate) struct RouteId(u32);
impl RouteId {
fn next() -> Self {
@@ -110,7 +110,7 @@ where
}
pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param";
-const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";
+pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";
impl Router<(), B>
where
diff --git a/axum/src/routing/service.rs b/axum/src/routing/service.rs
index bcf9ee3b..3bd9a629 100644
--- a/axum/src/routing/service.rs
+++ b/axum/src/routing/service.rs
@@ -1,23 +1,18 @@
+use super::{future::RouteFuture, url_params, Endpoint, Node, Route, RouteId, Router};
+use crate::{
+ body::{Body, HttpBody},
+ response::Response,
+};
+use http::Request;
+use matchit::MatchError;
use std::{
collections::HashMap,
convert::Infallible,
sync::Arc,
task::{Context, Poll},
};
-
-use http::Request;
-use matchit::MatchError;
use tower::Service;
-use super::{
- future::RouteFuture, url_params, Endpoint, Node, Route, RouteId, Router,
- NEST_TAIL_PARAM_CAPTURE,
-};
-use crate::{
- body::{Body, HttpBody},
- response::Response,
-};
-
/// A [`Router`] converted into a [`Service`].
#[derive(Debug)]
pub struct RouterService {
@@ -70,39 +65,11 @@ where
let id = *match_.value;
#[cfg(feature = "matched-path")]
- {
- fn set_matched_path(
- id: RouteId,
- route_id_to_path: &HashMap>,
- extensions: &mut http::Extensions,
- ) {
- if let Some(matched_path) = route_id_to_path.get(&id) {
- use crate::extract::MatchedPath;
-
- let matched_path = if let Some(previous) = extensions.get::() {
- // a previous `MatchedPath` might exist if we're inside a nested Router
- let previous = if let Some(previous) =
- previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE)
- {
- previous
- } else {
- previous.as_str()
- };
-
- let matched_path = format!("{}{}", previous, matched_path);
- matched_path.into()
- } else {
- Arc::clone(matched_path)
- };
- extensions.insert(MatchedPath(matched_path));
- } else {
- #[cfg(debug_assertions)]
- panic!("should always have a matched path for a route id");
- }
- }
-
- set_matched_path(id, &self.node.route_id_to_path, req.extensions_mut());
- }
+ crate::extract::matched_path::set_matched_path_for_request(
+ id,
+ &self.node.route_id_to_path,
+ req.extensions_mut(),
+ );
url_params::insert_url_params(req.extensions_mut(), match_.params);