diff --git a/axum-core/CHANGELOG.md b/axum-core/CHANGELOG.md index 71cccf5d..fea1a3ed 100644 --- a/axum-core/CHANGELOG.md +++ b/axum-core/CHANGELOG.md @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397]) + +[#1397]: https://github.com/tokio-rs/axum/pull/1397 # 0.3.0-rc.2 (10. September, 2022) diff --git a/axum-core/src/extract/default_body_limit.rs b/axum-core/src/extract/default_body_limit.rs index fcf44b79..21616287 100644 --- a/axum-core/src/extract/default_body_limit.rs +++ b/axum-core/src/extract/default_body_limit.rs @@ -16,8 +16,15 @@ use tower_layer::Layer; /// [`Json`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Json.html /// [`Form`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Form.html #[derive(Debug, Clone)] -#[non_exhaustive] -pub struct DefaultBodyLimit; +pub struct DefaultBodyLimit { + kind: DefaultBodyLimitKind, +} + +#[derive(Debug, Clone, Copy)] +pub(crate) enum DefaultBodyLimitKind { + Disable, + Limit(usize), +} impl DefaultBodyLimit { /// Disable the default request body limit. @@ -53,7 +60,42 @@ impl DefaultBodyLimit { /// [`Json`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Json.html /// [`Form`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Form.html pub fn disable() -> Self { - Self + Self { + kind: DefaultBodyLimitKind::Disable, + } + } + + /// Set the default request body limit. + /// + /// By default the limit of request body sizes that [`Bytes::from_request`] (and other + /// extractors built on top of it such as `String`, [`Json`], and [`Form`]) is 2MB. This method + /// can be used to change that limit. + /// + /// # Example + /// + /// ``` + /// use axum::{ + /// Router, + /// routing::get, + /// body::{Bytes, Body}, + /// extract::DefaultBodyLimit, + /// }; + /// use tower_http::limit::RequestBodyLimitLayer; + /// use http_body::Limited; + /// + /// let app: Router<_, Limited<Body>> = Router::new() + /// .route("/", get(|body: Bytes| async {})) + /// // Replace the default of 2MB with 1024 bytes. + /// .layer(DefaultBodyLimit::max(1024)); + /// ``` + /// + /// [`Bytes::from_request`]: bytes::Bytes + /// [`Json`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Json.html + /// [`Form`]: https://docs.rs/axum/0.6.0-rc.2/axum/struct.Form.html + pub fn max(limit: usize) -> Self { + Self { + kind: DefaultBodyLimitKind::Limit(limit), + } } } @@ -61,15 +103,15 @@ impl<S> Layer<S> for DefaultBodyLimit { type Service = DefaultBodyLimitService<S>; fn layer(&self, inner: S) -> Self::Service { - DefaultBodyLimitService { inner } + DefaultBodyLimitService { + inner, + kind: self.kind, + } } } -#[derive(Copy, Clone, Debug)] -pub(crate) struct DefaultBodyLimitDisabled; - mod private { - use super::DefaultBodyLimitDisabled; + use super::DefaultBodyLimitKind; use http::Request; use std::task::Context; use tower_service::Service; @@ -77,6 +119,7 @@ mod private { #[derive(Debug, Clone, Copy)] pub struct DefaultBodyLimitService<S> { pub(super) inner: S, + pub(super) kind: DefaultBodyLimitKind, } impl<B, S> Service<Request<B>> for DefaultBodyLimitService<S> @@ -94,7 +137,7 @@ mod private { #[inline] fn call(&mut self, mut req: Request<B>) -> Self::Future { - req.extensions_mut().insert(DefaultBodyLimitDisabled); + req.extensions_mut().insert(self.kind); self.inner.call(req) } } diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index 60ef1b79..fc151a2d 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -1,5 +1,5 @@ use super::{ - default_body_limit::DefaultBodyLimitDisabled, rejection::*, FromRequest, FromRequestParts, + default_body_limit::DefaultBodyLimitKind, rejection::*, FromRequest, FromRequestParts, }; use crate::BoxError; use async_trait::async_trait; @@ -88,15 +88,23 @@ where // `axum/src/docs/extract.md` if this changes const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb - let bytes = if req.extensions().get::<DefaultBodyLimitDisabled>().is_some() { - crate::body::to_bytes(req.into_body()) + let limit_kind = req.extensions().get::<DefaultBodyLimitKind>().copied(); + let bytes = match limit_kind { + Some(DefaultBodyLimitKind::Disable) => crate::body::to_bytes(req.into_body()) .await - .map_err(FailedToBufferBody::from_err)? - } else { - let body = http_body::Limited::new(req.into_body(), DEFAULT_LIMIT); - crate::body::to_bytes(body) - .await - .map_err(FailedToBufferBody::from_err)? + .map_err(FailedToBufferBody::from_err)?, + Some(DefaultBodyLimitKind::Limit(limit)) => { + let body = http_body::Limited::new(req.into_body(), limit); + crate::body::to_bytes(body) + .await + .map_err(FailedToBufferBody::from_err)? + } + None => { + let body = http_body::Limited::new(req.into_body(), DEFAULT_LIMIT); + crate::body::to_bytes(body) + .await + .map_err(FailedToBufferBody::from_err)? + } }; Ok(bytes) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 1bedc2ac..94052653 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -15,10 +15,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **fixed:** Support streaming/chunked requests in `ContentLengthLimit` ([#1389]) - **fixed:** Used `400 Bad Request` for `FailedToDeserializeQueryString` rejections, instead of `422 Unprocessable Entity` ([#1387]) +- **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397]) [#1371]: https://github.com/tokio-rs/axum/pull/1371 [#1387]: https://github.com/tokio-rs/axum/pull/1387 [#1389]: https://github.com/tokio-rs/axum/pull/1389 +[#1397]: https://github.com/tokio-rs/axum/pull/1397 # 0.6.0-rc.2 (10. September, 2022) diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 08efd28f..4d4c58cd 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -671,6 +671,31 @@ async fn limited_body_with_content_length() { assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } +#[tokio::test] +async fn changing_the_default_limit() { + let new_limit = 2; + + let app = Router::new() + .route("/", post(|_: Bytes| async {})) + .layer(DefaultBodyLimit::max(new_limit)); + + let client = TestClient::new(app); + + let res = client + .post("/") + .body(Body::from("a".repeat(new_limit))) + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .post("/") + .body(Body::from("a".repeat(new_limit + 1))) + .send() + .await; + assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); +} + #[tokio::test] async fn limited_body_with_streaming_body() { const LIMIT: usize = 3;