Add Multipart extractor for consuming multipart/form-data requests (#32)

Multipart implementation on top of `multer`
This commit is contained in:
David Pedersen 2021-07-14 16:53:37 +02:00 committed by GitHub
parent e641caefaf
commit 028c472c84
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 358 additions and 103 deletions

View file

@ -15,6 +15,7 @@ version = "0.1.0"
[features]
default = []
ws = ["tokio-tungstenite", "sha-1", "base64"]
multipart = ["multer", "mime"]
[dependencies]
async-trait = "0.1"
@ -38,6 +39,8 @@ tokio-tungstenite = { optional = true, version = "0.14" }
sha-1 = { optional = true, version = "0.9.6" }
base64 = { optional = true, version = "0.13" }
headers = { optional = true, version = "0.3" }
multer = { optional = true, version = "2.0.0" }
mime = { optional = true, version = "0.3" }
[dev-dependencies]
askama = "0.10.5"

View file

@ -0,0 +1,59 @@
use axum::{
extract::{ContentLengthLimit, Multipart},
prelude::*,
};
use std::net::SocketAddr;
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
// build our application with some routes
let app = route("/", get(show_form).post(accept_form))
.layer(tower_http::trace::TraceLayer::new_for_http());
// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
hyper::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
async fn show_form() -> response::Html<&'static str> {
response::Html(
r#"
<!doctype html>
<html>
<head></head>
<body>
<form action="/" method="post" enctype="multipart/form-data">
<label>
Upload file:
<input type="file" name="file" multiple>
</label>
<input type="submit" value="Upload files">
</form>
</body>
</html>
"#,
)
}
async fn accept_form(
ContentLengthLimit(mut multipart): ContentLengthLimit<
Multipart,
{
250 * 1024 * 1024 /* 250mb */
},
>,
) {
while let Some(field) = multipart.next_field().await.unwrap() {
let name = field.name().unwrap().to_string();
let data = field.bytes().await.unwrap();
println!("Length of `{}` is {} bytes", name, data.len());
}
}

View file

