diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 9e7832d3..ccd49bb9 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,7 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **fixed:** Correctly handle `GET`, `HEAD`, and `OPTIONS` requests in `ContentLengthLimit`. + Request with these methods are now accepted if they _do not_ have a `Content-Length` header, and + the request body will not be checked. If they do have a `Content-Length` header they'll be + rejected. This allows `ContentLengthLimit` to be used as middleware around several routes, + including `GET` routes ([#989]) + +[#989]: https://github.com/tokio-rs/axum/pull/989 # 0.5.4 (26. April, 2022) diff --git a/axum/src/extract/content_length_limit.rs b/axum/src/extract/content_length_limit.rs index afdb7247..6412324b 100644 --- a/axum/src/extract/content_length_limit.rs +++ b/axum/src/extract/content_length_limit.rs @@ -1,10 +1,14 @@ use super::{rejection::*, FromRequest, RequestParts}; use async_trait::async_trait; use axum_core::response::IntoResponse; +use http::Method; use std::ops::Deref; /// Extractor that will reject requests with a body larger than some size. /// +/// `GET`, `HEAD`, and `OPTIONS` requests are rejected if they have a `Content-Length` header, +/// otherwise they're accepted without the body being checked. +/// /// # Example /// /// ```rust,no_run @@ -38,20 +42,35 @@ where type Rejection = ContentLengthLimitRejection; async fn from_request(req: &mut RequestParts) -> Result { - let content_length = req.headers().get(http::header::CONTENT_LENGTH); + let content_length = req + .headers() + .get(http::header::CONTENT_LENGTH) + .and_then(|value| value.to_str().ok()?.parse::().ok()); - let content_length = - content_length.and_then(|value| value.to_str().ok()?.parse::().ok()); - - if let Some(length) = content_length { - if length > N { + match (content_length, req.method()) { + (content_length, &(Method::GET | Method::HEAD | Method::OPTIONS)) => { + if content_length.is_some() { + return Err(ContentLengthLimitRejection::ContentLengthNotAllowed( + ContentLengthNotAllowed, + )); + } else if req + .headers() + .get(http::header::TRANSFER_ENCODING) + .map_or(false, |value| value.as_bytes() == b"chunked") + { + return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired)); + } + } + (Some(content_length), _) if content_length > N => { return Err(ContentLengthLimitRejection::PayloadTooLarge( PayloadTooLarge, )); } - } else { - return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired)); - }; + (None, _) => { + return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired)); + } + _ => {} + } let value = T::from_request(req) .await @@ -72,7 +91,12 @@ impl Deref for ContentLengthLimit { #[cfg(test)] mod tests { use super::*; - use crate::{body::Bytes, routing::post, test_helpers::*, Router}; + use crate::{ + body::Bytes, + routing::{get, post}, + test_helpers::*, + Router, + }; use http::StatusCode; use serde::Deserialize; @@ -124,4 +148,45 @@ mod tests { .await; assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED); } + + #[tokio::test] + async fn get_request_without_content_length_is_accepted() { + let app = Router::new().route("/", get(|_body: ContentLengthLimit| async {})); + + let client = TestClient::new(app); + + let res = client.get("/").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn get_request_with_content_length_is_rejected() { + let app = Router::new().route("/", get(|_body: ContentLengthLimit| async {})); + + let client = TestClient::new(app); + + let res = client + .get("/") + .header("content-length", 3) + .body("foo") + .send() + .await; + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + } + + #[tokio::test] + async fn get_request_with_chunked_encoding_is_rejected() { + let app = Router::new().route("/", get(|_body: ContentLengthLimit| async {})); + + let client = TestClient::new(app); + + let res = client + .get("/") + .header("transfer-encoding", "chunked") + .body("3\r\nfoo\r\n0\r\n\r\n") + .send() + .await; + + assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED); + } } diff --git a/axum/src/extract/rejection.rs b/axum/src/extract/rejection.rs index fc3e6a54..4d13faee 100644 --- a/axum/src/extract/rejection.rs +++ b/axum/src/extract/rejection.rs @@ -63,6 +63,14 @@ define_rejection! { pub struct LengthRequired; } +define_rejection! { + #[status = BAD_REQUEST] + #[body = "`GET`, `HEAD`, `OPTIONS` requests are not allowed to have a `Content-Length` header"] + /// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if + /// the request is `GET`, `HEAD`, or `OPTIONS` and has a `Content-Length` header. + pub struct ContentLengthNotAllowed; +} + define_rejection! { #[status = INTERNAL_SERVER_ERROR] #[body = "No paths parameters found for matched route. Are you also extracting `Request<_>`?"] @@ -225,6 +233,8 @@ pub enum ContentLengthLimitRejection { #[allow(missing_docs)] LengthRequired(LengthRequired), #[allow(missing_docs)] + ContentLengthNotAllowed(ContentLengthNotAllowed), + #[allow(missing_docs)] Inner(T), } @@ -236,6 +246,7 @@ where match self { Self::PayloadTooLarge(inner) => inner.into_response(), Self::LengthRequired(inner) => inner.into_response(), + Self::ContentLengthNotAllowed(inner) => inner.into_response(), Self::Inner(inner) => inner.into_response(), } } @@ -249,6 +260,7 @@ where match self { Self::PayloadTooLarge(inner) => inner.fmt(f), Self::LengthRequired(inner) => inner.fmt(f), + Self::ContentLengthNotAllowed(inner) => inner.fmt(f), Self::Inner(inner) => inner.fmt(f), } } @@ -262,6 +274,7 @@ where match self { Self::PayloadTooLarge(inner) => Some(inner), Self::LengthRequired(inner) => Some(inner), + Self::ContentLengthNotAllowed(inner) => Some(inner), Self::Inner(inner) => Some(inner), } }