diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 5aee01d8..9b7316fb 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -54,6 +54,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **fixed:** Set `Allow` header when responding with `405 Method Not Allowed` ([#733]) - **fixed:** Correctly set the `Content-Length` header for response to `HEAD` requests ([#734]) +- **fixed:** Fix wrong `content-length` for `HEAD` requests to endpoints that returns chunked + responses ([#755]) [#644]: https://github.com/tokio-rs/axum/pull/644 [#665]: https://github.com/tokio-rs/axum/pull/665 @@ -62,6 +64,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#719]: https://github.com/tokio-rs/axum/pull/719 [#733]: https://github.com/tokio-rs/axum/pull/733 [#734]: https://github.com/tokio-rs/axum/pull/734 +[#755]: https://github.com/tokio-rs/axum/pull/755 # 0.4.4 (13. January, 2022) diff --git a/axum/src/routing/future.rs b/axum/src/routing/future.rs index 013ed894..d2f097b9 100644 --- a/axum/src/routing/future.rs +++ b/axum/src/routing/future.rs @@ -1,21 +1,3 @@ //! Future types. -use crate::response::Response; -use std::convert::Infallible; - pub use super::{into_make_service::IntoMakeServiceFuture, route::RouteFuture}; - -opaque_future! { - /// Response future for [`Router`](super::Router). - pub type RouterFuture = RouteFuture; -} - -impl RouterFuture { - pub(super) fn from_future(future: RouteFuture) -> Self { - Self::new(future) - } - - pub(super) fn from_response(response: Response) -> Self { - Self::new(RouteFuture::from_response(response)) - } -} diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 79f6b04f..5ddc838c 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -983,7 +983,7 @@ where ) => { if $method == Method::$method_variant { if let Some(svc) = $svc { - return RouteFuture::from_future(svc.0.clone().oneshot($req)) + return RouteFuture::from_future(svc.oneshot_inner($req)) .strip_body($method == Method::HEAD); } } @@ -1018,11 +1018,9 @@ where call!(req, method, TRACE, trace); let future = match fallback { - Fallback::Default(fallback) => { - RouteFuture::from_future(fallback.0.clone().oneshot(req)) - .strip_body(method == Method::HEAD) - } - Fallback::Custom(fallback) => RouteFuture::from_future(fallback.0.clone().oneshot(req)) + Fallback::Default(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req)) + .strip_body(method == Method::HEAD), + Fallback::Custom(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req)) .strip_body(method == Method::HEAD), }; diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 931bcdbb..7d0d2660 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -1,6 +1,6 @@ //! Routing between [`Service`]s and handlers. -use self::{future::RouterFuture, not_found::NotFound}; +use self::{future::RouteFuture, not_found::NotFound}; use crate::{ body::{boxed, Body, Bytes, HttpBody}, extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo}, @@ -396,7 +396,11 @@ where } #[inline] - fn call_route(&self, match_: matchit::Match<&RouteId>, mut req: Request) -> RouterFuture { + fn call_route( + &self, + match_: matchit::Match<&RouteId>, + mut req: Request, + ) -> RouteFuture { let id = *match_.value; req.extensions_mut().insert(id); @@ -440,11 +444,10 @@ where .expect("no route for id. This is a bug in axum. Please file an issue") .clone(); - let future = match &mut route { + match &mut route { Endpoint::MethodRouter(inner) => inner.call(req), Endpoint::Route(inner) => inner.call(req), - }; - RouterFuture::from_future(future) + } } fn panic_on_matchit_error(&self, err: matchit::InsertError) { @@ -465,7 +468,7 @@ where { type Response = Response; type Error = Infallible; - type Future = RouterFuture; + type Future = RouteFuture; #[inline] fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { @@ -496,15 +499,11 @@ where with_path(req.uri(), &format!("{}/", path)) }; let res = Redirect::permanent(redirect_to); - RouterFuture::from_response(res.into_response()) + RouteFuture::from_response(res.into_response()) } else { match &self.fallback { - Fallback::Default(inner) => { - RouterFuture::from_future(inner.clone().call(req)) - } - Fallback::Custom(inner) => { - RouterFuture::from_future(inner.clone().call(req)) - } + Fallback::Default(inner) => inner.clone().call(req), + Fallback::Custom(inner) => inner.clone().call(req), } } } diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 7fea7f44..53fbad46 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -25,7 +25,7 @@ use tower_service::Service; /// /// You normally shouldn't need to care about this type. It's used in /// [`Router::layer`](super::Router::layer). -pub struct Route(pub(crate) BoxCloneService, Response, E>); +pub struct Route(BoxCloneService, Response, E>); impl Route { pub(super) fn new(svc: T) -> Self @@ -35,6 +35,13 @@ impl Route { { Self(BoxCloneService::new(svc)) } + + pub(crate) fn oneshot_inner( + &mut self, + req: Request, + ) -> Oneshot, Response, E>, Request> { + self.0.clone().oneshot(req) + } } impl Clone for Route { @@ -64,7 +71,7 @@ where #[inline] fn call(&mut self, req: Request) -> Self::Future { - RouteFuture::from_future(self.0.clone().oneshot(req)) + RouteFuture::from_future(self.oneshot_inner(req)) } } @@ -134,6 +141,9 @@ where #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + #[derive(Clone, Copy)] + struct AlreadyPassedThroughRouteFuture; + let this = self.project(); let mut res = match this.kind.project() { @@ -147,6 +157,16 @@ where } }; + if res + .extensions() + .get::() + .is_some() + { + return Poll::Ready(Ok(res)); + } else { + res.extensions_mut().insert(AlreadyPassedThroughRouteFuture); + } + set_allow_header(&mut res, this.allow_header); // make sure to set content-length before removing the body diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index a30c52fc..9922a611 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -645,3 +645,36 @@ async fn head_content_length_through_hyper_server_that_hits_fallback() { let res = client.head("/").send().await; assert_eq!(res.headers()["content-length"], "3"); } + +#[tokio::test] +async fn head_with_middleware_applied() { + use tower_http::compression::{predicate::SizeAbove, CompressionLayer}; + + let app = Router::new() + .route("/", get(|| async { "Hello, World!" })) + .layer(CompressionLayer::new().compress_when(SizeAbove::new(0))); + + let client = TestClient::new(app); + + // send GET request + let res = client + .get("/") + .header("accept-encoding", "gzip") + .send() + .await; + assert_eq!(res.headers()["transfer-encoding"], "chunked"); + // cannot have `transfer-encoding: chunked` and `content-length` + assert!(!res.headers().contains_key("content-length")); + + // send HEAD request + let res = client + .head("/") + .header("accept-encoding", "gzip") + .send() + .await; + // no response body so no `transfer-encoding` + assert!(!res.headers().contains_key("transfer-encoding")); + // no content-length since we cannot know it since the response + // is compressed + assert!(!res.headers().contains_key("content-length")); +}