From 2e8a7e51a105ab2fda2e7ca7cb09168cc8a26bf1 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 18 Nov 2022 11:00:52 +0100 Subject: [PATCH] Fix `Handler::with_state` not working if request body was changed via layer (#1536) Previously ```rust handler.layer(RequestBodyLimitLayer::new(...)).with_state(...) ``` didn't work because we required the same request body all the way through. --- .../fail/generic_without_via.stderr | 2 +- .../fail/generic_without_via_rejection.stderr | 2 +- ...rride_rejection_on_enum_without_via.stderr | 4 +- axum/Cargo.toml | 2 +- axum/src/handler/mod.rs | 57 ++++++++++++++----- examples/key-value-store/Cargo.toml | 1 + examples/key-value-store/src/main.rs | 12 ++-- 7 files changed, 56 insertions(+), 24 deletions(-) diff --git a/axum-macros/tests/from_request/fail/generic_without_via.stderr b/axum-macros/tests/from_request/fail/generic_without_via.stderr index 40b4e435..a943a54f 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via.stderr +++ b/axum-macros/tests/from_request/fail/generic_without_via.stderr @@ -12,7 +12,7 @@ error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {fo | | | required by a bound introduced by this call | - = help: the trait `Handler` is implemented for `Layered` + = help: the trait `Handler` is implemented for `Layered` note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | diff --git a/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr b/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr index a48c43e4..8d53346e 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr +++ b/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr @@ -12,7 +12,7 @@ error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {fo | | | required by a bound introduced by this call | - = help: the trait `Handler` is implemented for `Layered` + = help: the trait `Handler` is implemented for `Layered` note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | diff --git a/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr b/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr index 8517f367..6b2b8bbe 100644 --- a/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr +++ b/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr @@ -12,7 +12,7 @@ error[E0277]: the trait bound `fn(MyExtractor) -> impl Future {hand | | | required by a bound introduced by this call | - = help: the trait `Handler` is implemented for `Layered` + = help: the trait `Handler` is implemented for `Layered` note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | @@ -28,7 +28,7 @@ error[E0277]: the trait bound `fn(Result) -> impl Futu | | | required by a bound introduced by this call | - = help: the trait `Handler` is implemented for `Layered` + = help: the trait `Handler` is implemented for `Layered` note: required by a bound in `MethodRouter::::post` --> $WORKSPACE/axum/src/routing/method_routing.rs | diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 641d7742..3674e649 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -48,7 +48,7 @@ serde = "1.0" sync_wrapper = "0.1.1" tower = { version = "0.4.13", default-features = false, features = ["util"] } tower-http = { version = "0.3.0", features = ["util", "map-response-body"] } -tower-layer = "0.3" +tower-layer = "0.3.2" tower-service = "0.3" # optional dependencies diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 986b1342..d7427dc1 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -139,9 +139,10 @@ pub trait Handler: Clone + Send + Sized + 'static { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - fn layer(self, layer: L) -> Layered + fn layer(self, layer: L) -> Layered where L: Layer> + Clone, + L::Service: Service>, { Layered { layer, @@ -220,13 +221,13 @@ all_the_tuples!(impl_handler); /// A [`Service`] created from a [`Handler`] by applying a Tower middleware. /// /// Created with [`Handler::layer`]. See that method for more details. -pub struct Layered { +pub struct Layered { layer: L, handler: H, - _marker: PhantomData (T, S, B)>, + _marker: PhantomData (T, S, B, B2)>, } -impl fmt::Debug for Layered +impl fmt::Debug for Layered where L: fmt::Debug, { @@ -237,7 +238,7 @@ where } } -impl Clone for Layered +impl Clone for Layered where L: Clone, H: Clone, @@ -251,20 +252,21 @@ where } } -impl Handler for Layered +impl Handler for Layered where L: Layer> + Clone + Send + 'static, H: Handler, - L::Service: Service, Error = Infallible> + Clone + Send + 'static, - >>::Response: IntoResponse, - >>::Future: Send, + L::Service: Service, Error = Infallible> + Clone + Send + 'static, + >>::Response: IntoResponse, + >>::Future: Send, T: 'static, S: 'static, B: Send + 'static, + B2: Send + 'static, { - type Future = future::LayeredFuture; + type Future = future::LayeredFuture; - fn call(self, req: Request, state: S) -> Self::Future { + fn call(self, req: Request, state: S) -> Self::Future { use futures_util::future::{FutureExt, Map}; let svc = self.handler.with_state(state); @@ -274,8 +276,8 @@ where _, fn( Result< - >>::Response, - >>::Error, + >>::Response, + >>::Error, >, ) -> _, > = svc.oneshot(req).map(|result| match result { @@ -338,8 +340,14 @@ where #[cfg(test)] mod tests { use super::*; - use crate::test_helpers::*; + use crate::{body, extract::State, test_helpers::*}; use http::StatusCode; + use std::time::Duration; + use tower_http::{ + compression::CompressionLayer, limit::RequestBodyLimitLayer, + map_request_body::MapRequestBodyLayer, map_response_body::MapResponseBodyLayer, + timeout::TimeoutLayer, + }; #[tokio::test] async fn handler_into_service() { @@ -353,4 +361,25 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "you said: hi there!"); } + + #[tokio::test] + async fn with_layer_that_changes_request_body_and_state() { + async fn handle(State(state): State<&'static str>) -> &'static str { + state + } + + let svc = handle + .layer(( + RequestBodyLimitLayer::new(1024), + TimeoutLayer::new(Duration::from_secs(10)), + MapResponseBodyLayer::new(body::boxed), + CompressionLayer::new(), + )) + .layer(MapRequestBodyLayer::new(body::boxed)) + .with_state("foo"); + + let client = TestClient::from_service(svc); + let res = client.get("/").send().await; + assert_eq!(res.text().await, "foo"); + } } diff --git a/examples/key-value-store/Cargo.toml b/examples/key-value-store/Cargo.toml index 73fb191b..7db05d60 100644 --- a/examples/key-value-store/Cargo.toml +++ b/examples/key-value-store/Cargo.toml @@ -15,5 +15,6 @@ tower-http = { version = "0.3.0", features = [ "limit", "trace", ] } +tower-layer = "0.3.2" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index b9580102..e176436b 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -26,7 +26,7 @@ use std::{ use tower::{BoxError, ServiceBuilder}; use tower_http::{ auth::RequireAuthorizationLayer, compression::CompressionLayer, limit::RequestBodyLimitLayer, - trace::TraceLayer, + trace::TraceLayer, ServiceBuilderExt, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -50,10 +50,12 @@ async fn main() { get(kv_get.layer(CompressionLayer::new())) // But don't compress `kv_set` .post_service( - ServiceBuilder::new() - .layer(DefaultBodyLimit::disable()) - .layer(RequestBodyLimitLayer::new(1024 * 5_000 /* ~5mb */)) - .service(kv_set.with_state(Arc::clone(&shared_state))), + kv_set + .layer(( + DefaultBodyLimit::disable(), + RequestBodyLimitLayer::new(1024 * 5_000 /* ~5mb */), + )) + .with_state(Arc::clone(&shared_state)), ), ) .route("/keys", get(list_keys))