diff --git a/axum-core/CHANGELOG.md b/axum-core/CHANGELOG.md index 94b26094..091ff855 100644 --- a/axum-core/CHANGELOG.md +++ b/axum-core/CHANGELOG.md @@ -7,8 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +- **added:** Automatically handle `http_body::LengthLimitError` in `FailedToBufferBody` and map + such errors to `413 Payload Too Large` ([#1048]) - **fixed:** Use `impl IntoResponse` less in docs ([#1049]) +[#1048]: https://github.com/tokio-rs/axum/pull/1048 [#1049]: https://github.com/tokio-rs/axum/pull/1049 # 0.2.4 (02. May, 2022) diff --git a/axum-core/Cargo.toml b/axum-core/Cargo.toml index 9ce2a642..2b817302 100644 --- a/axum-core/Cargo.toml +++ b/axum-core/Cargo.toml @@ -15,7 +15,7 @@ async-trait = "0.1" bytes = "1.0" futures-util = { version = "0.3", default-features = false, features = ["alloc"] } http = "0.2.7" -http-body = "0.4" +http-body = "0.4.5" mime = "0.3.16" [dev-dependencies] diff --git a/axum-core/src/extract/rejection.rs b/axum-core/src/extract/rejection.rs index fc81eda7..e6f53b82 100644 --- a/axum-core/src/extract/rejection.rs +++ b/axum-core/src/extract/rejection.rs @@ -1,6 +1,9 @@ //! Rejection response types. -use crate::response::{IntoResponse, Response}; +use crate::{ + response::{IntoResponse, Response}, + BoxError, +}; use http::StatusCode; use std::fmt; @@ -28,12 +31,44 @@ impl fmt::Display for BodyAlreadyExtracted { impl std::error::Error for BodyAlreadyExtracted {} +composite_rejection! { + /// Rejection type for extractors that buffer the request body. Used if the + /// request body cannot be buffered due to an error. + pub enum FailedToBufferBody { + LengthLimitError, + UnknownBodyError, + } +} + +impl FailedToBufferBody { + pub(crate) fn from_err<E>(err: E) -> Self + where + E: Into<BoxError>, + { + match err.into().downcast::<http_body::LengthLimitError>() { + Ok(err) => Self::LengthLimitError(LengthLimitError::from_err(err)), + Err(err) => Self::UnknownBodyError(UnknownBodyError::from_err(err)), + } + } +} + +define_rejection! { + #[status = PAYLOAD_TOO_LARGE] + #[body = "Failed to buffer the request body"] + /// Encountered some other error when buffering the body. + /// + /// This can _only_ happen when you're using [`tower_http::limit::RequestBodyLimitLayer`] or + /// otherwise wrapping request bodies in [`http_body::Limited`]. + /// + /// [`tower_http::limit::RequestBodyLimitLayer`]: https://docs.rs/tower-http/0.3/tower_http/limit/struct.RequestBodyLimitLayer.html + pub struct LengthLimitError(Error); +} + define_rejection! { #[status = BAD_REQUEST] #[body = "Failed to buffer the request body"] - /// Rejection type for extractors that buffer the request body. Used if the - /// request body cannot be buffered due to an error. - pub struct FailedToBufferBody(Error); + /// Encountered an unknown error when buffering the body. + pub struct UnknownBodyError(Error); } define_rejection! { diff --git a/axum-core/src/macros.rs b/axum-core/src/macros.rs index ee0bfd7b..9ea64cdc 100644 --- a/axum-core/src/macros.rs +++ b/axum-core/src/macros.rs @@ -10,6 +10,7 @@ macro_rules! define_rejection { pub struct $name(pub(crate) crate::Error); impl $name { + #[allow(dead_code)] pub(crate) fn from_err<E>(err: E) -> Self where E: Into<crate::BoxError>, diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 4c75dd85..7d4f0971 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -81,7 +81,7 @@ features = [ ] [dev-dependencies.tower-http] -version = "0.3.0" +version = "0.3.4" features = ["full"] [package.metadata.docs.rs] diff --git a/axum/src/extract/content_length_limit.rs b/axum/src/extract/content_length_limit.rs index 6412324b..62148a8e 100644 --- a/axum/src/extract/content_length_limit.rs +++ b/axum/src/extract/content_length_limit.rs @@ -29,6 +29,9 @@ use std::ops::Deref; /// ``` /// /// This requires the request to have a `Content-Length` header. +/// +/// If you want to limit the size of request bodies without requiring a `Content-Length` header, +/// consider using [`tower_http::limit::RequestBodyLimitLayer`]. #[derive(Debug, Clone)] pub struct ContentLengthLimit<T, const N: u64>(pub T); diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 768ffe02..b253bc52 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -8,7 +8,7 @@ use crate::{ test_helpers::*, BoxError, Json, Router, }; -use http::{Method, Request, Response, StatusCode, Uri}; +use http::{header::CONTENT_LENGTH, HeaderMap, Method, Request, Response, StatusCode, Uri}; use hyper::Body; use serde::Deserialize; use serde_json::{json, Value}; @@ -20,7 +20,7 @@ use std::{ time::Duration, }; use tower::{service_fn, timeout::TimeoutLayer, ServiceBuilder, ServiceExt}; -use tower_http::auth::RequireAuthorizationLayer; +use tower_http::{auth::RequireAuthorizationLayer, limit::RequestBodyLimitLayer}; use tower_service::Service; mod fallback; @@ -699,3 +699,57 @@ async fn routes_must_start_with_slash() { let app = Router::new().route(":foo", get(|| async {})); TestClient::new(app); } + +#[tokio::test] +async fn limited_body_with_content_length() { + const LIMIT: usize = 3; + + let app = Router::new() + .route( + "/", + post(|headers: HeaderMap, _body: Bytes| async move { + assert!(headers.get(CONTENT_LENGTH).is_some()); + }), + ) + .layer(RequestBodyLimitLayer::new(LIMIT)); + + let client = TestClient::new(app); + + let res = client.post("/").body("a".repeat(LIMIT)).send().await; + assert_eq!(res.status(), StatusCode::OK); + + let res = client.post("/").body("a".repeat(LIMIT * 2)).send().await; + assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); +} + +#[tokio::test] +async fn limited_body_with_streaming_body() { + const LIMIT: usize = 3; + + let app = Router::new() + .route( + "/", + post(|headers: HeaderMap, _body: Bytes| async move { + assert!(headers.get(CONTENT_LENGTH).is_none()); + }), + ) + .layer(RequestBodyLimitLayer::new(LIMIT)); + + let client = TestClient::new(app); + + let stream = futures_util::stream::iter(vec![Ok::<_, hyper::Error>("a".repeat(LIMIT))]); + let res = client + .post("/") + .body(Body::wrap_stream(stream)) + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + + let stream = futures_util::stream::iter(vec![Ok::<_, hyper::Error>("a".repeat(LIMIT * 2))]); + let res = client + .post("/") + .body(Body::wrap_stream(stream)) + .send() + .await; + assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); +} diff --git a/axum/src/test_helpers/test_client.rs b/axum/src/test_helpers/test_client.rs index 27f4be4a..34791b45 100644 --- a/axum/src/test_helpers/test_client.rs +++ b/axum/src/test_helpers/test_client.rs @@ -118,6 +118,7 @@ impl RequestBuilder { } } +#[derive(Debug)] pub(crate) struct TestResponse { response: reqwest::Response, }