@ -171,16 +171,11 @@
//! # };
//! ```
use crate::{
body::{BoxBody, BoxStdError},
response::IntoResponse,
util::ByteStr,
};
use crate::{response::IntoResponse, util::ByteStr};
use async_trait::async_trait;
use bytes::{Buf, Bytes};
use futures_util::stream::Stream;
use http::{header, HeaderMap, Method, Request, Uri, Version};
use http_body::Body;
use rejection::*;
use serde::de::DeserializeOwned;
use std::{
@ -198,6 +193,15 @@ pub mod rejection;
#[doc(inline)]
pub use self::extractor_middleware::extractor_middleware;
#[cfg(feature = "multipart")]
#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
pub mod multipart;
#[cfg(feature = "multipart")]
#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
#[doc(inline)]
pub use self::multipart::Multipart;
/// Types that can be created from requests.
///
/// See the [module docs](crate::extract) for more details.
@ -552,10 +556,13 @@ where
/// # };
/// ```
#[derive(Debug)]
pub struct BodyStream(BoxBody);
pub struct BodyStream<B = crate::body::Body>(B);
impl Stream for BodyStream {
type Item = Result<Bytes, BoxStdError>;
impl<B> Stream for BodyStream<B>
where
B: http_body::Body + Unpin,
{
type Item = Result<B::Data, B::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.0).poll_data(cx)
@ -563,17 +570,15 @@ impl Stream for BodyStream {
}
#[async_trait]
impl<B> FromRequest<B> for BodyStream
impl<B> FromRequest<B> for BodyStream<B>
where
B: http_body::Body<Data = Bytes> + Default + Send + Sync + 'static,
B::Data: Send,
B::Error: Into<tower::BoxError>,
B: http_body::Body + Default + Unpin + Send,
{
type Rejection = BodyAlreadyExtracted;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?;
let stream = BodyStream(BoxBody::new(body));
let stream = BodyStream(body);
Ok(stream)
}
}

187
src/extract/multipart.rs Normal file
View file

@ -0,0 +1,187 @@
//! Extractor that parses `multipart/form-data` requests commonly used with file uploads.
//!
//! See [`Multipart`] for more details.
use super::{rejection::*, BodyStream, FromRequest};
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::stream::Stream;
use http::header::{HeaderMap, CONTENT_TYPE};
use mime::Mime;
use std::{
fmt,
pin::Pin,
task::{Context, Poll},
};
use tower::BoxError;
/// Extractor that parses `multipart/form-data` requests commonly used with file uploads.
///
/// # Example
///
/// ```rust,no_run
/// use axum::prelude::*;
/// use futures::stream::StreamExt;
///
/// async fn upload(mut multipart: extract::Multipart) {
/// while let Some(mut field) = multipart.next_field().await.unwrap() {
/// let name = field.name().unwrap().to_string();
/// let data = field.bytes().await.unwrap();
///
/// println!("Length of `{}` is {} bytes", name, data.len());
/// }
/// }
///
/// let app = route("/upload", post(upload));
/// # async {
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// For security reasons its recommended to combine this with
/// [`ContentLengthLimit`](super::ContentLengthLimit) to limit the size of the request payload.
#[derive(Debug)]
pub struct Multipart {
inner: multer::Multipart<'static>,
}
#[async_trait]
impl<B> FromRequest<B> for Multipart
where
B: http_body::Body<Data = Bytes> + Default + Unpin + Send + 'static,
B::Error: Into<BoxError> + 'static,
{
type Rejection = MultipartRejection;
async fn from_request(req: &mut http::Request<B>) -> Result<Self, Self::Rejection> {
let stream = BodyStream::from_request(req).await?;
let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
let multipart = multer::Multipart::new(stream, boundary);
Ok(Self { inner: multipart })
}
}
impl Multipart {
/// Yields the next [`Field`] if available.
pub async fn next_field(&mut self) -> Result<Option<Field<'_>>, MultipartError> {
let field = self
.inner
.next_field()
.await
.map_err(MultipartError::from_multer)?;
if let Some(field) = field {
Ok(Some(Field {
inner: field,
_multipart: self,
}))
} else {
Ok(None)
}
}
}
/// A single field in a multipart stream.
#[derive(Debug)]
pub struct Field<'a> {
inner: multer::Field<'static>,
// multer requires there to only be one live `multer::Field` at any point. This enforces that
// statically, which multer does not do, it returns an error instead.
_multipart: &'a mut Multipart,
}
impl<'a> Stream for Field<'a> {
type Item = Result<Bytes, MultipartError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner)
.poll_next(cx)
.map_err(MultipartError::from_multer)
}
}
impl<'a> Field<'a> {
/// The field name found in the
/// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition)
/// header.
pub fn name(&self) -> Option<&str> {
self.inner.name()
}
/// The file name found in the
/// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition)
/// header.
pub fn file_name(&self) -> Option<&str> {
self.inner.file_name()
}
/// Get the content type of the field.
pub fn content_type(&self) -> Option<&Mime> {
self.inner.content_type()
}
/// Get a map of headers as [`HeaderMap`].
pub fn headers(&self) -> &HeaderMap {
self.inner.headers()
}
/// Get the full data of the field as [`Bytes`].
pub async fn bytes(self) -> Result<Bytes, MultipartError> {
self.inner
.bytes()
.await
.map_err(MultipartError::from_multer)
}
/// Get the full field data as text.
pub async fn text(self) -> Result<String, MultipartError> {
self.inner.text().await.map_err(MultipartError::from_multer)
}
}
/// Errors associated with parsing `multipart/form-data` requests.
#[derive(Debug)]
pub struct MultipartError {
source: multer::Error,
}
impl MultipartError {
fn from_multer(multer: multer::Error) -> Self {
Self { source: multer }
}
}
impl fmt::Display for MultipartError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Error parsing `multipart/form-data` request")
}
}
impl std::error::Error for MultipartError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.source)
}
}
fn parse_boundary(headers: &HeaderMap) -> Option<String> {
let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?;
multer::parse_boundary(content_type).ok()
}
composite_rejection! {
/// Rejection used for [`Multipart`].
///
/// Contains one variant for each way the [`Multipart`] extractor can fail.
pub enum MultipartRejection {
BodyAlreadyExtracted,
InvalidBoundary,
}
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Invalid `boundary` for `multipart/form-data` request"]
/// Rejection type used if the `boundary` in a `multipart/form-data` is
/// missing or invalid.
pub struct InvalidBoundary;
}

View file

