Remove ContentLengthLimit (#1400)

* feat: remove ContentLengthLimit

* feat: remove ContentLengthLimit rejections

* fix: update multipart docs

* fix: typo

* feat: add wip extractor code

* feat: revert "feat: add wip extractor code"

* fix: update Multipart docs

* fix: update examples

* fix: missing import in an example

* fix: broken import yet again

* fix: disable default body limit for example

* fix: key value store example

* fix: update expected debug_handler output

* chore: update CHANGELOG

* Update axum/CHANGELOG.md

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
Marek Kuskowski 2022-09-24 13:29:53 +02:00 committed by GitHub
parent c3f3db79ec
commit 896ffc5fba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 33 additions and 383 deletions

View file

@ -13,7 +13,7 @@ error[E0277]: the trait bound `bool: FromRequestParts<()>` is not satisfied
<(T1, T2, T3, T4, T5, T6) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6, T7) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequestParts<S>>
and 26 others
and 25 others
= note: required because of the requirements on the impl of `FromRequest<(), Body, axum_core::extract::private::ViaParts>` for `bool`
note: required by a bound in `__axum_macros_check_handler_0_from_request_check`
--> tests/debug_handler/fail/argument_not_extractor.rs:3:1

View file

@ -13,6 +13,6 @@ error[E0277]: the trait bound `String: FromRequestParts<()>` is not satisfied
<(T1, T2, T3, T4, T5, T6) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6, T7) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequestParts<S>>
and 26 others
and 25 others
= help: see issue #48214
= note: this error originates in the attribute macro `debug_handler` (in Nightly builds, run with -Z macro-backtrace for more info)

View file

@ -13,7 +13,7 @@ error[E0277]: the trait bound `bool: IntoResponse` is not satisfied
(Response<()>, T1, T2, R)
(Response<()>, T1, T2, T3, R)
(Response<()>, T1, T2, T3, T4, R)
and 122 others
and 118 others
note: required by a bound in `__axum_macros_check_handler_into_response::{closure#0}::check`
--> tests/debug_handler/fail/wrong_return_type.rs:4:23
|

View file

@ -13,4 +13,4 @@ error[E0277]: the trait bound `String: FromRequestParts<S>` is not satisfied
<(T1, T2, T3, T4, T5, T6) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6, T7) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequestParts<S>>
and 27 others
and 26 others

View file

