diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index ed745ec8..9eb2affa 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning]. # Unreleased -- None. +- **added:** Add `Multipart`. This is similar to `axum::extract::Multipart` + except that it enforces field exclusivity at runtime instead of compile time, + as this improves usability. # 0.6.0 (24. February, 2022) diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index 87325b7c..dea1767b 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -29,6 +29,7 @@ json-lines = [ "tokio-util?/io", "tokio-stream?/io-util" ] +multipart = ["dep:multer"] protobuf = ["dep:prost"] query = ["dep:serde", "dep:serde_html_form"] spa = ["tower-http/fs"] @@ -51,6 +52,7 @@ tower-service = "0.3" axum-macros = { path = "../axum-macros", version = "0.3.4", optional = true } cookie = { package = "cookie", version = "0.17", features = ["percent-encode"], optional = true } form_urlencoded = { version = "1.1.0", optional = true } +multer = { version = "2.0.0", optional = true } percent-encoding = { version = "2.1", optional = true } prost = { version = "0.11", optional = true } serde = { version = "1.0", optional = true } @@ -62,6 +64,7 @@ tokio-util = { version = "0.7", optional = true } [dev-dependencies] axum = { path = "../axum", version = "0.6.0", features = ["headers"] } futures = "0.3" +http-body = "0.4.4" hyper = "0.14" reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "multipart"] } serde = { version = "1.0", features = ["derive"] } diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs index eaed0a4d..3ca7749e 100644 --- a/axum-extra/src/extract/mod.rs +++ b/axum-extra/src/extract/mod.rs @@ -11,6 +11,9 @@ pub mod cookie; #[cfg(feature = "query")] mod query; +#[cfg(feature = "multipart")] +pub mod multipart; + mod with_rejection; pub use self::cached::Cached; @@ -30,6 +33,9 @@ pub use self::form::{Form, FormRejection}; #[cfg(feature = "query")] pub use self::query::{Query, QueryRejection}; +#[cfg(feature = "multipart")] +pub use self::multipart::Multipart; + #[cfg(feature = "json-lines")] #[doc(no_inline)] pub use crate::json_lines::JsonLines; diff --git a/axum-extra/src/extract/multipart.rs b/axum-extra/src/extract/multipart.rs new file mode 100644 index 00000000..ff74e30e --- /dev/null +++ b/axum-extra/src/extract/multipart.rs @@ -0,0 +1,398 @@ +//! Extractor that parses `multipart/form-data` requests commonly used with file uploads. +//! +//! See [`Multipart`] for more details. + +use axum::{ + async_trait, + body::{Bytes, HttpBody}, + extract::{BodyStream, FromRequest}, + response::{IntoResponse, Response}, + BoxError, RequestExt, +}; +use futures_util::stream::Stream; +use http::{ + header::{HeaderMap, CONTENT_TYPE}, + Request, +}; +use std::{ + fmt, + pin::Pin, + task::{Context, Poll}, +}; + +/// Extractor that parses `multipart/form-data` requests (commonly used with file uploads). +/// +/// Since extracting multipart form data from the request requires consuming the body, the +/// `Multipart` extractor must be *last* if there are multiple extractors in a handler. +/// See ["the order of extractors"][order-of-extractors] +/// +/// [order-of-extractors]: crate::extract#the-order-of-extractors +/// +/// # Example +/// +/// ``` +/// use axum::{ +/// routing::post, +/// Router, +/// }; +/// use axum_extra::extract::Multipart; +/// +/// async fn upload(mut multipart: Multipart) { +/// while let Some(mut field) = multipart.next_field().await.unwrap() { +/// let name = field.name().unwrap().to_string(); +/// let data = field.bytes().await.unwrap(); +/// +/// println!("Length of `{}` is {} bytes", name, data.len()); +/// } +/// } +/// +/// let app = Router::new().route("/upload", post(upload)); +/// # let _: Router = app; +/// ``` +/// +/// # Field Exclusivity +/// +/// A [`Field`] represents a raw, self-decoding stream into multipart data. As such, only one +/// [`Field`] from a given Multipart instance may be live at once. That is, a [`Field`] emitted by +/// [`next_field()`] must be dropped before calling [`next_field()`] again. Failure to do so will +/// result in an error. +/// +/// ``` +/// use axum_extra::extract::Multipart; +/// +/// async fn handler(mut multipart: Multipart) { +/// let field_1 = multipart.next_field().await; +/// +/// // We cannot get the next field while `field_1` is still alive. Have to drop `field_1` +/// // first. +/// let field_2 = multipart.next_field().await; +/// assert!(field_2.is_err()); +/// } +/// ``` +/// +/// In general you should consume `Multipart` by looping over the fields in order and make sure not +/// to keep `Field`s around from previous loop iterations. That will mimimize the risk of runtime +/// errors. +/// +/// # Differences between this and `axum::extract::Multipart` +/// +/// `axum::extract::Multipart` uses lifetimes to enforce field exclusivity at compile time, however +/// that leads to significant usability issues such as `Field` not being `'static`. +/// +/// `axum_extra::extract::Multipart` instead enforces field exclusivity at runtime which makes +/// things easier to use at the cost of possible runtime errors. +/// +/// [`next_field()`]: Multipart::next_field +#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))] +#[derive(Debug)] +pub struct Multipart { + inner: multer::Multipart<'static>, +} + +#[async_trait] +impl FromRequest for Multipart +where + B: HttpBody + Send + 'static, + B::Data: Into, + B::Error: Into, + S: Send + Sync, +{ + type Rejection = MultipartRejection; + + async fn from_request(req: Request, state: &S) -> Result { + let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?; + let stream_result = match req.with_limited_body() { + Ok(limited) => BodyStream::from_request(limited, state).await, + Err(unlimited) => BodyStream::from_request(unlimited, state).await, + }; + let stream = stream_result.unwrap_or_else(|err| match err {}); + let multipart = multer::Multipart::new(stream, boundary); + Ok(Self { inner: multipart }) + } +} + +impl Multipart { + /// Yields the next [`Field`] if available. + pub async fn next_field(&mut self) -> Result, MultipartError> { + let field = self + .inner + .next_field() + .await + .map_err(MultipartError::from_multer)?; + + if let Some(field) = field { + Ok(Some(Field { inner: field })) + } else { + Ok(None) + } + } + + /// Convert the `Multipart` into a stream of its fields. + pub fn into_stream(self) -> impl Stream> + Send + 'static { + futures_util::stream::try_unfold(self, |mut multipart| async move { + let field = multipart.next_field().await?; + Ok(field.map(|field| (field, multipart))) + }) + } +} + +/// A single field in a multipart stream. +#[derive(Debug)] +pub struct Field { + inner: multer::Field<'static>, +} + +impl Stream for Field { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner) + .poll_next(cx) + .map_err(MultipartError::from_multer) + } +} + +impl Field { + /// The field name found in the + /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) + /// header. + pub fn name(&self) -> Option<&str> { + self.inner.name() + } + + /// The file name found in the + /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) + /// header. + pub fn file_name(&self) -> Option<&str> { + self.inner.file_name() + } + + /// Get the [content type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) of the field. + pub fn content_type(&self) -> Option<&str> { + self.inner.content_type().map(|m| m.as_ref()) + } + + /// Get a map of headers as [`HeaderMap`]. + pub fn headers(&self) -> &HeaderMap { + self.inner.headers() + } + + /// Get the full data of the field as [`Bytes`]. + pub async fn bytes(self) -> Result { + self.inner + .bytes() + .await + .map_err(MultipartError::from_multer) + } + + /// Get the full field data as text. + pub async fn text(self) -> Result { + self.inner.text().await.map_err(MultipartError::from_multer) + } + + /// Stream a chunk of the field data. + /// + /// When the field data has been exhausted, this will return [`None`]. + /// + /// Note this does the same thing as `Field`'s [`Stream`] implementation. + /// + /// # Example + /// + /// ``` + /// use axum::{ + /// routing::post, + /// response::IntoResponse, + /// http::StatusCode, + /// Router, + /// }; + /// use axum_extra::extract::Multipart; + /// + /// async fn upload(mut multipart: Multipart) -> Result<(), (StatusCode, String)> { + /// while let Some(mut field) = multipart + /// .next_field() + /// .await + /// .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))? + /// { + /// while let Some(chunk) = field + /// .chunk() + /// .await + /// .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))? + /// { + /// println!("received {} bytes", chunk.len()); + /// } + /// } + /// + /// Ok(()) + /// } + /// + /// let app = Router::new().route("/upload", post(upload)); + /// # let _: Router = app; + /// ``` + pub async fn chunk(&mut self) -> Result, MultipartError> { + self.inner + .chunk() + .await + .map_err(MultipartError::from_multer) + } +} + +/// Errors associated with parsing `multipart/form-data` requests. +#[derive(Debug)] +pub struct MultipartError { + source: multer::Error, +} + +impl MultipartError { + fn from_multer(multer: multer::Error) -> Self { + Self { source: multer } + } +} + +impl fmt::Display for MultipartError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Error parsing `multipart/form-data` request") + } +} + +impl std::error::Error for MultipartError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&self.source) + } +} + +fn parse_boundary(headers: &HeaderMap) -> Option { + let content_type = headers.get(CONTENT_TYPE)?.to_str().ok()?; + multer::parse_boundary(content_type).ok() +} + +/// Rejection used for [`Multipart`]. +/// +/// Contains one variant for each way the [`Multipart`] extractor can fail. +#[derive(Debug)] +#[non_exhaustive] +pub enum MultipartRejection { + #[allow(missing_docs)] + InvalidBoundary(InvalidBoundary), +} + +impl IntoResponse for MultipartRejection { + fn into_response(self) -> Response { + match self { + Self::InvalidBoundary(inner) => inner.into_response(), + } + } +} + +impl MultipartRejection { + /// Get the response body text used for this rejection. + pub fn body_text(&self) -> String { + match self { + Self::InvalidBoundary(inner) => inner.body_text(), + } + } + + /// Get the status code used for this rejection. + pub fn status(&self) -> http::StatusCode { + match self { + Self::InvalidBoundary(inner) => inner.status(), + } + } +} + +impl From for MultipartRejection { + fn from(inner: InvalidBoundary) -> Self { + Self::InvalidBoundary(inner) + } +} + +impl std::fmt::Display for MultipartRejection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidBoundary(inner) => write!(f, "{}", inner.body_text()), + } + } +} + +impl std::error::Error for MultipartRejection { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::InvalidBoundary(inner) => Some(inner), + } + } +} + +/// Rejection type used if the `boundary` in a `multipart/form-data` is +/// missing or invalid. +#[derive(Debug, Default)] +#[non_exhaustive] +pub struct InvalidBoundary; + +impl IntoResponse for InvalidBoundary { + fn into_response(self) -> Response { + (self.status(), self.body_text()).into_response() + } +} + +impl InvalidBoundary { + /// Get the response body text used for this rejection. + pub fn body_text(&self) -> String { + "Invalid `boundary` for `multipart/form-data` request".into() + } + + /// Get the status code used for this rejection. + pub fn status(&self) -> http::StatusCode { + http::StatusCode::BAD_REQUEST + } +} + +impl std::fmt::Display for InvalidBoundary { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.body_text()) + } +} + +impl std::error::Error for InvalidBoundary {} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::*; + use axum::{body::Body, response::IntoResponse, routing::post, Router}; + + #[tokio::test] + async fn content_type_with_encoding() { + const BYTES: &[u8] = "🦀".as_bytes(); + const FILE_NAME: &str = "index.html"; + const CONTENT_TYPE: &str = "text/html; charset=utf-8"; + + async fn handle(mut multipart: Multipart) -> impl IntoResponse { + let field = multipart.next_field().await.unwrap().unwrap(); + + assert_eq!(field.file_name().unwrap(), FILE_NAME); + assert_eq!(field.content_type().unwrap(), CONTENT_TYPE); + assert_eq!(field.bytes().await.unwrap(), BYTES); + + assert!(multipart.next_field().await.unwrap().is_none()); + } + + let app = Router::new().route("/", post(handle)); + + let client = TestClient::new(app); + + let form = reqwest::multipart::Form::new().part( + "file", + reqwest::multipart::Part::bytes(BYTES) + .file_name(FILE_NAME) + .mime_str(CONTENT_TYPE) + .unwrap(), + ); + + client.post("/").multipart(form).send().await; + } + + // No need for this to be a #[test], we just want to make sure it compiles + fn _multipart_from_request_limited() { + async fn handler(_: Multipart) {} + let _app: Router<(), http_body::Limited> = Router::new().route("/", post(handler)); + } +} diff --git a/axum-extra/src/lib.rs b/axum-extra/src/lib.rs index b8152c59..8cfcae9d 100644 --- a/axum-extra/src/lib.rs +++ b/axum-extra/src/lib.rs @@ -17,6 +17,7 @@ //! `erased-json` | Enables the `ErasedJson` response | No //! `form` | Enables the `Form` extractor | No //! `json-lines` | Enables the `JsonLines` extractor and response | No +//! `multipart` | Enables the `Multpart` extractor | No //! `protobuf` | Enables the `Protobuf` extractor and response | No //! `query` | Enables the `Query` extractor | No //! `spa` | Enables the `Spa` router | No