Add RequestExt::{with_limited_body, into_limited_body} (#1420)

* Move RequestExt and RequestPartsExt into axum-core

* Add RequestExt::into_limited_body

… and use it for Bytes extraction.

* Add RequestExt::with_limited_body

… and use it for Multipart extraction.

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
Jonas Platte 2022-09-28 22:20:47 +02:00 committed by GitHub
parent be54583d98
commit b94248191e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 87 additions and 44 deletions

View file

@ -1,15 +1,34 @@
pub(crate) mod request;
pub(crate) mod request_parts;
pub(crate) mod service;
#[cfg(test)]
mod tests {
use std::convert::Infallible;
use crate::extract::{FromRef, FromRequestParts};
use async_trait::async_trait;
use axum_core::extract::{FromRef, FromRequestParts};
use http::request::Parts;
#[derive(Debug, Default, Clone, Copy)]
pub(crate) struct State<S>(pub(crate) S);
#[async_trait]
impl<OuterState, InnerState> FromRequestParts<OuterState> for State<InnerState>
where
InnerState: FromRef<OuterState>,
OuterState: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
_parts: &mut Parts,
state: &OuterState,
) -> Result<Self, Self::Rejection> {
let inner_state = InnerState::from_ref(state);
Ok(Self(inner_state))
}
}
// some extractor that requires the state, such as `SignedCookieJar`
pub(crate) struct RequiresState(pub(crate) String);

View file

@ -1,6 +1,7 @@
use axum_core::extract::{FromRequest, FromRequestParts};
use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts};
use futures_util::future::BoxFuture;
use http::Request;
use http_body::Limited;
mod sealed {
pub trait Sealed<B> {}
@ -48,6 +49,16 @@ pub trait RequestExt<B>: sealed::Sealed<B> + Sized {
where
E: FromRequestParts<S> + 'static,
S: Send + Sync;
/// Apply the [default body limit](crate::extract::DefaultBodyLimit).
///
/// If it is disabled, return the request as-is in `Err`.
fn with_limited_body(self) -> Result<Request<Limited<B>>, Request<B>>;
/// Consumes the request, returning the body wrapped in [`Limited`] if a
/// [default limit](crate::extract::DefaultBodyLimit) is in place, or not wrapped if the
/// default limit is disabled.
fn into_limited_body(self) -> Result<Limited<B>, B>;
}
impl<B> RequestExt<B> for Request<B>
@ -105,14 +116,36 @@ where
result
})
}
fn with_limited_body(self) -> Result<Request<Limited<B>>, Request<B>> {
// update docs in `axum-core/src/extract/default_body_limit.rs` and
// `axum/src/docs/extract.md` if this changes
const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb
match self.extensions().get::<DefaultBodyLimitKind>().copied() {
Some(DefaultBodyLimitKind::Disable) => Err(self),
Some(DefaultBodyLimitKind::Limit(limit)) => {
Ok(self.map(|b| http_body::Limited::new(b, limit)))
}
None => Ok(self.map(|b| http_body::Limited::new(b, DEFAULT_LIMIT))),
}
}
fn into_limited_body(self) -> Result<Limited<B>, B> {
self.with_limited_body()
.map(Request::into_body)
.map_err(Request::into_body)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ext_traits::tests::RequiresState, extract::State};
use crate::{
ext_traits::tests::{RequiresState, State},
extract::FromRef,
};
use async_trait::async_trait;
use axum_core::extract::FromRef;
use http::Method;
use hyper::Body;

View file

