Fix setting content-length: 0 for HEAD responses with compression middleware (#755)

* Fix getting `content-length` for `chunked` responses

Fixes #747

* changelog

* Fix `cargo deny bans`

https://github.com/tokio-rs/axum/pull/753 will fix things properly
This commit is contained in:
David Pedersen 2022-02-14 10:56:08 +01:00 committed by GitHub
parent 409a6651f5
commit c135436cc9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 74 additions and 39 deletions

View file

@ -54,6 +54,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **fixed:** Set `Allow` header when responding with `405 Method Not Allowed` ([#733])
- **fixed:** Correctly set the `Content-Length` header for response to `HEAD`
requests ([#734])
- **fixed:** Fix wrong `content-length` for `HEAD` requests to endpoints that returns chunked
responses ([#755])
[#644]: https://github.com/tokio-rs/axum/pull/644
[#665]: https://github.com/tokio-rs/axum/pull/665
@ -62,6 +64,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#719]: https://github.com/tokio-rs/axum/pull/719
[#733]: https://github.com/tokio-rs/axum/pull/733
[#734]: https://github.com/tokio-rs/axum/pull/734
[#755]: https://github.com/tokio-rs/axum/pull/755
# 0.4.4 (13. January, 2022)

View file

@ -1,21 +1,3 @@
//! Future types.
use crate::response::Response;
use std::convert::Infallible;
pub use super::{into_make_service::IntoMakeServiceFuture, route::RouteFuture};
opaque_future! {
/// Response future for [`Router`](super::Router).
pub type RouterFuture<B> = RouteFuture<B, Infallible>;
}
impl<B> RouterFuture<B> {
pub(super) fn from_future(future: RouteFuture<B, Infallible>) -> Self {
Self::new(future)
}
pub(super) fn from_response(response: Response) -> Self {
Self::new(RouteFuture::from_response(response))
}
}

View file

@ -983,7 +983,7 @@ where
) => {
if $method == Method::$method_variant {
if let Some(svc) = $svc {
return RouteFuture::from_future(svc.0.clone().oneshot($req))
return RouteFuture::from_future(svc.oneshot_inner($req))
.strip_body($method == Method::HEAD);
}
}
@ -1018,11 +1018,9 @@ where
call!(req, method, TRACE, trace);
let future = match fallback {
Fallback::Default(fallback) => {
RouteFuture::from_future(fallback.0.clone().oneshot(req))
.strip_body(method == Method::HEAD)
}
Fallback::Custom(fallback) => RouteFuture::from_future(fallback.0.clone().oneshot(req))
Fallback::Default(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
.strip_body(method == Method::HEAD),
Fallback::Custom(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
.strip_body(method == Method::HEAD),
};

View file

@ -1,6 +1,6 @@
//! Routing between [`Service`]s and handlers.
use self::{future::RouterFuture, not_found::NotFound};
use self::{future::RouteFuture, not_found::NotFound};
use crate::{
body::{boxed, Body, Bytes, HttpBody},
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
@ -396,7 +396,11 @@ where
}
#[inline]
fn call_route(&self, match_: matchit::Match<&RouteId>, mut req: Request<B>) -> RouterFuture<B> {
fn call_route(
&self,
match_: matchit::Match<&RouteId>,
mut req: Request<B>,
) -> RouteFuture<B, Infallible> {
let id = *match_.value;
req.extensions_mut().insert(id);
@ -440,11 +444,10 @@ where
.expect("no route for id. This is a bug in axum. Please file an issue")
.clone();
let future = match &mut route {
match &mut route {
Endpoint::MethodRouter(inner) => inner.call(req),
Endpoint::Route(inner) => inner.call(req),
};
RouterFuture::from_future(future)
}
}
fn panic_on_matchit_error(&self, err: matchit::InsertError) {
@ -465,7 +468,7 @@ where
{
type Response = Response;
type Error = Infallible;
type Future = RouterFuture<B>;
type Future = RouteFuture<B, Infallible>;
#[inline]
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -496,15 +499,11 @@ where
with_path(req.uri(), &format!("{}/", path))
};
let res = Redirect::permanent(redirect_to);
RouterFuture::from_response(res.into_response())
RouteFuture::from_response(res.into_response())
} else {
match &self.fallback {
Fallback::Default(inner) => {
RouterFuture::from_future(inner.clone().call(req))
}
Fallback::Custom(inner) => {
RouterFuture::from_future(inner.clone().call(req))
}
Fallback::Default(inner) => inner.clone().call(req),
Fallback::Custom(inner) => inner.clone().call(req),
}
}
}

View file

@ -25,7 +25,7 @@ use tower_service::Service;
///
/// You normally shouldn't need to care about this type. It's used in
/// [`Router::layer`](super::Router::layer).
pub struct Route<B = Body, E = Infallible>(pub(crate) BoxCloneService<Request<B>, Response, E>);
pub struct Route<B = Body, E = Infallible>(BoxCloneService<Request<B>, Response, E>);
impl<B, E> Route<B, E> {
pub(super) fn new<T>(svc: T) -> Self
@ -35,6 +35,13 @@ impl<B, E> Route<B, E> {
{
Self(BoxCloneService::new(svc))
}
pub(crate) fn oneshot_inner(
&mut self,
req: Request<B>,
) -> Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>> {
self.0.clone().oneshot(req)
}
}
impl<ReqBody, E> Clone for Route<ReqBody, E> {
@ -64,7 +71,7 @@ where
#[inline]
fn call(&mut self, req: Request<B>) -> Self::Future {
RouteFuture::from_future(self.0.clone().oneshot(req))
RouteFuture::from_future(self.oneshot_inner(req))
}
}
@ -134,6 +141,9 @@ where
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
#[derive(Clone, Copy)]
struct AlreadyPassedThroughRouteFuture;
let this = self.project();
let mut res = match this.kind.project() {
@ -147,6 +157,16 @@ where
}
};
if res
.extensions()
.get::<AlreadyPassedThroughRouteFuture>()
.is_some()
{
return Poll::Ready(Ok(res));
} else {
res.extensions_mut().insert(AlreadyPassedThroughRouteFuture);
}
set_allow_header(&mut res, this.allow_header);
// make sure to set content-length before removing the body

View file

@ -645,3 +645,36 @@ async fn head_content_length_through_hyper_server_that_hits_fallback() {
let res = client.head("/").send().await;
assert_eq!(res.headers()["content-length"], "3");
}
#[tokio::test]
async fn head_with_middleware_applied() {
use tower_http::compression::{predicate::SizeAbove, CompressionLayer};
let app = Router::new()
.route("/", get(|| async { "Hello, World!" }))
.layer(CompressionLayer::new().compress_when(SizeAbove::new(0)));
let client = TestClient::new(app);
// send GET request
let res = client
.get("/")
.header("accept-encoding", "gzip")
.send()
.await;
assert_eq!(res.headers()["transfer-encoding"], "chunked");
// cannot have `transfer-encoding: chunked` and `content-length`
assert!(!res.headers().contains_key("content-length"));
// send HEAD request
let res = client
.head("/")
.header("accept-encoding", "gzip")
.send()
.await;
// no response body so no `transfer-encoding`
assert!(!res.headers().contains_key("transfer-encoding"));
// no content-length since we cannot know it since the response
// is compressed
assert!(!res.headers().contains_key("content-length"));
}