mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-06 02:35:42 +01:00
Support streaming/chunked requests in ContentLengthLimit
(#1389)
* Support streaming/chunked requests in `ContentLengthLimit` * changelog
This commit is contained in:
parent
015de21a52
commit
c81549d95b
2 changed files with 69 additions and 26 deletions
|
@ -12,11 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
`serde_json::Error` ([#1371])
|
`serde_json::Error` ([#1371])
|
||||||
- **added**: `JsonRejection` now displays the path at which a deserialization
|
- **added**: `JsonRejection` now displays the path at which a deserialization
|
||||||
error occurred too ([#1371])
|
error occurred too ([#1371])
|
||||||
|
- **fixed:** Support streaming/chunked requests in `ContentLengthLimit` ([#1389])
|
||||||
- **fixed:** Used `400 Bad Request` for `FailedToDeserializeQueryString`
|
- **fixed:** Used `400 Bad Request` for `FailedToDeserializeQueryString`
|
||||||
rejections, instead of `422 Unprocessable Entity` ([#1387])
|
rejections, instead of `422 Unprocessable Entity` ([#1387])
|
||||||
|
|
||||||
[#1371]: https://github.com/tokio-rs/axum/pull/1371
|
[#1371]: https://github.com/tokio-rs/axum/pull/1371
|
||||||
[#1387]: https://github.com/tokio-rs/axum/pull/1387
|
[#1387]: https://github.com/tokio-rs/axum/pull/1387
|
||||||
|
[#1389]: https://github.com/tokio-rs/axum/pull/1389
|
||||||
|
|
||||||
# 0.6.0-rc.2 (10. September, 2022)
|
# 0.6.0-rc.2 (10. September, 2022)
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ use super::{rejection::*, FromRequest};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum_core::{extract::FromRequestParts, response::IntoResponse};
|
use axum_core::{extract::FromRequestParts, response::IntoResponse};
|
||||||
use http::{request::Parts, Method, Request};
|
use http::{request::Parts, Method, Request};
|
||||||
|
use http_body::Limited;
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
|
|
||||||
/// Extractor that will reject requests with a body larger than some size.
|
/// Extractor that will reject requests with a body larger than some size.
|
||||||
|
@ -36,32 +37,41 @@ use std::ops::Deref;
|
||||||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||||
/// # };
|
/// # };
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
/// This requires the request to have a `Content-Length` header.
|
|
||||||
///
|
|
||||||
/// If you want to limit the size of request bodies without requiring a `Content-Length` header,
|
|
||||||
/// consider using [`tower_http::limit::RequestBodyLimitLayer`].
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ContentLengthLimit<T, const N: u64>(pub T);
|
pub struct ContentLengthLimit<T, const N: u64>(pub T);
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<T, S, B, const N: u64> FromRequest<S, B> for ContentLengthLimit<T, N>
|
impl<T, S, B, R, const N: u64> FromRequest<S, B> for ContentLengthLimit<T, N>
|
||||||
where
|
where
|
||||||
T: FromRequest<S, B>,
|
T: FromRequest<S, B, Rejection = R> + FromRequest<S, Limited<B>, Rejection = R>,
|
||||||
T::Rejection: IntoResponse,
|
R: IntoResponse + Send,
|
||||||
B: Send + 'static,
|
B: Send + 'static,
|
||||||
S: Send + Sync,
|
S: Send + Sync,
|
||||||
{
|
{
|
||||||
type Rejection = ContentLengthLimitRejection<T::Rejection>;
|
type Rejection = ContentLengthLimitRejection<R>;
|
||||||
|
|
||||||
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
|
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
let (parts, body) = req.into_parts();
|
let (parts, body) = req.into_parts();
|
||||||
validate::<_, N>(&parts)?;
|
|
||||||
|
|
||||||
let req = Request::from_parts(parts, body);
|
let value = if let Some(err) = validate::<N>(&parts).err() {
|
||||||
let value = T::from_request(req, state)
|
match err {
|
||||||
.await
|
RequestValidationError::LengthRequiredStream => {
|
||||||
.map_err(ContentLengthLimitRejection::Inner)?;
|
// `Limited` supports limiting streams, so use that instead since this is a
|
||||||
|
// streaming request
|
||||||
|
let body = Limited::new(body, N as usize);
|
||||||
|
let req = Request::from_parts(parts, body);
|
||||||
|
T::from_request(req, state)
|
||||||
|
.await
|
||||||
|
.map_err(ContentLengthLimitRejection::Inner)?
|
||||||
|
}
|
||||||
|
other => return Err(other.into()),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let req = Request::from_parts(parts, body);
|
||||||
|
T::from_request(req, state)
|
||||||
|
.await
|
||||||
|
.map_err(ContentLengthLimitRejection::Inner)?
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Self(value))
|
Ok(Self(value))
|
||||||
}
|
}
|
||||||
|
@ -77,7 +87,7 @@ where
|
||||||
type Rejection = ContentLengthLimitRejection<T::Rejection>;
|
type Rejection = ContentLengthLimitRejection<T::Rejection>;
|
||||||
|
|
||||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
validate::<_, N>(parts)?;
|
validate::<N>(parts)?;
|
||||||
|
|
||||||
let value = T::from_request_parts(parts, state)
|
let value = T::from_request_parts(parts, state)
|
||||||
.await
|
.await
|
||||||
|
@ -87,7 +97,7 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate<E, const N: u64>(parts: &Parts) -> Result<(), ContentLengthLimitRejection<E>> {
|
fn validate<const N: u64>(parts: &Parts) -> Result<(), RequestValidationError> {
|
||||||
let content_length = parts
|
let content_length = parts
|
||||||
.headers
|
.headers
|
||||||
.get(http::header::CONTENT_LENGTH)
|
.get(http::header::CONTENT_LENGTH)
|
||||||
|
@ -96,24 +106,20 @@ fn validate<E, const N: u64>(parts: &Parts) -> Result<(), ContentLengthLimitReje
|
||||||
match (content_length, &parts.method) {
|
match (content_length, &parts.method) {
|
||||||
(content_length, &(Method::GET | Method::HEAD | Method::OPTIONS)) => {
|
(content_length, &(Method::GET | Method::HEAD | Method::OPTIONS)) => {
|
||||||
if content_length.is_some() {
|
if content_length.is_some() {
|
||||||
return Err(ContentLengthLimitRejection::ContentLengthNotAllowed(
|
return Err(RequestValidationError::ContentLengthNotAllowed);
|
||||||
ContentLengthNotAllowed,
|
|
||||||
));
|
|
||||||
} else if parts
|
} else if parts
|
||||||
.headers
|
.headers
|
||||||
.get(http::header::TRANSFER_ENCODING)
|
.get(http::header::TRANSFER_ENCODING)
|
||||||
.map_or(false, |value| value.as_bytes() == b"chunked")
|
.map_or(false, |value| value.as_bytes() == b"chunked")
|
||||||
{
|
{
|
||||||
return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired));
|
return Err(RequestValidationError::LengthRequiredChunkedHeadOrGet);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(Some(content_length), _) if content_length > N => {
|
(Some(content_length), _) if content_length > N => {
|
||||||
return Err(ContentLengthLimitRejection::PayloadTooLarge(
|
return Err(RequestValidationError::PayloadTooLarge);
|
||||||
PayloadTooLarge,
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
(None, _) => {
|
(None, _) => {
|
||||||
return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired));
|
return Err(RequestValidationError::LengthRequiredStream);
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
@ -129,6 +135,30 @@ impl<T, const N: u64> Deref for ContentLengthLimit<T, N> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Similar to `ContentLengthLimitRejection` but more fine grained in that we can tell the
|
||||||
|
/// difference between `LengthRequiredStream` and `LengthRequiredChunkedHeadOrGet`
|
||||||
|
enum RequestValidationError {
|
||||||
|
PayloadTooLarge,
|
||||||
|
LengthRequiredStream,
|
||||||
|
LengthRequiredChunkedHeadOrGet,
|
||||||
|
ContentLengthNotAllowed,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> From<RequestValidationError> for ContentLengthLimitRejection<T> {
|
||||||
|
fn from(inner: RequestValidationError) -> Self {
|
||||||
|
match inner {
|
||||||
|
RequestValidationError::PayloadTooLarge => Self::PayloadTooLarge(PayloadTooLarge),
|
||||||
|
RequestValidationError::LengthRequiredStream
|
||||||
|
| RequestValidationError::LengthRequiredChunkedHeadOrGet => {
|
||||||
|
Self::LengthRequired(LengthRequired)
|
||||||
|
}
|
||||||
|
RequestValidationError::ContentLengthNotAllowed => {
|
||||||
|
Self::ContentLengthNotAllowed(ContentLengthNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -180,14 +210,25 @@ mod tests {
|
||||||
.await;
|
.await;
|
||||||
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
|
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
|
||||||
|
|
||||||
|
let chunk = repeat(0_u8).take(LIMIT as usize).collect::<Bytes>();
|
||||||
let res = client
|
let res = client
|
||||||
.post("/")
|
.post("/")
|
||||||
.body(reqwest::Body::wrap_stream(futures_util::stream::iter(
|
.body(reqwest::Body::wrap_stream(futures_util::stream::iter(
|
||||||
vec![Ok::<_, std::io::Error>(Bytes::new())],
|
vec![Ok::<_, std::io::Error>(chunk)],
|
||||||
)))
|
)))
|
||||||
.send()
|
.send()
|
||||||
.await;
|
.await;
|
||||||
assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED);
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
let chunk = repeat(0_u8).take((LIMIT + 1) as usize).collect::<Bytes>();
|
||||||
|
let res = client
|
||||||
|
.post("/")
|
||||||
|
.body(reqwest::Body::wrap_stream(futures_util::stream::iter(
|
||||||
|
vec![Ok::<_, std::io::Error>(chunk)],
|
||||||
|
)))
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
Loading…
Reference in a new issue