@ -1,4 +1,4 @@
use axum_core::extract::FromRequestParts;
use crate::extract::FromRequestParts;
use futures_util::future::BoxFuture;
use http::request::Parts;
@ -53,9 +53,11 @@ mod tests {
use std::convert::Infallible;
use super::*;
use crate::{ext_traits::tests::RequiresState, extract::State};
use crate::{
ext_traits::tests::{RequiresState, State},
extract::FromRef,
};
use async_trait::async_trait;
use axum_core::extract::FromRef;
use http::{Method, Request};
#[tokio::test]
@ -73,7 +75,10 @@ mod tests {
let state = "state".to_owned();
let State(extracted_state): State<String> = parts.extract_with_state(&state).await.unwrap();
let State(extracted_state): State<String> = parts
.extract_with_state::<State<String>, String>(&state)
.await
.unwrap();
assert_eq!(extracted_state, state);
}

View file

@ -16,6 +16,7 @@ mod from_ref;
mod request_parts;
mod tuple;
pub(crate) use self::default_body_limit::DefaultBodyLimitKind;
pub use self::{default_body_limit::DefaultBodyLimit, from_ref::FromRef};
mod private {

View file

@ -1,7 +1,5 @@
use super::{
default_body_limit::DefaultBodyLimitKind, rejection::*, FromRequest, FromRequestParts,
};
use crate::BoxError;
use super::{rejection::*, FromRequest, FromRequestParts};
use crate::{BoxError, RequestExt};
use async_trait::async_trait;
use bytes::Bytes;
use http::{request::Parts, HeaderMap, Method, Request, Uri, Version};
@ -84,27 +82,13 @@ where
type Rejection = BytesRejection;
async fn from_request(req: Request<B>, _: &S) -> Result<Self, Self::Rejection> {
// update docs in `axum-core/src/extract/default_body_limit.rs` and
// `axum/src/docs/extract.md` if this changes
const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb
let limit_kind = req.extensions().get::<DefaultBodyLimitKind>().copied();
let bytes = match limit_kind {
Some(DefaultBodyLimitKind::Disable) => crate::body::to_bytes(req.into_body())
let bytes = match req.into_limited_body() {
Ok(limited_body) => crate::body::to_bytes(limited_body)
.await
.map_err(FailedToBufferBody::from_err)?,
Err(unlimited_body) => crate::body::to_bytes(unlimited_body)
.await
.map_err(FailedToBufferBody::from_err)?,
Some(DefaultBodyLimitKind::Limit(limit)) => {
let body = http_body::Limited::new(req.into_body(), limit);
crate::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)?
}
None => {
let body = http_body::Limited::new(req.into_body(), DEFAULT_LIMIT);
crate::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)?
}
};
Ok(bytes)

View file

@ -52,6 +52,7 @@
pub(crate) mod macros;
mod error;
mod ext_traits;
pub use self::error::Error;
pub mod body;
@ -60,3 +61,5 @@ pub mod response;
/// Alias for a type-erased error type.
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
pub use self::ext_traits::{request::RequestExt, request_parts::RequestPartsExt};

View file

@ -40,6 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
you likely need to re-enable the `tokio` feature ([#1382])
- **breaking:** `handler::{WithState, IntoService}` are merged into one type,
named `HandlerService` ([#1418])
- **changed:** The default body limit now applies to the `Multipart` extractor ([#1420])
- **added:** String and binary `From` impls have been added to `extract::ws::Message`
to be more inline with `tungstenite` ([#1421])
@ -54,6 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1408]: https://github.com/tokio-rs/axum/pull/1408
[#1414]: https://github.com/tokio-rs/axum/pull/1414
[#1418]: https://github.com/tokio-rs/axum/pull/1418
[#1420]: https://github.com/tokio-rs/axum/pull/1420
[#1421]: https://github.com/tokio-rs/axum/pull/1421
# 0.6.0-rc.2 (10. September, 2022)

View file

@ -6,6 +6,7 @@ use super::{BodyStream, FromRequest};
use crate::body::{Bytes, HttpBody};
use crate::BoxError;
use async_trait::async_trait;
use axum_core::RequestExt;
use futures_util::stream::Stream;
use http::header::{HeaderMap, CONTENT_TYPE};
use http::Request;
@ -47,10 +48,6 @@ use std::{
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// For security reasons it's recommended to combine this with
/// [`RequestBodyLimitLayer`](tower_http::limit::RequestBodyLimitLayer)
/// to limit the size of the request payload.
#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
#[derive(Debug)]
pub struct Multipart {
@ -69,10 +66,11 @@ where
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
let stream = match BodyStream::from_request(req, state).await {
Ok(stream) => stream,
Err(err) => match err {},
let stream_result = match req.with_limited_body() {
Ok(limited) => BodyStream::from_request(limited, state).await,
Err(unlimited) => BodyStream::from_request(unlimited, state).await,
};
let stream = stream_result.unwrap_or_else(|err| match err {});
let multipart = multer::Multipart::new(stream, boundary);
Ok(Self { inner: multipart })
}

View file

@ -434,12 +434,12 @@
#[macro_use]
pub(crate) mod macros;
mod ext_traits;
mod extension;
#[cfg(feature = "form")]
mod form;
#[cfg(feature = "json")]
mod json;
mod service_ext;
#[cfg(feature = "headers")]
mod typed_header;
mod util;
@ -483,11 +483,9 @@ pub use self::typed_header::TypedHeader;
pub use self::form::Form;
#[doc(inline)]
pub use axum_core::{BoxError, Error};
pub use axum_core::{BoxError, Error, RequestExt, RequestPartsExt};
#[cfg(feature = "macros")]
pub use axum_macros::debug_handler;
pub use self::ext_traits::{
request::RequestExt, request_parts::RequestPartsExt, service::ServiceExt,
};
pub use self::service_ext::ServiceExt;