Support streaming/chunked requests in ContentLengthLimit (#1389)

* Support streaming/chunked requests in `ContentLengthLimit`

* changelog
This commit is contained in:
David Pedersen 2022-09-18 22:21:38 +02:00 committed by GitHub
parent 015de21a52
commit c81549d95b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 26 deletions

View file

@ -12,11 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`serde_json::Error` ([#1371])
- **added**: `JsonRejection` now displays the path at which a deserialization
error occurred too ([#1371])
- **fixed:** Support streaming/chunked requests in `ContentLengthLimit` ([#1389])
- **fixed:** Used `400 Bad Request` for `FailedToDeserializeQueryString`
rejections, instead of `422 Unprocessable Entity` ([#1387])
[#1371]: https://github.com/tokio-rs/axum/pull/1371
[#1387]: https://github.com/tokio-rs/axum/pull/1387
[#1389]: https://github.com/tokio-rs/axum/pull/1389
# 0.6.0-rc.2 (10. September, 2022)

View file

@ -2,6 +2,7 @@ use super::{rejection::*, FromRequest};
use async_trait::async_trait;
use axum_core::{extract::FromRequestParts, response::IntoResponse};
use http::{request::Parts, Method, Request};
use http_body::Limited;
use std::ops::Deref;
/// Extractor that will reject requests with a body larger than some size.
@ -36,32 +37,41 @@ use std::ops::Deref;
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// This requires the request to have a `Content-Length` header.
///
/// If you want to limit the size of request bodies without requiring a `Content-Length` header,
/// consider using [`tower_http::limit::RequestBodyLimitLayer`].
#[derive(Debug, Clone)]
pub struct ContentLengthLimit<T, const N: u64>(pub T);
#[async_trait]
impl<T, S, B, const N: u64> FromRequest<S, B> for ContentLengthLimit<T, N>
impl<T, S, B, R, const N: u64> FromRequest<S, B> for ContentLengthLimit<T, N>
where
T: FromRequest<S, B>,
T::Rejection: IntoResponse,
T: FromRequest<S, B, Rejection = R> + FromRequest<S, Limited<B>, Rejection = R>,
R: IntoResponse + Send,
B: Send + 'static,
S: Send + Sync,
{
type Rejection = ContentLengthLimitRejection<T::Rejection>;
type Rejection = ContentLengthLimitRejection<R>;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let (parts, body) = req.into_parts();
validate::<_, N>(&parts)?;
let req = Request::from_parts(parts, body);
let value = T::from_request(req, state)
.await
.map_err(ContentLengthLimitRejection::Inner)?;
let value = if let Some(err) = validate::<N>(&parts).err() {
match err {
RequestValidationError::LengthRequiredStream => {
// `Limited` supports limiting streams, so use that instead since this is a
// streaming request
let body = Limited::new(body, N as usize);
let req = Request::from_parts(parts, body);
T::from_request(req, state)
.await
.map_err(ContentLengthLimitRejection::Inner)?
}
other => return Err(other.into()),
}
} else {
let req = Request::from_parts(parts, body);
T::from_request(req, state)
.await
.map_err(ContentLengthLimitRejection::Inner)?
};
Ok(Self(value))
}
@ -77,7 +87,7 @@ where
type Rejection = ContentLengthLimitRejection<T::Rejection>;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
validate::<_, N>(parts)?;
validate::<N>(parts)?;
let value = T::from_request_parts(parts, state)
.await
@ -87,7 +97,7 @@ where
}
}
fn validate<E, const N: u64>(parts: &Parts) -> Result<(), ContentLengthLimitRejection<E>> {
fn validate<const N: u64>(parts: &Parts) -> Result<(), RequestValidationError> {
let content_length = parts
.headers
.get(http::header::CONTENT_LENGTH)
@ -96,24 +106,20 @@ fn validate<E, const N: u64>(parts: &Parts) -> Result<(), ContentLengthLimitReje
match (content_length, &parts.method) {
(content_length, &(Method::GET | Method::HEAD | Method::OPTIONS)) => {
if content_length.is_some() {
return Err(ContentLengthLimitRejection::ContentLengthNotAllowed(
ContentLengthNotAllowed,
));
return Err(RequestValidationError::ContentLengthNotAllowed);
} else if parts
.headers
.get(http::header::TRANSFER_ENCODING)
.map_or(false, |value| value.as_bytes() == b"chunked")
{
return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired));
return Err(RequestValidationError::LengthRequiredChunkedHeadOrGet);
}
}
(Some(content_length), _) if content_length > N => {
return Err(ContentLengthLimitRejection::PayloadTooLarge(
PayloadTooLarge,
));
return Err(RequestValidationError::PayloadTooLarge);
}
(None, _) => {
return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired));
return Err(RequestValidationError::LengthRequiredStream);
}
_ => {}
}
@ -129,6 +135,30 @@ impl<T, const N: u64> Deref for ContentLengthLimit<T, N> {
}
}
/// Similar to `ContentLengthLimitRejection` but more fine grained in that we can tell the
/// difference between `LengthRequiredStream` and `LengthRequiredChunkedHeadOrGet`
enum RequestValidationError {
PayloadTooLarge,
LengthRequiredStream,
LengthRequiredChunkedHeadOrGet,
ContentLengthNotAllowed,
}
impl<T> From<RequestValidationError> for ContentLengthLimitRejection<T> {
fn from(inner: RequestValidationError) -> Self {
match inner {
RequestValidationError::PayloadTooLarge => Self::PayloadTooLarge(PayloadTooLarge),
RequestValidationError::LengthRequiredStream
| RequestValidationError::LengthRequiredChunkedHeadOrGet => {
Self::LengthRequired(LengthRequired)
}
RequestValidationError::ContentLengthNotAllowed => {
Self::ContentLengthNotAllowed(ContentLengthNotAllowed)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -180,14 +210,25 @@ mod tests {
.await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
let chunk = repeat(0_u8).take(LIMIT as usize).collect::<Bytes>();
let res = client
.post("/")
.body(reqwest::Body::wrap_stream(futures_util::stream::iter(
vec![Ok::<_, std::io::Error>(Bytes::new())],
vec![Ok::<_, std::io::Error>(chunk)],
)))
.send()
.await;
assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED);
assert_eq!(res.status(), StatusCode::OK);
let chunk = repeat(0_u8).take((LIMIT + 1) as usize).collect::<Bytes>();
let res = client
.post("/")
.body(reqwest::Body::wrap_stream(futures_util::stream::iter(
vec![Ok::<_, std::io::Error>(chunk)],
)))
.send()
.await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
#[tokio::test]