Handle GET, HEAD, and OPTIONS correctly in ContentLengthLimit (#989)

* Handle `GET`/`HEAD`/`OPTIONS` in `ContentLengthLimit`

* changelog

* Apply suggestions from code review

Co-authored-by: Marcus Griep <marcus@griep.us>

* Don't allow GET/HEAD/OPTIONS with `transfer-encoding: chunked`

* simplify constructing chunked body

* Update axum/src/extract/content_length_limit.rs

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

Co-authored-by: Marcus Griep <marcus@griep.us>
Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
David Pedersen 2022-05-05 09:17:54 +02:00 committed by GitHub
parent 4b384fa01c
commit d19beffd6d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 95 additions and 11 deletions

View file

@ -7,7 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased # Unreleased
- None. - **fixed:** Correctly handle `GET`, `HEAD`, and `OPTIONS` requests in `ContentLengthLimit`.
Request with these methods are now accepted if they _do not_ have a `Content-Length` header, and
the request body will not be checked. If they do have a `Content-Length` header they'll be
rejected. This allows `ContentLengthLimit` to be used as middleware around several routes,
including `GET` routes ([#989])
[#989]: https://github.com/tokio-rs/axum/pull/989
# 0.5.4 (26. April, 2022) # 0.5.4 (26. April, 2022)

View file

@ -1,10 +1,14 @@
use super::{rejection::*, FromRequest, RequestParts}; use super::{rejection::*, FromRequest, RequestParts};
use async_trait::async_trait; use async_trait::async_trait;
use axum_core::response::IntoResponse; use axum_core::response::IntoResponse;
use http::Method;
use std::ops::Deref; use std::ops::Deref;
/// Extractor that will reject requests with a body larger than some size. /// Extractor that will reject requests with a body larger than some size.
/// ///
/// `GET`, `HEAD`, and `OPTIONS` requests are rejected if they have a `Content-Length` header,
/// otherwise they're accepted without the body being checked.
///
/// # Example /// # Example
/// ///
/// ```rust,no_run /// ```rust,no_run
@ -38,20 +42,35 @@ where
type Rejection = ContentLengthLimitRejection<T::Rejection>; type Rejection = ContentLengthLimitRejection<T::Rejection>;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let content_length = req.headers().get(http::header::CONTENT_LENGTH); let content_length = req
.headers()
.get(http::header::CONTENT_LENGTH)
.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
let content_length = match (content_length, req.method()) {
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok()); (content_length, &(Method::GET | Method::HEAD | Method::OPTIONS)) => {
if content_length.is_some() {
if let Some(length) = content_length { return Err(ContentLengthLimitRejection::ContentLengthNotAllowed(
if length > N { ContentLengthNotAllowed,
));
} else if req
.headers()
.get(http::header::TRANSFER_ENCODING)
.map_or(false, |value| value.as_bytes() == b"chunked")
{
return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired));
}
}
(Some(content_length), _) if content_length > N => {
return Err(ContentLengthLimitRejection::PayloadTooLarge( return Err(ContentLengthLimitRejection::PayloadTooLarge(
PayloadTooLarge, PayloadTooLarge,
)); ));
} }
} else { (None, _) => {
return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired)); return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired));
}; }
_ => {}
}
let value = T::from_request(req) let value = T::from_request(req)
.await .await
@ -72,7 +91,12 @@ impl<T, const N: u64> Deref for ContentLengthLimit<T, N> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::{body::Bytes, routing::post, test_helpers::*, Router}; use crate::{
body::Bytes,
routing::{get, post},
test_helpers::*,
Router,
};
use http::StatusCode; use http::StatusCode;
use serde::Deserialize; use serde::Deserialize;
@ -124,4 +148,45 @@ mod tests {
.await; .await;
assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED); assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED);
} }
#[tokio::test]
async fn get_request_without_content_length_is_accepted() {
let app = Router::new().route("/", get(|_body: ContentLengthLimit<Bytes, 1337>| async {}));
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn get_request_with_content_length_is_rejected() {
let app = Router::new().route("/", get(|_body: ContentLengthLimit<Bytes, 1337>| async {}));
let client = TestClient::new(app);
let res = client
.get("/")
.header("content-length", 3)
.body("foo")
.send()
.await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn get_request_with_chunked_encoding_is_rejected() {
let app = Router::new().route("/", get(|_body: ContentLengthLimit<Bytes, 1337>| async {}));
let client = TestClient::new(app);
let res = client
.get("/")
.header("transfer-encoding", "chunked")
.body("3\r\nfoo\r\n0\r\n\r\n")
.send()
.await;
assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED);
}
} }

View file

@ -63,6 +63,14 @@ define_rejection! {
pub struct LengthRequired; pub struct LengthRequired;
} }
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`GET`, `HEAD`, `OPTIONS` requests are not allowed to have a `Content-Length` header"]
/// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if
/// the request is `GET`, `HEAD`, or `OPTIONS` and has a `Content-Length` header.
pub struct ContentLengthNotAllowed;
}
define_rejection! { define_rejection! {
#[status = INTERNAL_SERVER_ERROR] #[status = INTERNAL_SERVER_ERROR]
#[body = "No paths parameters found for matched route. Are you also extracting `Request<_>`?"] #[body = "No paths parameters found for matched route. Are you also extracting `Request<_>`?"]
@ -225,6 +233,8 @@ pub enum ContentLengthLimitRejection<T> {
#[allow(missing_docs)] #[allow(missing_docs)]
LengthRequired(LengthRequired), LengthRequired(LengthRequired),
#[allow(missing_docs)] #[allow(missing_docs)]
ContentLengthNotAllowed(ContentLengthNotAllowed),
#[allow(missing_docs)]
Inner(T), Inner(T),
} }
@ -236,6 +246,7 @@ where
match self { match self {
Self::PayloadTooLarge(inner) => inner.into_response(), Self::PayloadTooLarge(inner) => inner.into_response(),
Self::LengthRequired(inner) => inner.into_response(), Self::LengthRequired(inner) => inner.into_response(),
Self::ContentLengthNotAllowed(inner) => inner.into_response(),
Self::Inner(inner) => inner.into_response(), Self::Inner(inner) => inner.into_response(),
} }
} }
@ -249,6 +260,7 @@ where
match self { match self {
Self::PayloadTooLarge(inner) => inner.fmt(f), Self::PayloadTooLarge(inner) => inner.fmt(f),
Self::LengthRequired(inner) => inner.fmt(f), Self::LengthRequired(inner) => inner.fmt(f),
Self::ContentLengthNotAllowed(inner) => inner.fmt(f),
Self::Inner(inner) => inner.fmt(f), Self::Inner(inner) => inner.fmt(f),
} }
} }
@ -262,6 +274,7 @@ where
match self { match self {
Self::PayloadTooLarge(inner) => Some(inner), Self::PayloadTooLarge(inner) => Some(inner),
Self::LengthRequired(inner) => Some(inner), Self::LengthRequired(inner) => Some(inner),
Self::ContentLengthNotAllowed(inner) => Some(inner),
Self::Inner(inner) => Some(inner), Self::Inner(inner) => Some(inner),
} }
} }