From 72c1b7a80c0e19320506a28e14d716507a52ffce Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 14 Mar 2023 09:13:19 +0100 Subject: [PATCH] Add `Body::from_stream` (#1848) --- axum-core/Cargo.toml | 2 + axum-core/src/body.rs | 61 ++++++++++++++++++++++++++ axum-extra/src/body/async_read_body.rs | 30 +++++-------- axum-extra/src/extract/multipart.rs | 6 +-- axum-extra/src/json_lines.rs | 4 +- axum/src/body/mod.rs | 4 -- axum/src/test_helpers/mod.rs | 1 - 7 files changed, 79 insertions(+), 29 deletions(-) diff --git a/axum-core/Cargo.toml b/axum-core/Cargo.toml index 4373dadc..7d7bb37e 100644 --- a/axum-core/Cargo.toml +++ b/axum-core/Cargo.toml @@ -24,6 +24,8 @@ futures-util = { version = "0.3", default-features = false, features = ["alloc"] http = "0.2.7" http-body = "0.4.5" mime = "0.3.16" +pin-project-lite = "0.2.7" +sync_wrapper = "0.1.1" tower-layer = "0.3" tower-service = "0.3" diff --git a/axum-core/src/body.rs b/axum-core/src/body.rs index 601e18a8..f3c16585 100644 --- a/axum-core/src/body.rs +++ b/axum-core/src/body.rs @@ -1,13 +1,17 @@ //! HTTP body utilities. +use crate::response::{IntoResponse, Response}; use crate::{BoxError, Error}; use bytes::Bytes; use bytes::{Buf, BufMut}; use futures_util::stream::Stream; +use futures_util::TryStream; use http::HeaderMap; use http_body::Body as _; +use pin_project_lite::pin_project; use std::pin::Pin; use std::task::{Context, Poll}; +use sync_wrapper::SyncWrapper; /// A boxed [`Body`] trait object. /// @@ -107,6 +111,20 @@ impl Body { pub fn empty() -> Self { Self::new(http_body::Empty::new()) } + + /// Create a new `Body` from a [`Stream`]. + /// + /// [`Stream`]: futures_util::stream::Stream + pub fn from_stream(stream: S) -> Self + where + S: TryStream + Send + 'static, + S::Ok: Into, + S::Error: Into, + { + Self::new(StreamBody { + stream: SyncWrapper::new(stream), + }) + } } impl Default for Body { @@ -175,6 +193,49 @@ impl Stream for Body { } } +pin_project! { + struct StreamBody { + #[pin] + stream: SyncWrapper, + } +} + +impl http_body::Body for StreamBody +where + S: TryStream, + S::Ok: Into, + S::Error: Into, +{ + type Data = Bytes; + type Error = Error; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let stream = self.project().stream.get_pin_mut(); + match futures_util::ready!(stream.try_poll_next(cx)) { + Some(Ok(chunk)) => Poll::Ready(Some(Ok(chunk.into()))), + Some(Err(err)) => Poll::Ready(Some(Err(Error::new(err)))), + None => Poll::Ready(None), + } + } + + #[inline] + fn poll_trailers( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + Poll::Ready(Ok(None)) + } +} + +impl IntoResponse for Body { + fn into_response(self) -> Response { + Response::new(self.0) + } +} + #[test] fn test_try_downcast() { assert_eq!(try_downcast::(5_u32), Err(5_u32)); diff --git a/axum-extra/src/body/async_read_body.rs b/axum-extra/src/body/async_read_body.rs index 5ea0fc59..ce87e436 100644 --- a/axum-extra/src/body/async_read_body.rs +++ b/axum-extra/src/body/async_read_body.rs @@ -1,5 +1,5 @@ use axum::{ - body::{self, Bytes, HttpBody, StreamBody}, + body::{Body, Bytes, HttpBody}, http::HeaderMap, response::{IntoResponse, Response}, Error, @@ -47,28 +47,25 @@ pin_project! { #[cfg(feature = "async-read-body")] #[derive(Debug)] #[must_use] - pub struct AsyncReadBody { + pub struct AsyncReadBody { #[pin] - read: StreamBody>, + body: Body, } } -impl AsyncReadBody { +impl AsyncReadBody { /// Create a new `AsyncReadBody`. - pub fn new(read: R) -> Self + pub fn new(read: R) -> Self where R: AsyncRead + Send + 'static, { Self { - read: StreamBody::new(ReaderStream::new(read)), + body: Body::from_stream(ReaderStream::new(read)), } } } -impl HttpBody for AsyncReadBody -where - R: AsyncRead + Send + 'static, -{ +impl HttpBody for AsyncReadBody { type Data = Bytes; type Error = Error; @@ -76,22 +73,19 @@ where self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - self.project().read.poll_data(cx) + self.project().body.poll_data(cx) } fn poll_trailers( self: Pin<&mut Self>, - _cx: &mut Context<'_>, + cx: &mut Context<'_>, ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) + self.project().body.poll_trailers(cx) } } -impl IntoResponse for AsyncReadBody -where - R: AsyncRead + Send + 'static, -{ +impl IntoResponse for AsyncReadBody { fn into_response(self) -> Response { - Response::new(body::boxed(self)) + self.body.into_response() } } diff --git a/axum-extra/src/extract/multipart.rs b/axum-extra/src/extract/multipart.rs index 67aac918..b62eb923 100644 --- a/axum-extra/src/extract/multipart.rs +++ b/axum-extra/src/extract/multipart.rs @@ -7,7 +7,7 @@ use axum::{ body::{Body, Bytes}, extract::FromRequest, response::{IntoResponse, Response}, - BoxError, RequestExt, + RequestExt, }; use futures_util::stream::Stream; use http::{ @@ -410,9 +410,7 @@ impl std::error::Error for InvalidBoundary {} mod tests { use super::*; use crate::test_helpers::*; - use axum::{ - body::Body, extract::DefaultBodyLimit, response::IntoResponse, routing::post, Router, - }; + use axum::{extract::DefaultBodyLimit, response::IntoResponse, routing::post, Router}; #[tokio::test] async fn content_type_with_encoding() { diff --git a/axum-extra/src/json_lines.rs b/axum-extra/src/json_lines.rs index 912529dd..973e418e 100644 --- a/axum-extra/src/json_lines.rs +++ b/axum-extra/src/json_lines.rs @@ -2,7 +2,7 @@ use axum::{ async_trait, - body::{Body, StreamBody}, + body::Body, extract::FromRequest, response::{IntoResponse, Response}, BoxError, @@ -166,7 +166,7 @@ where buf.write_all(b"\n")?; Ok::<_, BoxError>(buf.into_inner().freeze()) }); - let stream = StreamBody::new(stream); + let stream = Body::from_stream(stream); // there is no consensus around mime type yet // https://github.com/wardi/jsonlines/issues/36 diff --git a/axum/src/body/mod.rs b/axum/src/body/mod.rs index 9e1e826f..a1243ffd 100644 --- a/axum/src/body/mod.rs +++ b/axum/src/body/mod.rs @@ -1,9 +1,5 @@ //! HTTP body utilities. -mod stream_body; - -pub use self::stream_body::StreamBody; - #[doc(no_inline)] pub use http_body::{Body as HttpBody, Empty, Full}; diff --git a/axum/src/test_helpers/mod.rs b/axum/src/test_helpers/mod.rs index de455490..a8b4cf99 100644 --- a/axum/src/test_helpers/mod.rs +++ b/axum/src/test_helpers/mod.rs @@ -9,6 +9,5 @@ pub(crate) mod tracing_helpers; pub(crate) fn assert_send() {} pub(crate) fn assert_sync() {} -pub(crate) fn assert_unpin() {} pub(crate) struct NotSendSync(*const ());