fix: Avoid setting content-length before middleware (#3031)

This commit is contained in:
Sabrina Jewson 2024-11-16 15:50:25 +00:00 committed by GitHub
parent 893bb75e3b
commit ce3d42947e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 36 additions and 11 deletions

View file

@ -18,7 +18,7 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
}; };
use tower::{ use tower::{
util::{BoxCloneService, MapErrLayer, MapRequestLayer, MapResponseLayer, Oneshot}, util::{BoxCloneService, MapErrLayer, MapResponseLayer, Oneshot},
ServiceExt, ServiceExt,
}; };
use tower_layer::Layer; use tower_layer::Layer;
@ -73,7 +73,6 @@ impl<E> Route<E> {
NewError: 'static, NewError: 'static,
{ {
let layer = ( let layer = (
MapRequestLayer::new(|req: Request<_>| req.map(Body::new)),
MapErrLayer::new(Into::into), MapErrLayer::new(Into::into),
MapResponseLayer::new(IntoResponse::into_response), MapResponseLayer::new(IntoResponse::into_response),
layer, layer,
@ -113,7 +112,7 @@ where
#[inline] #[inline]
fn call(&mut self, req: Request<B>) -> Self::Future { fn call(&mut self, req: Request<B>) -> Self::Future {
let req = req.map(Body::new); let req = req.map(Body::new);
RouteFuture::from_future(self.oneshot_inner(req)) RouteFuture::from_future(self.oneshot_inner(req)).not_top_level()
} }
} }
@ -124,6 +123,7 @@ pin_project! {
kind: RouteFutureKind<E>, kind: RouteFutureKind<E>,
strip_body: bool, strip_body: bool,
allow_header: Option<Bytes>, allow_header: Option<Bytes>,
top_level: bool,
} }
} }
@ -151,6 +151,7 @@ impl<E> RouteFuture<E> {
kind: RouteFutureKind::Future { future }, kind: RouteFutureKind::Future { future },
strip_body: false, strip_body: false,
allow_header: None, allow_header: None,
top_level: true,
} }
} }
@ -163,6 +164,11 @@ impl<E> RouteFuture<E> {
self.allow_header = Some(allow_header); self.allow_header = Some(allow_header);
self self
} }
pub(crate) fn not_top_level(mut self) -> Self {
self.top_level = false;
self
}
} }
impl<E> Future for RouteFuture<E> { impl<E> Future for RouteFuture<E> {
@ -183,16 +189,16 @@ impl<E> Future for RouteFuture<E> {
} }
}; };
if *this.top_level {
set_allow_header(res.headers_mut(), this.allow_header); set_allow_header(res.headers_mut(), this.allow_header);
// make sure to set content-length before removing the body // make sure to set content-length before removing the body
set_content_length(res.size_hint(), res.headers_mut()); set_content_length(res.size_hint(), res.headers_mut());
let res = if *this.strip_body { if *this.strip_body {
res.map(|_| Body::empty()) *res.body_mut() = Body::empty();
} else { }
res }
};
Poll::Ready(Ok(res)) Poll::Ready(Ok(res))
} }

View file

@ -1087,3 +1087,22 @@ async fn locks_mutex_very_little() {
assert_eq!(num, 1); assert_eq!(num, 1);
} }
} }
#[crate::test]
async fn middleware_adding_body() {
let app = Router::new()
.route("/", get(()))
.layer(MapResponseLayer::new(|mut res: Response| -> Response {
// If there is a content-length header, its value will be zero and axum will avoid
// overwriting it. But this means our content-length doesn't match the length of the
// body, which leads to panics in Hyper. Thus we have to ensure that axum doesn't add
// on content-length headers until after middleware has been run.
assert!(!res.headers().contains_key("content-length"));
*res.body_mut() = "".into();
res
}));
let client = TestClient::new(app);
let res = client.get("/").await;
assert_eq!(res.text().await, "");
}