@ -18,12 +18,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Add `middleware::from_extractor_with_state` and
`middleware::from_extractor_with_state_arc` ([#1396])
- **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397])
- **breaking:** `ContentLengthLimit` has been removed. `Use DefaultBodyLimit` instead ([#1400])
[#1371]: https://github.com/tokio-rs/axum/pull/1371
[#1387]: https://github.com/tokio-rs/axum/pull/1387
[#1389]: https://github.com/tokio-rs/axum/pull/1389
[#1396]: https://github.com/tokio-rs/axum/pull/1396
[#1397]: https://github.com/tokio-rs/axum/pull/1397
[#1400]: https://github.com/tokio-rs/axum/pull/1400
# 0.6.0-rc.2 (10. September, 2022)

View file

@ -1,274 +0,0 @@
use super::{rejection::*, FromRequest};
use async_trait::async_trait;
use axum_core::{extract::FromRequestParts, response::IntoResponse};
use http::{request::Parts, Method, Request};
use http_body::Limited;
use std::ops::Deref;
/// Extractor that will reject requests with a body larger than some size.
///
/// `GET`, `HEAD`, and `OPTIONS` requests are rejected if they have a `Content-Length` header,
/// otherwise they're accepted without the body being checked.
///
/// Note: `ContentLengthLimit` can wrap types that extract the body (for example, [`Form`] or [`Json`])
/// if that is the case, the inner type will consume the request's body, which means the
/// `ContentLengthLimit` must come *last* if the handler uses several extractors. See
/// ["the order of extractors"][order-of-extractors]
///
/// [order-of-extractors]: crate::extract#the-order-of-extractors
/// [`Form`]: crate::form::Form
/// [`Json`]: crate::json::Json
///
/// # Example
///
/// ```rust,no_run
/// use axum::{
/// extract::ContentLengthLimit,
/// routing::post,
/// Router,
/// };
///
/// async fn handler(body: ContentLengthLimit<String, 1024>) {
/// // ...
/// }
///
/// let app = Router::new().route("/", post(handler));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
#[derive(Debug, Clone)]
pub struct ContentLengthLimit<T, const N: u64>(pub T);
#[async_trait]
impl<T, S, B, R, const N: u64> FromRequest<S, B> for ContentLengthLimit<T, N>
where
T: FromRequest<S, B, Rejection = R> + FromRequest<S, Limited<B>, Rejection = R>,
R: IntoResponse + Send,
B: Send + 'static,
S: Send + Sync,
{
type Rejection = ContentLengthLimitRejection<R>;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let (parts, body) = req.into_parts();
let value = if let Some(err) = validate::<N>(&parts).err() {
match err {
RequestValidationError::LengthRequiredStream => {
// `Limited` supports limiting streams, so use that instead since this is a
// streaming request
let body = Limited::new(body, N as usize);
let req = Request::from_parts(parts, body);
T::from_request(req, state)
.await
.map_err(ContentLengthLimitRejection::Inner)?
}
other => return Err(other.into()),
}
} else {
let req = Request::from_parts(parts, body);
T::from_request(req, state)
.await
.map_err(ContentLengthLimitRejection::Inner)?
};
Ok(Self(value))
}
}
#[async_trait]
impl<T, S, const N: u64> FromRequestParts<S> for ContentLengthLimit<T, N>
where
T: FromRequestParts<S>,
T::Rejection: IntoResponse,
S: Send + Sync,
{
type Rejection = ContentLengthLimitRejection<T::Rejection>;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
validate::<N>(parts)?;
let value = T::from_request_parts(parts, state)
.await
.map_err(ContentLengthLimitRejection::Inner)?;
Ok(Self(value))
}
}
fn validate<const N: u64>(parts: &Parts) -> Result<(), RequestValidationError> {
let content_length = parts
.headers
.get(http::header::CONTENT_LENGTH)
.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
match (content_length, &parts.method) {
(content_length, &(Method::GET | Method::HEAD | Method::OPTIONS)) => {
if content_length.is_some() {
return Err(RequestValidationError::ContentLengthNotAllowed);
} else if parts
.headers
.get(http::header::TRANSFER_ENCODING)
.map_or(false, |value| value.as_bytes() == b"chunked")
{
return Err(RequestValidationError::LengthRequiredChunkedHeadOrGet);
}
}
(Some(content_length), _) if content_length > N => {
return Err(RequestValidationError::PayloadTooLarge);
}
(None, _) => {
return Err(RequestValidationError::LengthRequiredStream);
}
_ => {}
}
Ok(())
}
impl<T, const N: u64> Deref for ContentLengthLimit<T, N> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
/// Similar to `ContentLengthLimitRejection` but more fine grained in that we can tell the
/// difference between `LengthRequiredStream` and `LengthRequiredChunkedHeadOrGet`
enum RequestValidationError {
PayloadTooLarge,
LengthRequiredStream,
LengthRequiredChunkedHeadOrGet,
ContentLengthNotAllowed,
}
impl<T> From<RequestValidationError> for ContentLengthLimitRejection<T> {
fn from(inner: RequestValidationError) -> Self {
match inner {
RequestValidationError::PayloadTooLarge => Self::PayloadTooLarge(PayloadTooLarge),
RequestValidationError::LengthRequiredStream
| RequestValidationError::LengthRequiredChunkedHeadOrGet => {
Self::LengthRequired(LengthRequired)
}
RequestValidationError::ContentLengthNotAllowed => {
Self::ContentLengthNotAllowed(ContentLengthNotAllowed)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
body::Bytes,
routing::{get, post},
test_helpers::*,
Router,
};
use http::StatusCode;
use serde::Deserialize;
#[tokio::test]
async fn body_with_length_limit() {
use std::iter::repeat;
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct Input {
foo: String,
}
const LIMIT: u64 = 8;
let app = Router::new().route(
"/",
post(|_body: ContentLengthLimit<Bytes, LIMIT>| async {}),
);
let client = TestClient::new(app);
let res = client
.post("/")
.body(repeat(0_u8).take((LIMIT - 1) as usize).collect::<Vec<_>>())
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
let res = client
.post("/")
.body(repeat(0_u8).take(LIMIT as usize).collect::<Vec<_>>())
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
let res = client
.post("/")
.body(repeat(0_u8).take((LIMIT + 1) as usize).collect::<Vec<_>>())
.send()
.await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
let chunk = repeat(0_u8).take(LIMIT as usize).collect::<Bytes>();
let res = client
.post("/")
.body(reqwest::Body::wrap_stream(futures_util::stream::iter(
vec![Ok::<_, std::io::Error>(chunk)],
)))
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
let chunk = repeat(0_u8).take((LIMIT + 1) as usize).collect::<Bytes>();
let res = client
.post("/")
.body(reqwest::Body::wrap_stream(futures_util::stream::iter(
vec![Ok::<_, std::io::Error>(chunk)],
)))
.send()
.await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
#[tokio::test]
async fn get_request_without_content_length_is_accepted() {
let app = Router::new().route("/", get(|_body: ContentLengthLimit<Bytes, 1337>| async {}));
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn get_request_with_content_length_is_rejected() {
let app = Router::new().route("/", get(|_body: ContentLengthLimit<Bytes, 1337>| async {}));
let client = TestClient::new(app);
let res = client
.get("/")
.header("content-length", 3)
.body("foo")
.send()
.await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn get_request_with_chunked_encoding_is_rejected() {
let app = Router::new().route("/", get(|_body: ContentLengthLimit<Bytes, 1337>| async {}));
let client = TestClient::new(app);
let res = client
.get("/")
.header("transfer-encoding", "chunked")
.body("3\r\nfoo\r\n0\r\n\r\n")
.send()
.await;
assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED);
}
}

View file

@ -9,7 +9,6 @@ pub mod rejection;
#[cfg(feature = "ws")]
pub mod ws;
mod content_length_limit;
mod host;
mod raw_query;
mod request_parts;
@ -25,7 +24,6 @@ pub use axum_macros::{FromRequest, FromRequestParts};
#[allow(deprecated)]
pub use self::{
connect_info::ConnectInfo,
content_length_limit::ContentLengthLimit,
host::Host,
path::Path,
raw_query::RawQuery,

View file

@ -49,7 +49,8 @@ use std::{
/// ```
///
/// For security reasons it's recommended to combine this with
/// [`ContentLengthLimit`](super::ContentLengthLimit) to limit the size of the request payload.
/// [`RequestBodyLimitLayer`](tower_http::limit::RequestBodyLimitLayer)
/// to limit the size of the request payload.
#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
#[derive(Debug)]
pub struct Multipart {

View file

@ -47,30 +47,6 @@ define_rejection! {
pub struct MissingExtension(Error);
}
define_rejection! {
#[status = PAYLOAD_TOO_LARGE]
#[body = "Request payload is too large"]
/// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if
/// the request body is too large.
pub struct PayloadTooLarge;
}
define_rejection! {
#[status = LENGTH_REQUIRED]
#[body = "Content length header is required"]
/// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if
/// the request is missing the `Content-Length` header or it is invalid.
pub struct LengthRequired;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`GET`, `HEAD`, `OPTIONS` requests are not allowed to have a `Content-Length` header"]
/// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if
/// the request is `GET`, `HEAD`, or `OPTIONS` and has a `Content-Length` header.
pub struct ContentLengthNotAllowed;
}
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "No paths parameters found for matched route"]
@ -216,64 +192,5 @@ composite_rejection! {
}
}
/// Rejection used for [`ContentLengthLimit`](super::ContentLengthLimit).
///
/// Contains one variant for each way the
/// [`ContentLengthLimit`](super::ContentLengthLimit) extractor can fail.
#[derive(Debug)]
#[non_exhaustive]
pub enum ContentLengthLimitRejection<T> {
#[allow(missing_docs)]
PayloadTooLarge(PayloadTooLarge),
#[allow(missing_docs)]
LengthRequired(LengthRequired),
#[allow(missing_docs)]
ContentLengthNotAllowed(ContentLengthNotAllowed),
#[allow(missing_docs)]
Inner(T),
}
impl<T> IntoResponse for ContentLengthLimitRejection<T>
where
T: IntoResponse,
{
fn into_response(self) -> Response {
match self {
Self::PayloadTooLarge(inner) => inner.into_response(),
Self::LengthRequired(inner) => inner.into_response(),
Self::ContentLengthNotAllowed(inner) => inner.into_response(),
Self::Inner(inner) => inner.into_response(),
}
}
}
impl<T> std::fmt::Display for ContentLengthLimitRejection<T>
where
T: std::fmt::Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::PayloadTooLarge(inner) => inner.fmt(f),
Self::LengthRequired(inner) => inner.fmt(f),
Self::ContentLengthNotAllowed(inner) => inner.fmt(f),
Self::Inner(inner) => inner.fmt(f),
}
}
}
impl<T> std::error::Error for ContentLengthLimitRejection<T>
where
T: std::error::Error + 'static,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::PayloadTooLarge(inner) => Some(inner),
Self::LengthRequired(inner) => Some(inner),
Self::ContentLengthNotAllowed(inner) => Some(inner),
Self::Inner(inner) => Some(inner),
}
}
}
#[cfg(feature = "headers")]
pub use crate::typed_header::{TypedHeaderRejection, TypedHeaderRejectionReason};

View file

@ -8,6 +8,12 @@ publish = false
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.4", features = ["util", "timeout", "load-shed", "limit"] }
tower-http = { version = "0.3.0", features = ["add-extension", "auth", "compression-full", "trace"] }
tower-http = { version = "0.3.0", features = [
"add-extension",
"auth",
"compression-full",
"limit",
"trace",
] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View file

@ -9,7 +9,7 @@
use axum::{
body::Bytes,
error_handling::HandleErrorLayer,
extract::{ContentLengthLimit, Path, State},
extract::{DefaultBodyLimit, Path, State},
handler::Handler,
http::StatusCode,
response::IntoResponse,
@ -25,7 +25,8 @@ use std::{
};
use tower::{BoxError, ServiceBuilder};
use tower_http::{
auth::RequireAuthorizationLayer, compression::CompressionLayer, trace::TraceLayer,
auth::RequireAuthorizationLayer, compression::CompressionLayer, limit::RequestBodyLimitLayer,
trace::TraceLayer,
};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
@ -48,7 +49,12 @@ async fn main() {
// Add compression to `kv_get`
get(kv_get.layer(CompressionLayer::new()))
// But don't compress `kv_set`
.post(kv_set),
.post_service(
ServiceBuilder::new()
.layer(DefaultBodyLimit::disable())
.layer(RequestBodyLimitLayer::new(1024 * 5_000 /* ~5mb */))
.service(kv_set.with_state(Arc::clone(&shared_state))),
),
)
.route("/keys", get(list_keys))
// Nest our admin routes under `/admin`
@ -94,11 +100,7 @@ async fn kv_get(
}
}
async fn kv_set(
Path(key): Path<String>,
State(state): State<SharedState>,
ContentLengthLimit(bytes): ContentLengthLimit<Bytes, { 1024 * 5_000 }>, // ~5mb
) {
async fn kv_set(Path(key): Path<String>, State(state): State<SharedState>, bytes: Bytes) {
state.write().unwrap().db.insert(key, bytes);
}

View file

@ -7,6 +7,6 @@ publish = false
[dependencies]
axum = { path = "../../axum", features = ["multipart"] }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.3.0", features = ["trace"] }
tower-http = { version = "0.3.0", features = ["limit", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View file

@ -5,12 +5,13 @@
//! ```
use axum::{
extract::{ContentLengthLimit, Multipart},
extract::{DefaultBodyLimit, Multipart},
response::Html,
routing::get,
Router,
};
use std::net::SocketAddr;
use tower_http::limit::RequestBodyLimitLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
@ -26,6 +27,10 @@ async fn main() {
// build our application with some routes
let app = Router::new()
.route("/", get(show_form).post(accept_form))
.layer(DefaultBodyLimit::disable())
.layer(RequestBodyLimitLayer::new(
250 * 1024 * 1024, /* 250mb */
))
.layer(tower_http::trace::TraceLayer::new_for_http());
// run it with hyper
@ -58,14 +63,7 @@ async fn show_form() -> Html<&'static str> {
)
}
async fn accept_form(
ContentLengthLimit(mut multipart): ContentLengthLimit<
Multipart,
{
250 * 1024 * 1024 /* 250mb */
},
>,
) {
async fn accept_form(mut multipart: Multipart) {
while let Some(field) = multipart.next_field().await.unwrap() {
let name = field.name().unwrap().to_string();
let file_name = field.file_name().unwrap().to_string();