diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index e0379c2d..1bedc2ac 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -12,11 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `serde_json::Error` ([#1371]) - **added**: `JsonRejection` now displays the path at which a deserialization error occurred too ([#1371]) +- **fixed:** Support streaming/chunked requests in `ContentLengthLimit` ([#1389]) - **fixed:** Used `400 Bad Request` for `FailedToDeserializeQueryString` rejections, instead of `422 Unprocessable Entity` ([#1387]) [#1371]: https://github.com/tokio-rs/axum/pull/1371 [#1387]: https://github.com/tokio-rs/axum/pull/1387 +[#1389]: https://github.com/tokio-rs/axum/pull/1389 # 0.6.0-rc.2 (10. September, 2022) diff --git a/axum/src/extract/content_length_limit.rs b/axum/src/extract/content_length_limit.rs index 6e53af7b..f5775842 100644 --- a/axum/src/extract/content_length_limit.rs +++ b/axum/src/extract/content_length_limit.rs @@ -2,6 +2,7 @@ use super::{rejection::*, FromRequest}; use async_trait::async_trait; use axum_core::{extract::FromRequestParts, response::IntoResponse}; use http::{request::Parts, Method, Request}; +use http_body::Limited; use std::ops::Deref; /// Extractor that will reject requests with a body larger than some size. @@ -36,32 +37,41 @@ use std::ops::Deref; /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -/// -/// This requires the request to have a `Content-Length` header. -/// -/// If you want to limit the size of request bodies without requiring a `Content-Length` header, -/// consider using [`tower_http::limit::RequestBodyLimitLayer`]. #[derive(Debug, Clone)] pub struct ContentLengthLimit(pub T); #[async_trait] -impl FromRequest for ContentLengthLimit +impl FromRequest for ContentLengthLimit where - T: FromRequest, - T::Rejection: IntoResponse, + T: FromRequest + FromRequest, Rejection = R>, + R: IntoResponse + Send, B: Send + 'static, S: Send + Sync, { - type Rejection = ContentLengthLimitRejection; + type Rejection = ContentLengthLimitRejection; async fn from_request(req: Request, state: &S) -> Result { let (parts, body) = req.into_parts(); - validate::<_, N>(&parts)?; - let req = Request::from_parts(parts, body); - let value = T::from_request(req, state) - .await - .map_err(ContentLengthLimitRejection::Inner)?; + let value = if let Some(err) = validate::(&parts).err() { + match err { + RequestValidationError::LengthRequiredStream => { + // `Limited` supports limiting streams, so use that instead since this is a + // streaming request + let body = Limited::new(body, N as usize); + let req = Request::from_parts(parts, body); + T::from_request(req, state) + .await + .map_err(ContentLengthLimitRejection::Inner)? + } + other => return Err(other.into()), + } + } else { + let req = Request::from_parts(parts, body); + T::from_request(req, state) + .await + .map_err(ContentLengthLimitRejection::Inner)? + }; Ok(Self(value)) } @@ -77,7 +87,7 @@ where type Rejection = ContentLengthLimitRejection; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - validate::<_, N>(parts)?; + validate::(parts)?; let value = T::from_request_parts(parts, state) .await @@ -87,7 +97,7 @@ where } } -fn validate(parts: &Parts) -> Result<(), ContentLengthLimitRejection> { +fn validate(parts: &Parts) -> Result<(), RequestValidationError> { let content_length = parts .headers .get(http::header::CONTENT_LENGTH) @@ -96,24 +106,20 @@ fn validate(parts: &Parts) -> Result<(), ContentLengthLimitReje match (content_length, &parts.method) { (content_length, &(Method::GET | Method::HEAD | Method::OPTIONS)) => { if content_length.is_some() { - return Err(ContentLengthLimitRejection::ContentLengthNotAllowed( - ContentLengthNotAllowed, - )); + return Err(RequestValidationError::ContentLengthNotAllowed); } else if parts .headers .get(http::header::TRANSFER_ENCODING) .map_or(false, |value| value.as_bytes() == b"chunked") { - return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired)); + return Err(RequestValidationError::LengthRequiredChunkedHeadOrGet); } } (Some(content_length), _) if content_length > N => { - return Err(ContentLengthLimitRejection::PayloadTooLarge( - PayloadTooLarge, - )); + return Err(RequestValidationError::PayloadTooLarge); } (None, _) => { - return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired)); + return Err(RequestValidationError::LengthRequiredStream); } _ => {} } @@ -129,6 +135,30 @@ impl Deref for ContentLengthLimit { } } +/// Similar to `ContentLengthLimitRejection` but more fine grained in that we can tell the +/// difference between `LengthRequiredStream` and `LengthRequiredChunkedHeadOrGet` +enum RequestValidationError { + PayloadTooLarge, + LengthRequiredStream, + LengthRequiredChunkedHeadOrGet, + ContentLengthNotAllowed, +} + +impl From for ContentLengthLimitRejection { + fn from(inner: RequestValidationError) -> Self { + match inner { + RequestValidationError::PayloadTooLarge => Self::PayloadTooLarge(PayloadTooLarge), + RequestValidationError::LengthRequiredStream + | RequestValidationError::LengthRequiredChunkedHeadOrGet => { + Self::LengthRequired(LengthRequired) + } + RequestValidationError::ContentLengthNotAllowed => { + Self::ContentLengthNotAllowed(ContentLengthNotAllowed) + } + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -180,14 +210,25 @@ mod tests { .await; assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); + let chunk = repeat(0_u8).take(LIMIT as usize).collect::(); let res = client .post("/") .body(reqwest::Body::wrap_stream(futures_util::stream::iter( - vec![Ok::<_, std::io::Error>(Bytes::new())], + vec![Ok::<_, std::io::Error>(chunk)], ))) .send() .await; - assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED); + assert_eq!(res.status(), StatusCode::OK); + + let chunk = repeat(0_u8).take((LIMIT + 1) as usize).collect::(); + let res = client + .post("/") + .body(reqwest::Body::wrap_stream(futures_util::stream::iter( + vec![Ok::<_, std::io::Error>(chunk)], + ))) + .send() + .await; + assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } #[tokio::test]