diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 1e6dcc9d..046f7d09 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -11,7 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 They can be enabled by changing `get(ws_endpoint)` handlers to `any(ws_endpoint)` - **added:** Add `MethodFilter::CONNECT`, `routing::connect[_service]` and `MethodRouter::connect[_service]` ([#2961]) +- **fixed:** Avoid setting `content-length` before middleware ([#2897]). + This allows middleware to add bodies to requests without needing to manually set `content-length` +[#2897]: https://github.com/tokio-rs/axum/pull/2897 [#2984]: https://github.com/tokio-rs/axum/pull/2984 [#2961]: https://github.com/tokio-rs/axum/pull/2961 diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 0b88fbcc..0953dfbc 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -15,10 +15,10 @@ use std::{ fmt, future::Future, pin::Pin, - task::{Context, Poll}, + task::{ready, Context, Poll}, }; use tower::{ - util::{MapErrLayer, MapRequestLayer, MapResponseLayer, Oneshot}, + util::{MapErrLayer, MapResponseLayer, Oneshot}, ServiceExt, }; use tower_layer::Layer; @@ -44,7 +44,7 @@ impl Route { pub(crate) fn oneshot_inner(&mut self, req: Request) -> RouteFuture { let method = req.method().clone(); - RouteFuture::from_future(method, self.0.clone().oneshot(req)) + RouteFuture::new(method, self.0.clone().oneshot(req)) } pub(crate) fn layer(self, layer: L) -> Route @@ -57,7 +57,6 @@ impl Route { NewError: 'static, { let layer = ( - MapRequestLayer::new(|req: Request<_>| req.map(Body::new)), MapErrLayer::new(Into::into), MapResponseLayer::new(IntoResponse::into_response), layer, @@ -96,7 +95,7 @@ where #[inline] fn call(&mut self, req: Request) -> Self::Future { - self.oneshot_inner(req.map(Body::new)) + self.oneshot_inner(req.map(Body::new)).not_top_level() } } @@ -104,37 +103,20 @@ pin_project! { /// Response future for [`Route`]. pub struct RouteFuture { #[pin] - kind: RouteFutureKind, + inner: Oneshot, Request>, method: Method, allow_header: Option, - } -} - -pin_project! { - #[project = RouteFutureKindProj] - enum RouteFutureKind { - Future { - #[pin] - future: Oneshot< - BoxCloneService, - Request, - >, - }, - Response { - response: Option, - } + top_level: bool, } } impl RouteFuture { - pub(crate) fn from_future( - method: Method, - future: Oneshot, Request>, - ) -> Self { + fn new(method: Method, inner: Oneshot, Request>) -> Self { Self { - kind: RouteFutureKind::Future { future }, + inner, method, allow_header: None, + top_level: true, } } @@ -142,6 +124,11 @@ impl RouteFuture { self.allow_header = Some(allow_header); self } + + pub(crate) fn not_top_level(mut self) -> Self { + self.top_level = false; + self + } } impl Future for RouteFuture { @@ -150,19 +137,7 @@ impl Future for RouteFuture { #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - - let mut res = match this.kind.project() { - RouteFutureKindProj::Future { future } => match future.poll(cx) { - Poll::Ready(Ok(res)) => res, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - }, - RouteFutureKindProj::Response { response } => { - response.take().expect("future polled after completion") - } - }; - - set_allow_header(res.headers_mut(), this.allow_header); + let mut res = ready!(this.inner.poll(cx))?; if *this.method == Method::CONNECT && res.status().is_success() { // From https://httpwg.org/specs/rfc9110.html#CONNECT: @@ -176,16 +151,16 @@ impl Future for RouteFuture { error!("response to CONNECT with nonempty body"); res = res.map(|_| Body::empty()); } - } else { + } else if *this.top_level { + set_allow_header(res.headers_mut(), this.allow_header); + // make sure to set content-length before removing the body set_content_length(res.size_hint(), res.headers_mut()); - } - let res = if *this.method == Method::HEAD { - res.map(|_| Body::empty()) - } else { - res - }; + if *this.method == Method::HEAD { + *res.body_mut() = Body::empty(); + } + } Poll::Ready(Ok(res)) } diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index db5ca480..c4e37179 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -1073,3 +1073,22 @@ async fn colon_in_route() { async fn asterisk_in_route() { _ = Router::<()>::new().route("/*foo", get(|| async move {})); } + +#[crate::test] +async fn middleware_adding_body() { + let app = Router::new() + .route("/", get(())) + .layer(MapResponseLayer::new(|mut res: Response| -> Response { + // If there is a content-length header, its value will be zero and Axum will avoid + // overwriting it. But this means our content-length doesn’t match the length of the + // body, which leads to panics in Hyper. Thus we have to ensure that Axum doesn’t add + // on content-length headers until after middleware has been run. + assert!(!res.headers().contains_key("content-length")); + *res.body_mut() = "…".into(); + res + })); + + let client = TestClient::new(app); + let res = client.get("/").await; + assert_eq!(res.text().await, "…"); +}