mirror of
https://github.com/tokio-rs/axum.git
synced 2025-04-26 13:56:22 +02:00
Implement IntoResponse
for MultipartError
(#1861)
This commit is contained in:
parent
8e1eb8979f
commit
03e8bc77f1
5 changed files with 165 additions and 5 deletions
|
@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning].
|
|||
|
||||
# Unreleased
|
||||
|
||||
- None.
|
||||
- **added:** Implement `IntoResponse` for `MultipartError` ([#1861])
|
||||
|
||||
# 0.7.1 (13. March, 2023)
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ axum = { path = "../axum", version = "0.6.9", default-features = false }
|
|||
bytes = "1.1.0"
|
||||
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
|
||||
http = "0.2"
|
||||
http-body = "0.4.4"
|
||||
mime = "0.3"
|
||||
pin-project-lite = "0.2"
|
||||
tokio = "1.19"
|
||||
|
|
|
@ -12,9 +12,10 @@ use axum::{
|
|||
use futures_util::stream::Stream;
|
||||
use http::{
|
||||
header::{HeaderMap, CONTENT_TYPE},
|
||||
Request,
|
||||
Request, StatusCode,
|
||||
};
|
||||
use std::{
|
||||
error::Error,
|
||||
fmt,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
|
@ -246,6 +247,57 @@ impl MultipartError {
|
|||
fn from_multer(multer: multer::Error) -> Self {
|
||||
Self { source: multer }
|
||||
}
|
||||
|
||||
/// Get the response body text used for this rejection.
|
||||
pub fn body_text(&self) -> String {
|
||||
self.source.to_string()
|
||||
}
|
||||
|
||||
/// Get the status code used for this rejection.
|
||||
pub fn status(&self) -> http::StatusCode {
|
||||
status_code_from_multer_error(&self.source)
|
||||
}
|
||||
}
|
||||
|
||||
fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
|
||||
match err {
|
||||
multer::Error::UnknownField { .. }
|
||||
| multer::Error::IncompleteFieldData { .. }
|
||||
| multer::Error::IncompleteHeaders
|
||||
| multer::Error::ReadHeaderFailed(..)
|
||||
| multer::Error::DecodeHeaderName { .. }
|
||||
| multer::Error::DecodeContentType(..)
|
||||
| multer::Error::NoBoundary
|
||||
| multer::Error::DecodeHeaderValue { .. }
|
||||
| multer::Error::NoMultipart
|
||||
| multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
|
||||
multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
|
||||
StatusCode::PAYLOAD_TOO_LARGE
|
||||
}
|
||||
multer::Error::StreamReadFailed(err) => {
|
||||
if let Some(err) = err.downcast_ref::<multer::Error>() {
|
||||
return status_code_from_multer_error(err);
|
||||
}
|
||||
|
||||
if err
|
||||
.downcast_ref::<axum::Error>()
|
||||
.and_then(|err| err.source())
|
||||
.and_then(|err| err.downcast_ref::<http_body::LengthLimitError>())
|
||||
.is_some()
|
||||
{
|
||||
return StatusCode::PAYLOAD_TOO_LARGE;
|
||||
}
|
||||
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for MultipartError {
|
||||
fn into_response(self) -> Response {
|
||||
(self.status(), self.body_text()).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for MultipartError {
|
||||
|
@ -357,7 +409,9 @@ impl std::error::Error for InvalidBoundary {}
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::test_helpers::*;
|
||||
use axum::{body::Body, response::IntoResponse, routing::post, Router};
|
||||
use axum::{
|
||||
body::Body, extract::DefaultBodyLimit, response::IntoResponse, routing::post, Router,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn content_type_with_encoding() {
|
||||
|
@ -395,4 +449,28 @@ mod tests {
|
|||
async fn handler(_: Multipart) {}
|
||||
let _app: Router<(), http_body::Limited<Body>> = Router::new().route("/", post(handler));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn body_too_large() {
|
||||
const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
|
||||
|
||||
async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
|
||||
while let Some(field) = multipart.next_field().await? {
|
||||
field.bytes().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", post(handle))
|
||||
.layer(DefaultBodyLimit::max(BYTES.len() - 1));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let form =
|
||||
reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
|
||||
|
||||
let res = client.post("/").multipart(form).send().await;
|
||||
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
# Unreleased
|
||||
|
||||
- None.
|
||||
- **added:** Implement `IntoResponse` for `MultipartError` ([#1861])
|
||||
|
||||
[#1861]: https://github.com/tokio-rs/axum/pull/1861
|
||||
|
||||
# 0.6.11 (13. March, 2023)
|
||||
|
||||
|
|
|
@ -6,10 +6,12 @@ use super::{BodyStream, FromRequest};
|
|||
use crate::body::{Bytes, HttpBody};
|
||||
use crate::BoxError;
|
||||
use async_trait::async_trait;
|
||||
use axum_core::response::{IntoResponse, Response};
|
||||
use axum_core::RequestExt;
|
||||
use futures_util::stream::Stream;
|
||||
use http::header::{HeaderMap, CONTENT_TYPE};
|
||||
use http::Request;
|
||||
use http::{Request, StatusCode};
|
||||
use std::error::Error;
|
||||
use std::{
|
||||
fmt,
|
||||
pin::Pin,
|
||||
|
@ -209,6 +211,51 @@ impl MultipartError {
|
|||
fn from_multer(multer: multer::Error) -> Self {
|
||||
Self { source: multer }
|
||||
}
|
||||
|
||||
/// Get the response body text used for this rejection.
|
||||
pub fn body_text(&self) -> String {
|
||||
self.source.to_string()
|
||||
}
|
||||
|
||||
/// Get the status code used for this rejection.
|
||||
pub fn status(&self) -> http::StatusCode {
|
||||
status_code_from_multer_error(&self.source)
|
||||
}
|
||||
}
|
||||
|
||||
fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
|
||||
match err {
|
||||
multer::Error::UnknownField { .. }
|
||||
| multer::Error::IncompleteFieldData { .. }
|
||||
| multer::Error::IncompleteHeaders
|
||||
| multer::Error::ReadHeaderFailed(..)
|
||||
| multer::Error::DecodeHeaderName { .. }
|
||||
| multer::Error::DecodeContentType(..)
|
||||
| multer::Error::NoBoundary
|
||||
| multer::Error::DecodeHeaderValue { .. }
|
||||
| multer::Error::NoMultipart
|
||||
| multer::Error::IncompleteStream => StatusCode::BAD_REQUEST,
|
||||
multer::Error::FieldSizeExceeded { .. } | multer::Error::StreamSizeExceeded { .. } => {
|
||||
StatusCode::PAYLOAD_TOO_LARGE
|
||||
}
|
||||
multer::Error::StreamReadFailed(err) => {
|
||||
if let Some(err) = err.downcast_ref::<multer::Error>() {
|
||||
return status_code_from_multer_error(err);
|
||||
}
|
||||
|
||||
if err
|
||||
.downcast_ref::<crate::Error>()
|
||||
.and_then(|err| err.source())
|
||||
.and_then(|err| err.downcast_ref::<http_body::LengthLimitError>())
|
||||
.is_some()
|
||||
{
|
||||
return StatusCode::PAYLOAD_TOO_LARGE;
|
||||
}
|
||||
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for MultipartError {
|
||||
|
@ -223,6 +270,12 @@ impl std::error::Error for MultipartError {
|
|||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for MultipartError {
|
||||
fn into_response(self) -> Response {
|
||||
(self.status(), self.body_text()).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_boundary(headers: &HeaderMap) -> Option<String> {
|
||||
let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
|
||||
multer::parse_boundary(content_type).ok()
|
||||
|
@ -247,6 +300,8 @@ define_rejection! {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use axum_core::extract::DefaultBodyLimit;
|
||||
|
||||
use super::*;
|
||||
use crate::{body::Body, response::IntoResponse, routing::post, test_helpers::*, Router};
|
||||
|
||||
|
@ -286,4 +341,28 @@ mod tests {
|
|||
async fn handler(_: Multipart) {}
|
||||
let _app: Router<(), http_body::Limited<Body>> = Router::new().route("/", post(handler));
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn body_too_large() {
|
||||
const BYTES: &[u8] = "<!doctype html><title>🦀</title>".as_bytes();
|
||||
|
||||
async fn handle(mut multipart: Multipart) -> Result<(), MultipartError> {
|
||||
while let Some(field) = multipart.next_field().await? {
|
||||
field.bytes().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", post(handle))
|
||||
.layer(DefaultBodyLimit::max(BYTES.len() - 1));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let form =
|
||||
reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES));
|
||||
|
||||
let res = client.post("/").multipart(form).send().await;
|
||||
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue