mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-01 08:56:15 +01:00
Support streaming/chunked requests in ContentLengthLimit
(#1389)
* Support streaming/chunked requests in `ContentLengthLimit` * changelog
This commit is contained in:
parent
015de21a52
commit
c81549d95b
2 changed files with 69 additions and 26 deletions
|
@ -12,11 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
`serde_json::Error` ([#1371])
|
||||
- **added**: `JsonRejection` now displays the path at which a deserialization
|
||||
error occurred too ([#1371])
|
||||
- **fixed:** Support streaming/chunked requests in `ContentLengthLimit` ([#1389])
|
||||
- **fixed:** Used `400 Bad Request` for `FailedToDeserializeQueryString`
|
||||
rejections, instead of `422 Unprocessable Entity` ([#1387])
|
||||
|
||||
[#1371]: https://github.com/tokio-rs/axum/pull/1371
|
||||
[#1387]: https://github.com/tokio-rs/axum/pull/1387
|
||||
[#1389]: https://github.com/tokio-rs/axum/pull/1389
|
||||
|
||||
# 0.6.0-rc.2 (10. September, 2022)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ use super::{rejection::*, FromRequest};
|
|||
use async_trait::async_trait;
|
||||
use axum_core::{extract::FromRequestParts, response::IntoResponse};
|
||||
use http::{request::Parts, Method, Request};
|
||||
use http_body::Limited;
|
||||
use std::ops::Deref;
|
||||
|
||||
/// Extractor that will reject requests with a body larger than some size.
|
||||
|
@ -36,32 +37,41 @@ use std::ops::Deref;
|
|||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
///
|
||||
/// This requires the request to have a `Content-Length` header.
|
||||
///
|
||||
/// If you want to limit the size of request bodies without requiring a `Content-Length` header,
|
||||
/// consider using [`tower_http::limit::RequestBodyLimitLayer`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContentLengthLimit<T, const N: u64>(pub T);
|
||||
|
||||
#[async_trait]
|
||||
impl<T, S, B, const N: u64> FromRequest<S, B> for ContentLengthLimit<T, N>
|
||||
impl<T, S, B, R, const N: u64> FromRequest<S, B> for ContentLengthLimit<T, N>
|
||||
where
|
||||
T: FromRequest<S, B>,
|
||||
T::Rejection: IntoResponse,
|
||||
T: FromRequest<S, B, Rejection = R> + FromRequest<S, Limited<B>, Rejection = R>,
|
||||
R: IntoResponse + Send,
|
||||
B: Send + 'static,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = ContentLengthLimitRejection<T::Rejection>;
|
||||
type Rejection = ContentLengthLimitRejection<R>;
|
||||
|
||||
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let (parts, body) = req.into_parts();
|
||||
validate::<_, N>(&parts)?;
|
||||
|
||||
let req = Request::from_parts(parts, body);
|
||||
let value = T::from_request(req, state)
|
||||
.await
|
||||
.map_err(ContentLengthLimitRejection::Inner)?;
|
||||
let value = if let Some(err) = validate::<N>(&parts).err() {
|
||||
match err {
|
||||
RequestValidationError::LengthRequiredStream => {
|
||||
// `Limited` supports limiting streams, so use that instead since this is a
|
||||
// streaming request
|
||||
let body = Limited::new(body, N as usize);
|
||||
let req = Request::from_parts(parts, body);
|
||||
T::from_request(req, state)
|
||||
.await
|
||||
.map_err(ContentLengthLimitRejection::Inner)?
|
||||
}
|
||||
other => return Err(other.into()),
|
||||
}
|
||||
} else {
|
||||
let req = Request::from_parts(parts, body);
|
||||
T::from_request(req, state)
|
||||
.await
|
||||
.map_err(ContentLengthLimitRejection::Inner)?
|
||||
};
|
||||
|
||||
Ok(Self(value))
|
||||
}
|
||||
|
@ -77,7 +87,7 @@ where
|
|||
type Rejection = ContentLengthLimitRejection<T::Rejection>;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
validate::<_, N>(parts)?;
|
||||
validate::<N>(parts)?;
|
||||
|
||||
let value = T::from_request_parts(parts, state)
|
||||
.await
|
||||
|
@ -87,7 +97,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn validate<E, const N: u64>(parts: &Parts) -> Result<(), ContentLengthLimitRejection<E>> {
|
||||
fn validate<const N: u64>(parts: &Parts) -> Result<(), RequestValidationError> {
|
||||
let content_length = parts
|
||||
.headers
|
||||
.get(http::header::CONTENT_LENGTH)
|
||||
|
@ -96,24 +106,20 @@ fn validate<E, const N: u64>(parts: &Parts) -> Result<(), ContentLengthLimitReje
|
|||
match (content_length, &parts.method) {
|
||||
(content_length, &(Method::GET | Method::HEAD | Method::OPTIONS)) => {
|
||||
if content_length.is_some() {
|
||||
return Err(ContentLengthLimitRejection::ContentLengthNotAllowed(
|
||||
ContentLengthNotAllowed,
|
||||
));
|
||||
return Err(RequestValidationError::ContentLengthNotAllowed);
|
||||
} else if parts
|
||||
.headers
|
||||
.get(http::header::TRANSFER_ENCODING)
|
||||
.map_or(false, |value| value.as_bytes() == b"chunked")
|
||||
{
|
||||
return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired));
|
||||
return Err(RequestValidationError::LengthRequiredChunkedHeadOrGet);
|
||||
}
|
||||
}
|
||||
(Some(content_length), _) if content_length > N => {
|
||||
return Err(ContentLengthLimitRejection::PayloadTooLarge(
|
||||
PayloadTooLarge,
|
||||
));
|
||||
return Err(RequestValidationError::PayloadTooLarge);
|
||||
}
|
||||
(None, _) => {
|
||||
return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired));
|
||||
return Err(RequestValidationError::LengthRequiredStream);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
@ -129,6 +135,30 @@ impl<T, const N: u64> Deref for ContentLengthLimit<T, N> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Similar to `ContentLengthLimitRejection` but more fine grained in that we can tell the
|
||||
/// difference between `LengthRequiredStream` and `LengthRequiredChunkedHeadOrGet`
|
||||
enum RequestValidationError {
|
||||
PayloadTooLarge,
|
||||
LengthRequiredStream,
|
||||
LengthRequiredChunkedHeadOrGet,
|
||||
ContentLengthNotAllowed,
|
||||
}
|
||||
|
||||
impl<T> From<RequestValidationError> for ContentLengthLimitRejection<T> {
|
||||
fn from(inner: RequestValidationError) -> Self {
|
||||
match inner {
|
||||
RequestValidationError::PayloadTooLarge => Self::PayloadTooLarge(PayloadTooLarge),
|
||||
RequestValidationError::LengthRequiredStream
|
||||
| RequestValidationError::LengthRequiredChunkedHeadOrGet => {
|
||||
Self::LengthRequired(LengthRequired)
|
||||
}
|
||||
RequestValidationError::ContentLengthNotAllowed => {
|
||||
Self::ContentLengthNotAllowed(ContentLengthNotAllowed)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -180,14 +210,25 @@ mod tests {
|
|||
.await;
|
||||
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
|
||||
|
||||
let chunk = repeat(0_u8).take(LIMIT as usize).collect::<Bytes>();
|
||||
let res = client
|
||||
.post("/")
|
||||
.body(reqwest::Body::wrap_stream(futures_util::stream::iter(
|
||||
vec![Ok::<_, std::io::Error>(Bytes::new())],
|
||||
vec![Ok::<_, std::io::Error>(chunk)],
|
||||
)))
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED);
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let chunk = repeat(0_u8).take((LIMIT + 1) as usize).collect::<Bytes>();
|
||||
let res = client
|
||||
.post("/")
|
||||
.body(reqwest::Body::wrap_stream(futures_util::stream::iter(
|
||||
vec![Ok::<_, std::io::Error>(chunk)],
|
||||
)))
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
Loading…
Reference in a new issue