mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-11 12:31:25 +01:00
Fix Handler::with_state
not working if request body was changed via layer (#1536)
Previously ```rust handler.layer(RequestBodyLimitLayer::new(...)).with_state(...) ``` didn't work because we required the same request body all the way through.
This commit is contained in:
parent
b1f894a500
commit
2e8a7e51a1
7 changed files with 56 additions and 24 deletions
|
@ -12,7 +12,7 @@ error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {fo
|
|||
| |
|
||||
| required by a bound introduced by this call
|
||||
|
|
||||
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
|
||||
= help: the trait `Handler<T, S, B2>` is implemented for `Layered<L, H, T, S, B, B2>`
|
||||
note: required by a bound in `axum::routing::get`
|
||||
--> $WORKSPACE/axum/src/routing/method_routing.rs
|
||||
|
|
||||
|
|
|
@ -12,7 +12,7 @@ error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {fo
|
|||
| |
|
||||
| required by a bound introduced by this call
|
||||
|
|
||||
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
|
||||
= help: the trait `Handler<T, S, B2>` is implemented for `Layered<L, H, T, S, B, B2>`
|
||||
note: required by a bound in `axum::routing::get`
|
||||
--> $WORKSPACE/axum/src/routing/method_routing.rs
|
||||
|
|
||||
|
|
|
@ -12,7 +12,7 @@ error[E0277]: the trait bound `fn(MyExtractor) -> impl Future<Output = ()> {hand
|
|||
| |
|
||||
| required by a bound introduced by this call
|
||||
|
|
||||
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
|
||||
= help: the trait `Handler<T, S, B2>` is implemented for `Layered<L, H, T, S, B, B2>`
|
||||
note: required by a bound in `axum::routing::get`
|
||||
--> $WORKSPACE/axum/src/routing/method_routing.rs
|
||||
|
|
||||
|
@ -28,7 +28,7 @@ error[E0277]: the trait bound `fn(Result<MyExtractor, MyRejection>) -> impl Futu
|
|||
| |
|
||||
| required by a bound introduced by this call
|
||||
|
|
||||
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
|
||||
= help: the trait `Handler<T, S, B2>` is implemented for `Layered<L, H, T, S, B, B2>`
|
||||
note: required by a bound in `MethodRouter::<S, B>::post`
|
||||
--> $WORKSPACE/axum/src/routing/method_routing.rs
|
||||
|
|
||||
|
|
|
@ -48,7 +48,7 @@ serde = "1.0"
|
|||
sync_wrapper = "0.1.1"
|
||||
tower = { version = "0.4.13", default-features = false, features = ["util"] }
|
||||
tower-http = { version = "0.3.0", features = ["util", "map-response-body"] }
|
||||
tower-layer = "0.3"
|
||||
tower-layer = "0.3.2"
|
||||
tower-service = "0.3"
|
||||
|
||||
# optional dependencies
|
||||
|
|
|
@ -139,9 +139,10 @@ pub trait Handler<T, S, B = Body>: Clone + Send + Sized + 'static {
|
|||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
fn layer<L>(self, layer: L) -> Layered<L, Self, T, S, B>
|
||||
fn layer<L, NewReqBody>(self, layer: L) -> Layered<L, Self, T, S, B, NewReqBody>
|
||||
where
|
||||
L: Layer<HandlerService<Self, T, S, B>> + Clone,
|
||||
L::Service: Service<Request<NewReqBody>>,
|
||||
{
|
||||
Layered {
|
||||
layer,
|
||||
|
@ -220,13 +221,13 @@ all_the_tuples!(impl_handler);
|
|||
/// A [`Service`] created from a [`Handler`] by applying a Tower middleware.
|
||||
///
|
||||
/// Created with [`Handler::layer`]. See that method for more details.
|
||||
pub struct Layered<L, H, T, S, B> {
|
||||
pub struct Layered<L, H, T, S, B, B2> {
|
||||
layer: L,
|
||||
handler: H,
|
||||
_marker: PhantomData<fn() -> (T, S, B)>,
|
||||
_marker: PhantomData<fn() -> (T, S, B, B2)>,
|
||||
}
|
||||
|
||||
impl<L, H, T, S, B> fmt::Debug for Layered<L, H, T, S, B>
|
||||
impl<L, H, T, S, B, B2> fmt::Debug for Layered<L, H, T, S, B, B2>
|
||||
where
|
||||
L: fmt::Debug,
|
||||
{
|
||||
|
@ -237,7 +238,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<L, H, T, S, B> Clone for Layered<L, H, T, S, B>
|
||||
impl<L, H, T, S, B, B2> Clone for Layered<L, H, T, S, B, B2>
|
||||
where
|
||||
L: Clone,
|
||||
H: Clone,
|
||||
|
@ -251,20 +252,21 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<H, S, T, B, L> Handler<T, S, B> for Layered<L, H, T, S, B>
|
||||
impl<H, S, T, L, B, B2> Handler<T, S, B2> for Layered<L, H, T, S, B, B2>
|
||||
where
|
||||
L: Layer<HandlerService<H, T, S, B>> + Clone + Send + 'static,
|
||||
H: Handler<T, S, B>,
|
||||
L::Service: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
|
||||
<L::Service as Service<Request<B>>>::Response: IntoResponse,
|
||||
<L::Service as Service<Request<B>>>::Future: Send,
|
||||
L::Service: Service<Request<B2>, Error = Infallible> + Clone + Send + 'static,
|
||||
<L::Service as Service<Request<B2>>>::Response: IntoResponse,
|
||||
<L::Service as Service<Request<B2>>>::Future: Send,
|
||||
T: 'static,
|
||||
S: 'static,
|
||||
B: Send + 'static,
|
||||
B2: Send + 'static,
|
||||
{
|
||||
type Future = future::LayeredFuture<B, L::Service>;
|
||||
type Future = future::LayeredFuture<B2, L::Service>;
|
||||
|
||||
fn call(self, req: Request<B>, state: S) -> Self::Future {
|
||||
fn call(self, req: Request<B2>, state: S) -> Self::Future {
|
||||
use futures_util::future::{FutureExt, Map};
|
||||
|
||||
let svc = self.handler.with_state(state);
|
||||
|
@ -274,8 +276,8 @@ where
|
|||
_,
|
||||
fn(
|
||||
Result<
|
||||
<L::Service as Service<Request<B>>>::Response,
|
||||
<L::Service as Service<Request<B>>>::Error,
|
||||
<L::Service as Service<Request<B2>>>::Response,
|
||||
<L::Service as Service<Request<B2>>>::Error,
|
||||
>,
|
||||
) -> _,
|
||||
> = svc.oneshot(req).map(|result| match result {
|
||||
|
@ -338,8 +340,14 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::test_helpers::*;
|
||||
use crate::{body, extract::State, test_helpers::*};
|
||||
use http::StatusCode;
|
||||
use std::time::Duration;
|
||||
use tower_http::{
|
||||
compression::CompressionLayer, limit::RequestBodyLimitLayer,
|
||||
map_request_body::MapRequestBodyLayer, map_response_body::MapResponseBodyLayer,
|
||||
timeout::TimeoutLayer,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn handler_into_service() {
|
||||
|
@ -353,4 +361,25 @@ mod tests {
|
|||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await, "you said: hi there!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn with_layer_that_changes_request_body_and_state() {
|
||||
async fn handle(State(state): State<&'static str>) -> &'static str {
|
||||
state
|
||||
}
|
||||
|
||||
let svc = handle
|
||||
.layer((
|
||||
RequestBodyLimitLayer::new(1024),
|
||||
TimeoutLayer::new(Duration::from_secs(10)),
|
||||
MapResponseBodyLayer::new(body::boxed),
|
||||
CompressionLayer::new(),
|
||||
))
|
||||
.layer(MapRequestBodyLayer::new(body::boxed))
|
||||
.with_state("foo");
|
||||
|
||||
let client = TestClient::from_service(svc);
|
||||
let res = client.get("/").send().await;
|
||||
assert_eq!(res.text().await, "foo");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,5 +15,6 @@ tower-http = { version = "0.3.0", features = [
|
|||
"limit",
|
||||
"trace",
|
||||
] }
|
||||
tower-layer = "0.3.2"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
|
|
@ -26,7 +26,7 @@ use std::{
|
|||
use tower::{BoxError, ServiceBuilder};
|
||||
use tower_http::{
|
||||
auth::RequireAuthorizationLayer, compression::CompressionLayer, limit::RequestBodyLimitLayer,
|
||||
trace::TraceLayer,
|
||||
trace::TraceLayer, ServiceBuilderExt,
|
||||
};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
|
@ -50,10 +50,12 @@ async fn main() {
|
|||
get(kv_get.layer(CompressionLayer::new()))
|
||||
// But don't compress `kv_set`
|
||||
.post_service(
|
||||
ServiceBuilder::new()
|
||||
.layer(DefaultBodyLimit::disable())
|
||||
.layer(RequestBodyLimitLayer::new(1024 * 5_000 /* ~5mb */))
|
||||
.service(kv_set.with_state(Arc::clone(&shared_state))),
|
||||
kv_set
|
||||
.layer((
|
||||
DefaultBodyLimit::disable(),
|
||||
RequestBodyLimitLayer::new(1024 * 5_000 /* ~5mb */),
|
||||
))
|
||||
.with_state(Arc::clone(&shared_state)),
|
||||
),
|
||||
)
|
||||
.route("/keys", get(list_keys))
|
||||
|
|
Loading…
Reference in a new issue