@ -4,57 +4,6 @@ use super::IntoResponse;
use crate::body::Body;
use tower::BoxError;
macro_rules! define_rejection {
(
#[status = $status:ident]
#[body = $body:expr]
$(#[$m:meta])*
pub struct $name:ident;
) => {
$(#[$m])*
#[derive(Debug)]
#[non_exhaustive]
pub struct $name;
impl IntoResponse for $name {
fn into_response(self) -> http::Response<Body> {
let mut res = http::Response::new(Body::from($body));
*res.status_mut() = http::StatusCode::$status;
res
}
}
};
(
#[status = $status:ident]
#[body = $body:expr]
$(#[$m:meta])*
pub struct $name:ident (BoxError);
) => {
$(#[$m])*
#[derive(Debug)]
pub struct $name(pub(super) tower::BoxError);
impl $name {
pub(super) fn from_err<E>(err: E) -> Self
where
E: Into<tower::BoxError>,
{
Self(err.into())
}
}
impl IntoResponse for $name {
fn into_response(self) -> http::Response<Body> {
let mut res =
http::Response::new(Body::from(format!(concat!($body, ": {}"), self.0)));
*res.status_mut() = http::StatusCode::$status;
res
}
}
};
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Query string was invalid or missing"]
@ -205,44 +154,6 @@ impl IntoResponse for FailedToDeserializeQueryString {
}
}
macro_rules! composite_rejection {
(
$(#[$m:meta])*
pub enum $name:ident {
$($variant:ident),+
$(,)?
}
) => {
$(#[$m])*
#[derive(Debug)]
#[non_exhaustive]
pub enum $name {
$(
#[allow(missing_docs)]
$variant($variant)
),+
}
impl IntoResponse for $name {
fn into_response(self) -> http::Response<Body> {
match self {
$(
Self::$variant(inner) => inner.into_response(),
)+
}
}
}
$(
impl From<$variant> for $name {
fn from(inner: $variant) -> Self {
Self::$variant(inner)
}
}
)+
};
}
composite_rejection! {
/// Rejection used for [`Query`](super::Query).
///

View file

@ -546,6 +546,7 @@
//!
//! - `ws`: Enables WebSockets support.
//! - `headers`: Enables extracing typed headers via [`extract::TypedHeader`].
//! - `multipart`: Enables parsing `multipart/form-data` requests with [`extract::Multipart`].
//!
//! [tower]: https://crates.io/crates/tower
//! [tower-http]: https://crates.io/crates/tower-http

View file

@ -32,3 +32,92 @@ macro_rules! opaque_future {
}
};
}
macro_rules! define_rejection {
(
#[status = $status:ident]
#[body = $body:expr]
$(#[$m:meta])*
pub struct $name:ident;
) => {
$(#[$m])*
#[derive(Debug)]
#[non_exhaustive]
pub struct $name;
impl $crate::response::IntoResponse for $name {
fn into_response(self) -> http::Response<$crate::body::Body> {
let mut res = http::Response::new($crate::body::Body::from($body));
*res.status_mut() = http::StatusCode::$status;
res
}
}
};
(
#[status = $status:ident]
#[body = $body:expr]
$(#[$m:meta])*
pub struct $name:ident (BoxError);
) => {
$(#[$m])*
#[derive(Debug)]
pub struct $name(pub(super) tower::BoxError);
impl $name {
pub(super) fn from_err<E>(err: E) -> Self
where
E: Into<tower::BoxError>,
{
Self(err.into())
}
}
impl IntoResponse for $name {
fn into_response(self) -> http::Response<Body> {
let mut res =
http::Response::new(Body::from(format!(concat!($body, ": {}"), self.0)));
*res.status_mut() = http::StatusCode::$status;
res
}
}
};
}
macro_rules! composite_rejection {
(
$(#[$m:meta])*
pub enum $name:ident {
$($variant:ident),+
$(,)?
}
) => {
$(#[$m])*
#[derive(Debug)]
#[non_exhaustive]
pub enum $name {
$(
#[allow(missing_docs)]
$variant($variant)
),+
}
impl $crate::response::IntoResponse for $name {
fn into_response(self) -> http::Response<$crate::body::Body> {
match self {
$(
Self::$variant(inner) => inner.into_response(),
)+
}
}
}
$(
impl From<$variant> for $name {
fn from(inner: $variant) -> Self {
Self::$variant(inner)
}
}
)+
};
}