mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-21 22:56:46 +01:00
fix: Avoid setting content-length before middleware (#2897)
This commit is contained in:
parent
822db3b1af
commit
5512b5b91f
3 changed files with 44 additions and 47 deletions
|
@ -11,7 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
They can be enabled by changing `get(ws_endpoint)` handlers to `any(ws_endpoint)`
|
||||
- **added:** Add `MethodFilter::CONNECT`, `routing::connect[_service]`
|
||||
and `MethodRouter::connect[_service]` ([#2961])
|
||||
- **fixed:** Avoid setting `content-length` before middleware ([#2897]).
|
||||
This allows middleware to add bodies to requests without needing to manually set `content-length`
|
||||
|
||||
[#2897]: https://github.com/tokio-rs/axum/pull/2897
|
||||
[#2984]: https://github.com/tokio-rs/axum/pull/2984
|
||||
[#2961]: https://github.com/tokio-rs/axum/pull/2961
|
||||
|
||||
|
|
|
@ -15,10 +15,10 @@ use std::{
|
|||
fmt,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
use tower::{
|
||||
util::{MapErrLayer, MapRequestLayer, MapResponseLayer, Oneshot},
|
||||
util::{MapErrLayer, MapResponseLayer, Oneshot},
|
||||
ServiceExt,
|
||||
};
|
||||
use tower_layer::Layer;
|
||||
|
@ -44,7 +44,7 @@ impl<E> Route<E> {
|
|||
|
||||
pub(crate) fn oneshot_inner(&mut self, req: Request) -> RouteFuture<E> {
|
||||
let method = req.method().clone();
|
||||
RouteFuture::from_future(method, self.0.clone().oneshot(req))
|
||||
RouteFuture::new(method, self.0.clone().oneshot(req))
|
||||
}
|
||||
|
||||
pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
|
||||
|
@ -57,7 +57,6 @@ impl<E> Route<E> {
|
|||
NewError: 'static,
|
||||
{
|
||||
let layer = (
|
||||
MapRequestLayer::new(|req: Request<_>| req.map(Body::new)),
|
||||
MapErrLayer::new(Into::into),
|
||||
MapResponseLayer::new(IntoResponse::into_response),
|
||||
layer,
|
||||
|
@ -96,7 +95,7 @@ where
|
|||
|
||||
#[inline]
|
||||
fn call(&mut self, req: Request<B>) -> Self::Future {
|
||||
self.oneshot_inner(req.map(Body::new))
|
||||
self.oneshot_inner(req.map(Body::new)).not_top_level()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -104,37 +103,20 @@ pin_project! {
|
|||
/// Response future for [`Route`].
|
||||
pub struct RouteFuture<E> {
|
||||
#[pin]
|
||||
kind: RouteFutureKind<E>,
|
||||
inner: Oneshot<BoxCloneService<Request, Response, E>, Request>,
|
||||
method: Method,
|
||||
allow_header: Option<Bytes>,
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
#[project = RouteFutureKindProj]
|
||||
enum RouteFutureKind<E> {
|
||||
Future {
|
||||
#[pin]
|
||||
future: Oneshot<
|
||||
BoxCloneService<Request, Response, E>,
|
||||
Request,
|
||||
>,
|
||||
},
|
||||
Response {
|
||||
response: Option<Response>,
|
||||
}
|
||||
top_level: bool,
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> RouteFuture<E> {
|
||||
pub(crate) fn from_future(
|
||||
method: Method,
|
||||
future: Oneshot<BoxCloneService<Request, Response, E>, Request>,
|
||||
) -> Self {
|
||||
fn new(method: Method, inner: Oneshot<BoxCloneService<Request, Response, E>, Request>) -> Self {
|
||||
Self {
|
||||
kind: RouteFutureKind::Future { future },
|
||||
inner,
|
||||
method,
|
||||
allow_header: None,
|
||||
top_level: true,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -142,6 +124,11 @@ impl<E> RouteFuture<E> {
|
|||
self.allow_header = Some(allow_header);
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn not_top_level(mut self) -> Self {
|
||||
self.top_level = false;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> Future for RouteFuture<E> {
|
||||
|
@ -150,19 +137,7 @@ impl<E> Future for RouteFuture<E> {
|
|||
#[inline]
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.project();
|
||||
|
||||
let mut res = match this.kind.project() {
|
||||
RouteFutureKindProj::Future { future } => match future.poll(cx) {
|
||||
Poll::Ready(Ok(res)) => res,
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
},
|
||||
RouteFutureKindProj::Response { response } => {
|
||||
response.take().expect("future polled after completion")
|
||||
}
|
||||
};
|
||||
|
||||
set_allow_header(res.headers_mut(), this.allow_header);
|
||||
let mut res = ready!(this.inner.poll(cx))?;
|
||||
|
||||
if *this.method == Method::CONNECT && res.status().is_success() {
|
||||
// From https://httpwg.org/specs/rfc9110.html#CONNECT:
|
||||
|
@ -176,16 +151,16 @@ impl<E> Future for RouteFuture<E> {
|
|||
error!("response to CONNECT with nonempty body");
|
||||
res = res.map(|_| Body::empty());
|
||||
}
|
||||
} else {
|
||||
} else if *this.top_level {
|
||||
set_allow_header(res.headers_mut(), this.allow_header);
|
||||
|
||||
// make sure to set content-length before removing the body
|
||||
set_content_length(res.size_hint(), res.headers_mut());
|
||||
}
|
||||
|
||||
let res = if *this.method == Method::HEAD {
|
||||
res.map(|_| Body::empty())
|
||||
} else {
|
||||
res
|
||||
};
|
||||
if *this.method == Method::HEAD {
|
||||
*res.body_mut() = Body::empty();
|
||||
}
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(res))
|
||||
}
|
||||
|
|
|
@ -1073,3 +1073,22 @@ async fn colon_in_route() {
|
|||
async fn asterisk_in_route() {
|
||||
_ = Router::<()>::new().route("/*foo", get(|| async move {}));
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn middleware_adding_body() {
|
||||
let app = Router::new()
|
||||
.route("/", get(()))
|
||||
.layer(MapResponseLayer::new(|mut res: Response| -> Response {
|
||||
// If there is a content-length header, its value will be zero and Axum will avoid
|
||||
// overwriting it. But this means our content-length doesn’t match the length of the
|
||||
// body, which leads to panics in Hyper. Thus we have to ensure that Axum doesn’t add
|
||||
// on content-length headers until after middleware has been run.
|
||||
assert!(!res.headers().contains_key("content-length"));
|
||||
*res.body_mut() = "…".into();
|
||||
res
|
||||
}));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
let res = client.get("/").await;
|
||||
assert_eq!(res.text().await, "…");
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue