mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-30 04:02:42 +01:00
fix: Avoid setting content-length before middleware (#3031)
This commit is contained in:
parent
893bb75e3b
commit
ce3d42947e
2 changed files with 36 additions and 11 deletions
|
@ -18,7 +18,7 @@ use std::{
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
use tower::{
|
use tower::{
|
||||||
util::{BoxCloneService, MapErrLayer, MapRequestLayer, MapResponseLayer, Oneshot},
|
util::{BoxCloneService, MapErrLayer, MapResponseLayer, Oneshot},
|
||||||
ServiceExt,
|
ServiceExt,
|
||||||
};
|
};
|
||||||
use tower_layer::Layer;
|
use tower_layer::Layer;
|
||||||
|
@ -73,7 +73,6 @@ impl<E> Route<E> {
|
||||||
NewError: 'static,
|
NewError: 'static,
|
||||||
{
|
{
|
||||||
let layer = (
|
let layer = (
|
||||||
MapRequestLayer::new(|req: Request<_>| req.map(Body::new)),
|
|
||||||
MapErrLayer::new(Into::into),
|
MapErrLayer::new(Into::into),
|
||||||
MapResponseLayer::new(IntoResponse::into_response),
|
MapResponseLayer::new(IntoResponse::into_response),
|
||||||
layer,
|
layer,
|
||||||
|
@ -113,7 +112,7 @@ where
|
||||||
#[inline]
|
#[inline]
|
||||||
fn call(&mut self, req: Request<B>) -> Self::Future {
|
fn call(&mut self, req: Request<B>) -> Self::Future {
|
||||||
let req = req.map(Body::new);
|
let req = req.map(Body::new);
|
||||||
RouteFuture::from_future(self.oneshot_inner(req))
|
RouteFuture::from_future(self.oneshot_inner(req)).not_top_level()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,6 +123,7 @@ pin_project! {
|
||||||
kind: RouteFutureKind<E>,
|
kind: RouteFutureKind<E>,
|
||||||
strip_body: bool,
|
strip_body: bool,
|
||||||
allow_header: Option<Bytes>,
|
allow_header: Option<Bytes>,
|
||||||
|
top_level: bool,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -151,6 +151,7 @@ impl<E> RouteFuture<E> {
|
||||||
kind: RouteFutureKind::Future { future },
|
kind: RouteFutureKind::Future { future },
|
||||||
strip_body: false,
|
strip_body: false,
|
||||||
allow_header: None,
|
allow_header: None,
|
||||||
|
top_level: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,6 +164,11 @@ impl<E> RouteFuture<E> {
|
||||||
self.allow_header = Some(allow_header);
|
self.allow_header = Some(allow_header);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn not_top_level(mut self) -> Self {
|
||||||
|
self.top_level = false;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<E> Future for RouteFuture<E> {
|
impl<E> Future for RouteFuture<E> {
|
||||||
|
@ -183,16 +189,16 @@ impl<E> Future for RouteFuture<E> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if *this.top_level {
|
||||||
set_allow_header(res.headers_mut(), this.allow_header);
|
set_allow_header(res.headers_mut(), this.allow_header);
|
||||||
|
|
||||||
// make sure to set content-length before removing the body
|
// make sure to set content-length before removing the body
|
||||||
set_content_length(res.size_hint(), res.headers_mut());
|
set_content_length(res.size_hint(), res.headers_mut());
|
||||||
|
|
||||||
let res = if *this.strip_body {
|
if *this.strip_body {
|
||||||
res.map(|_| Body::empty())
|
*res.body_mut() = Body::empty();
|
||||||
} else {
|
}
|
||||||
res
|
}
|
||||||
};
|
|
||||||
|
|
||||||
Poll::Ready(Ok(res))
|
Poll::Ready(Ok(res))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1087,3 +1087,22 @@ async fn locks_mutex_very_little() {
|
||||||
assert_eq!(num, 1);
|
assert_eq!(num, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[crate::test]
|
||||||
|
async fn middleware_adding_body() {
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/", get(()))
|
||||||
|
.layer(MapResponseLayer::new(|mut res: Response| -> Response {
|
||||||
|
// If there is a content-length header, its value will be zero and axum will avoid
|
||||||
|
// overwriting it. But this means our content-length doesn't match the length of the
|
||||||
|
// body, which leads to panics in Hyper. Thus we have to ensure that axum doesn't add
|
||||||
|
// on content-length headers until after middleware has been run.
|
||||||
|
assert!(!res.headers().contains_key("content-length"));
|
||||||
|
*res.body_mut() = "…".into();
|
||||||
|
res
|
||||||
|
}));
|
||||||
|
|
||||||
|
let client = TestClient::new(app);
|
||||||
|
let res = client.get("/").await;
|
||||||
|
assert_eq!(res.text().await, "…");
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue