From 0e3f9d09385bf22331361ee50352f95abeb07dea Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 8 Nov 2022 21:43:19 +0100 Subject: [PATCH] Don't allow extracting `MatchedPath` in middleware for nested routes (#1462) * Don't allow extracting `MatchedPath` for nested paths * misc clean up * Update docs * changelog * Apply suggestions from code review Co-authored-by: Jonas Platte * Add test for nested handler service * change to `debug_assert` * apply suggestions from review Co-authored-by: Jonas Platte --- axum/CHANGELOG.md | 3 + axum/src/extract/matched_path.rs | 306 +++++++++++++++++++++---------- axum/src/extract/mod.rs | 2 +- axum/src/routing/mod.rs | 4 +- axum/src/routing/service.rs | 57 ++---- 5 files changed, 224 insertions(+), 148 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 1596f37b..a6d9e112 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -46,6 +46,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Add `#[derive(axum::extract::FromRef)]` ([#1430]) - **added:** `FromRequest` and `FromRequestParts` derive macro re-exports from [`axum-macros`] behind the `macros` feature ([#1352]) +- **breaking:** `MatchedPath` can now no longer be extracted in middleware for + nested routes ([#1462]) - **added:** Add `extract::RawForm` for accessing raw urlencoded query bytes or request body ([#1487]) - **breaking:** Rename `FormRejection::FailedToDeserializeQueryString` to `FormRejection::FailedToDeserializeForm` ([#1496]) @@ -64,6 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1418]: https://github.com/tokio-rs/axum/pull/1418 [#1420]: https://github.com/tokio-rs/axum/pull/1420 [#1421]: https://github.com/tokio-rs/axum/pull/1421 +[#1462]: https://github.com/tokio-rs/axum/pull/1462 [#1487]: https://github.com/tokio-rs/axum/pull/1487 [#1496]: https://github.com/tokio-rs/axum/pull/1496 diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index fa4d50b3..ef20cdb9 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -1,7 +1,8 @@ use super::{rejection::*, FromRequestParts}; +use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE}; use async_trait::async_trait; use http::request::Parts; -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; /// Access the path in the router that matches the request. /// @@ -24,7 +25,10 @@ use std::sync::Arc; /// # }; /// ``` /// +/// # Accessing `MatchedPath` via extensions +/// /// `MatchedPath` can also be accessed from middleware via request extensions. +/// /// This is useful for example with [`Trace`](tower_http::trace::Trace) to /// create a span that contains the matched path: /// @@ -49,10 +53,46 @@ use std::sync::Arc; /// tracing::info_span!("http-request", %path) /// }), /// ); -/// # async { -/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; +/// # let _: Router = app; /// ``` +/// +/// # Matched path in nested routers +/// +/// Because of how [nesting] works `MatchedPath` isn't accessible in middleware on nested routes: +/// +/// ``` +/// use axum::{ +/// Router, +/// RequestExt, +/// routing::get, +/// extract::{MatchedPath, rejection::MatchedPathRejection}, +/// middleware::map_request, +/// http::Request, +/// body::Body, +/// }; +/// +/// async fn access_matched_path(mut request: Request) -> Request { +/// // if `/foo/bar` is called this will be `Err(_)` since that matches +/// // a nested route +/// let matched_path: Result = +/// request.extract_parts::().await; +/// +/// request +/// } +/// +/// // `MatchedPath` is always accessible on handlers added via `Router::route` +/// async fn handler(matched_path: MatchedPath) {} +/// +/// let app = Router::new() +/// .nest( +/// "/foo", +/// Router::new().route("/bar", get(handler)), +/// ) +/// .layer(map_request(access_matched_path)); +/// # let _: Router = app; +/// ``` +/// +/// [nesting]: crate::Router::nest #[cfg_attr(docsrs, doc(cfg(feature = "matched-path")))] #[derive(Clone, Debug)] pub struct MatchedPath(pub(crate) Arc); @@ -82,119 +122,74 @@ where } } +#[derive(Clone, Debug)] +struct MatchedNestedPath(Arc); + +pub(crate) fn set_matched_path_for_request( + id: RouteId, + route_id_to_path: &HashMap>, + extensions: &mut http::Extensions, +) { + let matched_path = if let Some(matched_path) = route_id_to_path.get(&id) { + matched_path + } else { + #[cfg(debug_assertions)] + panic!("should always have a matched path for a route id"); + }; + + let matched_path = append_nested_matched_path(matched_path, extensions); + + if matched_path.ends_with(NEST_TAIL_PARAM_CAPTURE) { + extensions.insert(MatchedNestedPath(matched_path)); + debug_assert!(matches!(dbg!(extensions.remove::()), None)); + } else { + extensions.insert(MatchedPath(matched_path)); + extensions.remove::(); + } +} + +// a previous `MatchedPath` might exist if we're inside a nested Router +fn append_nested_matched_path(matched_path: &Arc, extensions: &http::Extensions) -> Arc { + if let Some(previous) = extensions + .get::() + .map(|matched_path| matched_path.as_str()) + .or_else(|| Some(&extensions.get::()?.0)) + { + let previous = previous + .strip_suffix(NEST_TAIL_PARAM_CAPTURE) + .unwrap_or(previous); + + let matched_path = format!("{previous}{matched_path}"); + matched_path.into() + } else { + Arc::clone(matched_path) + } +} + #[cfg(test)] mod tests { use super::*; use crate::{ - extract::Extension, handler::HandlerWithoutStateExt, routing::get, test_helpers::*, Router, + handler::HandlerWithoutStateExt, middleware::map_request, routing::get, test_helpers::*, + Router, }; use http::{Request, StatusCode}; - use std::task::{Context, Poll}; - use tower::layer::layer_fn; - use tower_service::Service; - - #[derive(Clone)] - struct SetMatchedPathExtension(S); - - impl Service> for SetMatchedPathExtension - where - S: Service>, - { - type Response = S::Response; - type Error = S::Error; - type Future = S::Future; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.0.poll_ready(cx) - } - - fn call(&mut self, mut req: Request) -> Self::Future { - let path = req - .extensions() - .get::() - .unwrap() - .as_str() - .to_owned(); - req.extensions_mut().insert(MatchedPathFromMiddleware(path)); - self.0.call(req) - } - } - - #[derive(Clone)] - struct MatchedPathFromMiddleware(String); #[tokio::test] - async fn access_matched_path() { - let api = Router::new().route( - "/users/:id", + async fn extracting_on_handler() { + let app = Router::new().route( + "/:a", get(|path: MatchedPath| async move { path.as_str().to_owned() }), ); - async fn handler( - path: MatchedPath, - Extension(MatchedPathFromMiddleware(path_from_middleware)): Extension< - MatchedPathFromMiddleware, - >, - ) -> String { - format!( - "extractor = {}, middleware = {}", - path.as_str(), - path_from_middleware - ) - } - - let app = Router::new() - .route( - "/:key", - get(|path: MatchedPath| async move { path.as_str().to_owned() }), - ) - .nest("/api", api) - .nest( - "/public", - Router::new() - .route("/assets/*path", get(handler)) - // have to set the middleware here since otherwise the - // matched path is just `/public/*` since we're nesting - // this router - .layer(layer_fn(SetMatchedPathExtension)), - ) - .nest_service("/foo", handler.into_service()) - .layer(layer_fn(SetMatchedPathExtension)); - let client = TestClient::new(app); - let res = client.get("/api/users/123").send().await; - assert_eq!(res.text().await, "/api/users/:id"); - - // the router nested at `/public` doesn't handle `/` - let res = client.get("/public").send().await; - assert_eq!(res.status(), StatusCode::NOT_FOUND); - - let res = client.get("/public/assets/css/style.css").send().await; - assert_eq!( - res.text().await, - "extractor = /public/assets/*path, middleware = /public/assets/*path" - ); - let res = client.get("/foo").send().await; - assert_eq!(res.text().await, "extractor = /foo, middleware = /foo"); - - let res = client.get("/foo/").send().await; - assert_eq!(res.text().await, "extractor = /foo/, middleware = /foo/"); - - let res = client.get("/foo/bar/baz").send().await; - assert_eq!( - res.text().await, - format!( - "extractor = /foo/*{}, middleware = /foo/*{}", - crate::routing::NEST_TAIL_PARAM, - crate::routing::NEST_TAIL_PARAM, - ), - ); + assert_eq!(res.text().await, "/:a"); } #[tokio::test] - async fn nested_opaque_routers_append_to_matched_path() { + async fn extracting_on_handler_in_nested_router() { let app = Router::new().nest( "/:a", Router::new().route( @@ -208,4 +203,115 @@ mod tests { let res = client.get("/foo/bar").send().await; assert_eq!(res.text().await, "/:a/:b"); } + + #[tokio::test] + async fn extracting_on_handler_in_deeply_nested_router() { + let app = Router::new().nest( + "/:a", + Router::new().nest( + "/:b", + Router::new().route( + "/:c", + get(|path: MatchedPath| async move { path.as_str().to_owned() }), + ), + ), + ); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar/baz").send().await; + assert_eq!(res.text().await, "/:a/:b/:c"); + } + + #[tokio::test] + async fn cannot_extract_nested_matched_path_in_middleware() { + async fn extract_matched_path( + matched_path: Option, + req: Request, + ) -> Request { + assert!(matched_path.is_none()); + req + } + + let app = Router::new() + .nest("/:a", Router::new().route("/:b", get(|| async move {}))) + .layer(map_request(extract_matched_path)); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn cannot_extract_nested_matched_path_in_middleware_via_extension() { + async fn assert_no_matched_path(req: Request) -> Request { + assert!(req.extensions().get::().is_none()); + req + } + + let app = Router::new() + .nest("/:a", Router::new().route("/:b", get(|| async move {}))) + .layer(map_request(assert_no_matched_path)); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn can_extract_nested_matched_path_in_middleware_on_nested_router() { + async fn extract_matched_path(matched_path: MatchedPath, req: Request) -> Request { + assert_eq!(matched_path.as_str(), "/:a/:b"); + req + } + + let app = Router::new().nest( + "/:a", + Router::new() + .route("/:b", get(|| async move {})) + .layer(map_request(extract_matched_path)), + ); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn can_extract_nested_matched_path_in_middleware_on_nested_router_via_extension() { + async fn extract_matched_path(req: Request) -> Request { + let matched_path = req.extensions().get::().unwrap(); + assert_eq!(matched_path.as_str(), "/:a/:b"); + req + } + + let app = Router::new().nest( + "/:a", + Router::new() + .route("/:b", get(|| async move {})) + .layer(map_request(extract_matched_path)), + ); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn extracting_on_nested_handler() { + async fn handler(path: Option) { + assert!(path.is_none()); + } + + let app = Router::new().nest_service("/:a", handler.into_service()); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::OK); + } } diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index 8bed75b3..9d5fb07c 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -49,7 +49,7 @@ pub use crate::Extension; pub use crate::form::Form; #[cfg(feature = "matched-path")] -mod matched_path; +pub(crate) mod matched_path; #[cfg(feature = "matched-path")] #[doc(inline)] diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 34741254..c5156a47 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -48,7 +48,7 @@ pub use self::method_routing::{ }; #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -struct RouteId(u32); +pub(crate) struct RouteId(u32); impl RouteId { fn next() -> Self { @@ -110,7 +110,7 @@ where } pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param"; -const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param"; +pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param"; impl Router<(), B> where diff --git a/axum/src/routing/service.rs b/axum/src/routing/service.rs index bcf9ee3b..3bd9a629 100644 --- a/axum/src/routing/service.rs +++ b/axum/src/routing/service.rs @@ -1,23 +1,18 @@ +use super::{future::RouteFuture, url_params, Endpoint, Node, Route, RouteId, Router}; +use crate::{ + body::{Body, HttpBody}, + response::Response, +}; +use http::Request; +use matchit::MatchError; use std::{ collections::HashMap, convert::Infallible, sync::Arc, task::{Context, Poll}, }; - -use http::Request; -use matchit::MatchError; use tower::Service; -use super::{ - future::RouteFuture, url_params, Endpoint, Node, Route, RouteId, Router, - NEST_TAIL_PARAM_CAPTURE, -}; -use crate::{ - body::{Body, HttpBody}, - response::Response, -}; - /// A [`Router`] converted into a [`Service`]. #[derive(Debug)] pub struct RouterService { @@ -70,39 +65,11 @@ where let id = *match_.value; #[cfg(feature = "matched-path")] - { - fn set_matched_path( - id: RouteId, - route_id_to_path: &HashMap>, - extensions: &mut http::Extensions, - ) { - if let Some(matched_path) = route_id_to_path.get(&id) { - use crate::extract::MatchedPath; - - let matched_path = if let Some(previous) = extensions.get::() { - // a previous `MatchedPath` might exist if we're inside a nested Router - let previous = if let Some(previous) = - previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE) - { - previous - } else { - previous.as_str() - }; - - let matched_path = format!("{}{}", previous, matched_path); - matched_path.into() - } else { - Arc::clone(matched_path) - }; - extensions.insert(MatchedPath(matched_path)); - } else { - #[cfg(debug_assertions)] - panic!("should always have a matched path for a route id"); - } - } - - set_matched_path(id, &self.node.route_id_to_path, req.extensions_mut()); - } + crate::extract::matched_path::set_matched_path_for_request( + id, + &self.node.route_id_to_path, + req.extensions_mut(), + ); url_params::insert_url_params(req.extensions_mut(), match_.params);