1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-04-26 13:56:22 +02:00

Parameterize ContentLengthLimit

This commit is contained in:
David Pedersen 2021-06-09 08:14:20 +02:00
parent 09f76f3c87
commit 90c3e5ba74
6 changed files with 21 additions and 18 deletions

View file

@ -23,7 +23,7 @@ use tower_http::{
};
use tower_web::{
body::{Body, BoxBody},
extract::{BytesMaxLength, Extension, UrlParams},
extract::{ContentLengthLimit, Extension, UrlParams},
prelude::*,
response::IntoResponse,
routing::BoxRoute,
@ -88,10 +88,10 @@ async fn kv_get(
async fn kv_set(
_req: Request<Body>,
UrlParams((key,)): UrlParams<(String,)>,
BytesMaxLength(value): BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
ContentLengthLimit(bytes): ContentLengthLimit<Bytes, { 1024 * 5_000 }>, // ~5mb
Extension(state): Extension<SharedState>,
) {
state.write().unwrap().db.insert(key, value);
state.write().unwrap().db.insert(key, bytes);
}
async fn list_keys(_req: Request<Body>, Extension(state): Extension<SharedState>) -> String {

View file

@ -388,14 +388,14 @@ impl FromRequest for Body {
}
}
/// Extractor that will buffer request bodies up to a certain size.
/// Extractor that will reject requests with a body larger than some size.
///
/// # Example
///
/// ```rust,no_run
/// use tower_web::prelude::*;
///
/// async fn handler(req: Request<Body>, body: extract::BytesMaxLength<1024>) {
/// async fn handler(req: Request<Body>, body: extract::ContentLengthLimit<String, 1024>) {
/// // ...
/// }
///
@ -404,15 +404,17 @@ impl FromRequest for Body {
///
/// This requires the request to have a `Content-Length` header.
#[derive(Debug, Clone)]
pub struct BytesMaxLength<const N: u64>(pub Bytes);
pub struct ContentLengthLimit<T, const N: u64>(pub T);
#[async_trait]
impl<const N: u64> FromRequest for BytesMaxLength<N> {
impl<T, const N: u64> FromRequest for ContentLengthLimit<T, N>
where
T: FromRequest,
{
type Rejection = Response<Body>;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
let body = take_body(req).map_err(|reject| reject.into_response())?;
let content_length =
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
@ -425,11 +427,11 @@ impl<const N: u64> FromRequest for BytesMaxLength<N> {
return Err(LengthRequired.into_response());
};
let bytes = hyper::body::to_bytes(body)
let value = T::from_request(req)
.await
.map_err(|e| FailedToBufferBody::from_err(e).into_response())?;
.map_err(IntoResponse::into_response)?;
Ok(BytesMaxLength(bytes))
Ok(Self(value))
}
}

View file

@ -103,16 +103,16 @@ define_rejection! {
define_rejection! {
#[status = PAYLOAD_TOO_LARGE]
#[body = "Request payload is too large"]
/// Rejection type for [`BytesMaxLength`](super::BytesMaxLength) if the
/// request body is too large.
/// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if
/// the request body is too large.
pub struct PayloadTooLarge;
}
define_rejection! {
#[status = LENGTH_REQUIRED]
#[body = "Content length header is required"]
/// Rejection type for [`BytesMaxLength`](super::BytesMaxLength) if the
/// request is missing the `Content-Length` header or it is invalid.
/// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if
/// the request is missing the `Content-Length` header or it is invalid.
pub struct LengthRequired;
}

View file

@ -1,8 +1,8 @@
//! Handler future types.
use crate::body::BoxBody;
use http::Response;
use std::convert::Infallible;
use crate::body::BoxBody;
opaque_future! {
/// The response future for [`IntoService`](super::IntoService).

View file

@ -50,8 +50,8 @@ use crate::{
service::HandleError,
};
use async_trait::async_trait;
use futures_util::future::Either;
use bytes::Bytes;
use futures_util::future::Either;
use http::{Request, Response};
use std::{
convert::Infallible,

View file

@ -1,4 +1,5 @@
use crate::{handler::on, prelude::*, routing::MethodFilter, service};
use bytes::Bytes;
use http::{Request, Response, StatusCode};
use hyper::{Body, Server};
use serde::Deserialize;
@ -137,7 +138,7 @@ async fn body_with_length_limit() {
let app = route(
"/",
post(
|req: Request<Body>, _body: extract::BytesMaxLength<LIMIT>| async move {
|req: Request<Body>, _body: extract::ContentLengthLimit<Bytes, LIMIT>| async move {
dbg!(&req);
},
),