Fix Handler::with_state not working if request body was changed via layer (#1536)

Previously

```rust
handler.layer(RequestBodyLimitLayer::new(...)).with_state(...)
```

didn't work because we required the same request body all the way
through.
This commit is contained in:
David Pedersen 2022-11-18 11:00:52 +01:00 committed by GitHub
parent b1f894a500
commit 2e8a7e51a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 56 additions and 24 deletions

View file

@ -12,7 +12,7 @@ error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {fo
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
= help: the trait `Handler<T, S, B2>` is implemented for `Layered<L, H, T, S, B, B2>`
note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|

View file

@ -12,7 +12,7 @@ error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {fo
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
= help: the trait `Handler<T, S, B2>` is implemented for `Layered<L, H, T, S, B, B2>`
note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|

View file

@ -12,7 +12,7 @@ error[E0277]: the trait bound `fn(MyExtractor) -> impl Future<Output = ()> {hand
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
= help: the trait `Handler<T, S, B2>` is implemented for `Layered<L, H, T, S, B, B2>`
note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|
@ -28,7 +28,7 @@ error[E0277]: the trait bound `fn(Result<MyExtractor, MyRejection>) -> impl Futu
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
= help: the trait `Handler<T, S, B2>` is implemented for `Layered<L, H, T, S, B, B2>`
note: required by a bound in `MethodRouter::<S, B>::post`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|

View file

@ -48,7 +48,7 @@ serde = "1.0"
sync_wrapper = "0.1.1"
tower = { version = "0.4.13", default-features = false, features = ["util"] }
tower-http = { version = "0.3.0", features = ["util", "map-response-body"] }
tower-layer = "0.3"
tower-layer = "0.3.2"
tower-service = "0.3"
# optional dependencies

View file

@ -139,9 +139,10 @@ pub trait Handler<T, S, B = Body>: Clone + Send + Sized + 'static {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
fn layer<L>(self, layer: L) -> Layered<L, Self, T, S, B>
fn layer<L, NewReqBody>(self, layer: L) -> Layered<L, Self, T, S, B, NewReqBody>
where
L: Layer<HandlerService<Self, T, S, B>> + Clone,
L::Service: Service<Request<NewReqBody>>,
{
Layered {
layer,
@ -220,13 +221,13 @@ all_the_tuples!(impl_handler);
/// A [`Service`] created from a [`Handler`] by applying a Tower middleware.
///
/// Created with [`Handler::layer`]. See that method for more details.
pub struct Layered<L, H, T, S, B> {
pub struct Layered<L, H, T, S, B, B2> {
layer: L,
handler: H,
_marker: PhantomData<fn() -> (T, S, B)>,
_marker: PhantomData<fn() -> (T, S, B, B2)>,
}
impl<L, H, T, S, B> fmt::Debug for Layered<L, H, T, S, B>
impl<L, H, T, S, B, B2> fmt::Debug for Layered<L, H, T, S, B, B2>
where
L: fmt::Debug,
{
@ -237,7 +238,7 @@ where
}
}
impl<L, H, T, S, B> Clone for Layered<L, H, T, S, B>
impl<L, H, T, S, B, B2> Clone for Layered<L, H, T, S, B, B2>
where
L: Clone,
H: Clone,
@ -251,20 +252,21 @@ where
}
}
impl<H, S, T, B, L> Handler<T, S, B> for Layered<L, H, T, S, B>
impl<H, S, T, L, B, B2> Handler<T, S, B2> for Layered<L, H, T, S, B, B2>
where
L: Layer<HandlerService<H, T, S, B>> + Clone + Send + 'static,
H: Handler<T, S, B>,
L::Service: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
<L::Service as Service<Request<B>>>::Response: IntoResponse,
<L::Service as Service<Request<B>>>::Future: Send,
L::Service: Service<Request<B2>, Error = Infallible> + Clone + Send + 'static,
<L::Service as Service<Request<B2>>>::Response: IntoResponse,
<L::Service as Service<Request<B2>>>::Future: Send,
T: 'static,
S: 'static,
B: Send + 'static,
B2: Send + 'static,
{
type Future = future::LayeredFuture<B, L::Service>;
type Future = future::LayeredFuture<B2, L::Service>;
fn call(self, req: Request<B>, state: S) -> Self::Future {
fn call(self, req: Request<B2>, state: S) -> Self::Future {
use futures_util::future::{FutureExt, Map};
let svc = self.handler.with_state(state);
@ -274,8 +276,8 @@ where
_,
fn(
Result<
<L::Service as Service<Request<B>>>::Response,
<L::Service as Service<Request<B>>>::Error,
<L::Service as Service<Request<B2>>>::Response,
<L::Service as Service<Request<B2>>>::Error,
>,
) -> _,
> = svc.oneshot(req).map(|result| match result {
@ -338,8 +340,14 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::*;
use crate::{body, extract::State, test_helpers::*};
use http::StatusCode;
use std::time::Duration;
use tower_http::{
compression::CompressionLayer, limit::RequestBodyLimitLayer,
map_request_body::MapRequestBodyLayer, map_response_body::MapResponseBodyLayer,
timeout::TimeoutLayer,
};
#[tokio::test]
async fn handler_into_service() {
@ -353,4 +361,25 @@ mod tests {
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "you said: hi there!");
}
#[tokio::test]
async fn with_layer_that_changes_request_body_and_state() {
async fn handle(State(state): State<&'static str>) -> &'static str {
state
}
let svc = handle
.layer((
RequestBodyLimitLayer::new(1024),
TimeoutLayer::new(Duration::from_secs(10)),
MapResponseBodyLayer::new(body::boxed),
CompressionLayer::new(),
))
.layer(MapRequestBodyLayer::new(body::boxed))
.with_state("foo");
let client = TestClient::from_service(svc);
let res = client.get("/").send().await;
assert_eq!(res.text().await, "foo");
}
}

View file

@ -15,5 +15,6 @@ tower-http = { version = "0.3.0", features = [
"limit",
"trace",
] }
tower-layer = "0.3.2"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View file

@ -26,7 +26,7 @@ use std::{
use tower::{BoxError, ServiceBuilder};
use tower_http::{
auth::RequireAuthorizationLayer, compression::CompressionLayer, limit::RequestBodyLimitLayer,
trace::TraceLayer,
trace::TraceLayer, ServiceBuilderExt,
};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
@ -50,10 +50,12 @@ async fn main() {
get(kv_get.layer(CompressionLayer::new()))
// But don't compress `kv_set`
.post_service(
ServiceBuilder::new()
.layer(DefaultBodyLimit::disable())
.layer(RequestBodyLimitLayer::new(1024 * 5_000 /* ~5mb */))
.service(kv_set.with_state(Arc::clone(&shared_state))),
kv_set
.layer((
DefaultBodyLimit::disable(),
RequestBodyLimitLayer::new(1024 * 5_000 /* ~5mb */),
))
.with_state(Arc::clone(&shared_state)),
),
)
.route("/keys", get(list_keys))