Fix setting content-length: 0 for HEAD responses with compression middleware (#755)

* Fix getting `content-length` for `chunked` responses

Fixes #747

* changelog

* Fix `cargo deny bans`

https://github.com/tokio-rs/axum/pull/753 will fix things properly
This commit is contained in:
David Pedersen 2022-02-14 10:56:08 +01:00 committed by GitHub
parent 409a6651f5
commit c135436cc9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 74 additions and 39 deletions

View file

@ -54,6 +54,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **fixed:** Set `Allow` header when responding with `405 Method Not Allowed` ([#733]) - **fixed:** Set `Allow` header when responding with `405 Method Not Allowed` ([#733])
- **fixed:** Correctly set the `Content-Length` header for response to `HEAD` - **fixed:** Correctly set the `Content-Length` header for response to `HEAD`
requests ([#734]) requests ([#734])
- **fixed:** Fix wrong `content-length` for `HEAD` requests to endpoints that returns chunked
responses ([#755])
[#644]: https://github.com/tokio-rs/axum/pull/644 [#644]: https://github.com/tokio-rs/axum/pull/644
[#665]: https://github.com/tokio-rs/axum/pull/665 [#665]: https://github.com/tokio-rs/axum/pull/665
@ -62,6 +64,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#719]: https://github.com/tokio-rs/axum/pull/719 [#719]: https://github.com/tokio-rs/axum/pull/719
[#733]: https://github.com/tokio-rs/axum/pull/733 [#733]: https://github.com/tokio-rs/axum/pull/733
[#734]: https://github.com/tokio-rs/axum/pull/734 [#734]: https://github.com/tokio-rs/axum/pull/734
[#755]: https://github.com/tokio-rs/axum/pull/755
# 0.4.4 (13. January, 2022) # 0.4.4 (13. January, 2022)

View file

@ -1,21 +1,3 @@
//! Future types. //! Future types.
use crate::response::Response;
use std::convert::Infallible;
pub use super::{into_make_service::IntoMakeServiceFuture, route::RouteFuture}; pub use super::{into_make_service::IntoMakeServiceFuture, route::RouteFuture};
opaque_future! {
/// Response future for [`Router`](super::Router).
pub type RouterFuture<B> = RouteFuture<B, Infallible>;
}
impl<B> RouterFuture<B> {
pub(super) fn from_future(future: RouteFuture<B, Infallible>) -> Self {
Self::new(future)
}
pub(super) fn from_response(response: Response) -> Self {
Self::new(RouteFuture::from_response(response))
}
}

View file

@ -983,7 +983,7 @@ where
) => { ) => {
if $method == Method::$method_variant { if $method == Method::$method_variant {
if let Some(svc) = $svc { if let Some(svc) = $svc {
return RouteFuture::from_future(svc.0.clone().oneshot($req)) return RouteFuture::from_future(svc.oneshot_inner($req))
.strip_body($method == Method::HEAD); .strip_body($method == Method::HEAD);
} }
} }
@ -1018,11 +1018,9 @@ where
call!(req, method, TRACE, trace); call!(req, method, TRACE, trace);
let future = match fallback { let future = match fallback {
Fallback::Default(fallback) => { Fallback::Default(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
RouteFuture::from_future(fallback.0.clone().oneshot(req)) .strip_body(method == Method::HEAD),
.strip_body(method == Method::HEAD) Fallback::Custom(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
}
Fallback::Custom(fallback) => RouteFuture::from_future(fallback.0.clone().oneshot(req))
.strip_body(method == Method::HEAD), .strip_body(method == Method::HEAD),
}; };

View file

@ -1,6 +1,6 @@
//! Routing between [`Service`]s and handlers. //! Routing between [`Service`]s and handlers.
use self::{future::RouterFuture, not_found::NotFound}; use self::{future::RouteFuture, not_found::NotFound};
use crate::{ use crate::{
body::{boxed, Body, Bytes, HttpBody}, body::{boxed, Body, Bytes, HttpBody},
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo}, extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
@ -396,7 +396,11 @@ where
} }
#[inline] #[inline]
fn call_route(&self, match_: matchit::Match<&RouteId>, mut req: Request<B>) -> RouterFuture<B> { fn call_route(
&self,
match_: matchit::Match<&RouteId>,
mut req: Request<B>,
) -> RouteFuture<B, Infallible> {
let id = *match_.value; let id = *match_.value;
req.extensions_mut().insert(id); req.extensions_mut().insert(id);
@ -440,11 +444,10 @@ where
.expect("no route for id. This is a bug in axum. Please file an issue") .expect("no route for id. This is a bug in axum. Please file an issue")
.clone(); .clone();
let future = match &mut route { match &mut route {
Endpoint::MethodRouter(inner) => inner.call(req), Endpoint::MethodRouter(inner) => inner.call(req),
Endpoint::Route(inner) => inner.call(req), Endpoint::Route(inner) => inner.call(req),
}; }
RouterFuture::from_future(future)
} }
fn panic_on_matchit_error(&self, err: matchit::InsertError) { fn panic_on_matchit_error(&self, err: matchit::InsertError) {
@ -465,7 +468,7 @@ where
{ {
type Response = Response; type Response = Response;
type Error = Infallible; type Error = Infallible;
type Future = RouterFuture<B>; type Future = RouteFuture<B, Infallible>;
#[inline] #[inline]
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -496,15 +499,11 @@ where
with_path(req.uri(), &format!("{}/", path)) with_path(req.uri(), &format!("{}/", path))
}; };
let res = Redirect::permanent(redirect_to); let res = Redirect::permanent(redirect_to);
RouterFuture::from_response(res.into_response()) RouteFuture::from_response(res.into_response())
} else { } else {
match &self.fallback { match &self.fallback {
Fallback::Default(inner) => { Fallback::Default(inner) => inner.clone().call(req),
RouterFuture::from_future(inner.clone().call(req)) Fallback::Custom(inner) => inner.clone().call(req),
}
Fallback::Custom(inner) => {
RouterFuture::from_future(inner.clone().call(req))
}
} }
} }
} }

View file

@ -25,7 +25,7 @@ use tower_service::Service;
/// ///
/// You normally shouldn't need to care about this type. It's used in /// You normally shouldn't need to care about this type. It's used in
/// [`Router::layer`](super::Router::layer). /// [`Router::layer`](super::Router::layer).
pub struct Route<B = Body, E = Infallible>(pub(crate) BoxCloneService<Request<B>, Response, E>); pub struct Route<B = Body, E = Infallible>(BoxCloneService<Request<B>, Response, E>);
impl<B, E> Route<B, E> { impl<B, E> Route<B, E> {
pub(super) fn new<T>(svc: T) -> Self pub(super) fn new<T>(svc: T) -> Self
@ -35,6 +35,13 @@ impl<B, E> Route<B, E> {
{ {
Self(BoxCloneService::new(svc)) Self(BoxCloneService::new(svc))
} }
pub(crate) fn oneshot_inner(
&mut self,
req: Request<B>,
) -> Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>> {
self.0.clone().oneshot(req)
}
} }
impl<ReqBody, E> Clone for Route<ReqBody, E> { impl<ReqBody, E> Clone for Route<ReqBody, E> {
@ -64,7 +71,7 @@ where
#[inline] #[inline]
fn call(&mut self, req: Request<B>) -> Self::Future { fn call(&mut self, req: Request<B>) -> Self::Future {
RouteFuture::from_future(self.0.clone().oneshot(req)) RouteFuture::from_future(self.oneshot_inner(req))
} }
} }
@ -134,6 +141,9 @@ where
#[inline] #[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
#[derive(Clone, Copy)]
struct AlreadyPassedThroughRouteFuture;
let this = self.project(); let this = self.project();
let mut res = match this.kind.project() { let mut res = match this.kind.project() {
@ -147,6 +157,16 @@ where
} }
}; };
if res
.extensions()
.get::<AlreadyPassedThroughRouteFuture>()
.is_some()
{
return Poll::Ready(Ok(res));
} else {
res.extensions_mut().insert(AlreadyPassedThroughRouteFuture);
}
set_allow_header(&mut res, this.allow_header); set_allow_header(&mut res, this.allow_header);
// make sure to set content-length before removing the body // make sure to set content-length before removing the body

View file

@ -645,3 +645,36 @@ async fn head_content_length_through_hyper_server_that_hits_fallback() {
let res = client.head("/").send().await; let res = client.head("/").send().await;
assert_eq!(res.headers()["content-length"], "3"); assert_eq!(res.headers()["content-length"], "3");
} }
#[tokio::test]
async fn head_with_middleware_applied() {
use tower_http::compression::{predicate::SizeAbove, CompressionLayer};
let app = Router::new()
.route("/", get(|| async { "Hello, World!" }))
.layer(CompressionLayer::new().compress_when(SizeAbove::new(0)));
let client = TestClient::new(app);
// send GET request
let res = client
.get("/")
.header("accept-encoding", "gzip")
.send()
.await;
assert_eq!(res.headers()["transfer-encoding"], "chunked");
// cannot have `transfer-encoding: chunked` and `content-length`
assert!(!res.headers().contains_key("content-length"));
// send HEAD request
let res = client
.head("/")
.header("accept-encoding", "gzip")
.send()
.await;
// no response body so no `transfer-encoding`
assert!(!res.headers().contains_key("transfer-encoding"));
// no content-length since we cannot know it since the response
// is compressed
assert!(!res.headers().contains_key("content-length"));
}