fix: Avoid setting content-length before middleware (#2897)

This commit is contained in:
Sabrina Jewson 2024-10-06 21:21:45 +01:00 committed by GitHub
parent 822db3b1af
commit 5512b5b91f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 44 additions and 47 deletions

View file

@ -11,7 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
They can be enabled by changing `get(ws_endpoint)` handlers to `any(ws_endpoint)`
- **added:** Add `MethodFilter::CONNECT`, `routing::connect[_service]`
and `MethodRouter::connect[_service]` ([#2961])
- **fixed:** Avoid setting `content-length` before middleware ([#2897]).
This allows middleware to add bodies to requests without needing to manually set `content-length`
[#2897]: https://github.com/tokio-rs/axum/pull/2897
[#2984]: https://github.com/tokio-rs/axum/pull/2984
[#2961]: https://github.com/tokio-rs/axum/pull/2961

View file

@ -15,10 +15,10 @@ use std::{
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
task::{ready, Context, Poll},
};
use tower::{
util::{MapErrLayer, MapRequestLayer, MapResponseLayer, Oneshot},
util::{MapErrLayer, MapResponseLayer, Oneshot},
ServiceExt,
};
use tower_layer::Layer;
@ -44,7 +44,7 @@ impl<E> Route<E> {
pub(crate) fn oneshot_inner(&mut self, req: Request) -> RouteFuture<E> {
let method = req.method().clone();
RouteFuture::from_future(method, self.0.clone().oneshot(req))
RouteFuture::new(method, self.0.clone().oneshot(req))
}
pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
@ -57,7 +57,6 @@ impl<E> Route<E> {
NewError: 'static,
{
let layer = (
MapRequestLayer::new(|req: Request<_>| req.map(Body::new)),
MapErrLayer::new(Into::into),
MapResponseLayer::new(IntoResponse::into_response),
layer,
@ -96,7 +95,7 @@ where
#[inline]
fn call(&mut self, req: Request<B>) -> Self::Future {
self.oneshot_inner(req.map(Body::new))
self.oneshot_inner(req.map(Body::new)).not_top_level()
}
}
@ -104,37 +103,20 @@ pin_project! {
/// Response future for [`Route`].
pub struct RouteFuture<E> {
#[pin]
kind: RouteFutureKind<E>,
inner: Oneshot<BoxCloneService<Request, Response, E>, Request>,
method: Method,
allow_header: Option<Bytes>,
}
}
pin_project! {
#[project = RouteFutureKindProj]
enum RouteFutureKind<E> {
Future {
#[pin]
future: Oneshot<
BoxCloneService<Request, Response, E>,
Request,
>,
},
Response {
response: Option<Response>,
}
top_level: bool,
}
}
impl<E> RouteFuture<E> {
pub(crate) fn from_future(
method: Method,
future: Oneshot<BoxCloneService<Request, Response, E>, Request>,
) -> Self {
fn new(method: Method, inner: Oneshot<BoxCloneService<Request, Response, E>, Request>) -> Self {
Self {
kind: RouteFutureKind::Future { future },
inner,
method,
allow_header: None,
top_level: true,
}
}
@ -142,6 +124,11 @@ impl<E> RouteFuture<E> {
self.allow_header = Some(allow_header);
self
}
pub(crate) fn not_top_level(mut self) -> Self {
self.top_level = false;
self
}
}
impl<E> Future for RouteFuture<E> {
@ -150,19 +137,7 @@ impl<E> Future for RouteFuture<E> {
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut res = match this.kind.project() {
RouteFutureKindProj::Future { future } => match future.poll(cx) {
Poll::Ready(Ok(res)) => res,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
},
RouteFutureKindProj::Response { response } => {
response.take().expect("future polled after completion")
}
};
set_allow_header(res.headers_mut(), this.allow_header);
let mut res = ready!(this.inner.poll(cx))?;
if *this.method == Method::CONNECT && res.status().is_success() {
// From https://httpwg.org/specs/rfc9110.html#CONNECT:
@ -176,16 +151,16 @@ impl<E> Future for RouteFuture<E> {
error!("response to CONNECT with nonempty body");
res = res.map(|_| Body::empty());
}
} else {
} else if *this.top_level {
set_allow_header(res.headers_mut(), this.allow_header);
// make sure to set content-length before removing the body
set_content_length(res.size_hint(), res.headers_mut());
}
let res = if *this.method == Method::HEAD {
res.map(|_| Body::empty())
} else {
res
};
if *this.method == Method::HEAD {
*res.body_mut() = Body::empty();
}
}
Poll::Ready(Ok(res))
}

View file

@ -1073,3 +1073,22 @@ async fn colon_in_route() {
async fn asterisk_in_route() {
_ = Router::<()>::new().route("/*foo", get(|| async move {}));
}
#[crate::test]
async fn middleware_adding_body() {
let app = Router::new()
.route("/", get(()))
.layer(MapResponseLayer::new(|mut res: Response| -> Response {
// If there is a content-length header, its value will be zero and Axum will avoid
// overwriting it. But this means our content-length doesnt match the length of the
// body, which leads to panics in Hyper. Thus we have to ensure that Axum doesnt add
// on content-length headers until after middleware has been run.
assert!(!res.headers().contains_key("content-length"));
*res.body_mut() = "".into();
res
}));
let client = TestClient::new(app);
let res = client.get("/").await;
assert_eq!(res.text().await, "");
}