diff --git a/axum-core/src/body.rs b/axum-core/src/body.rs index 9f254089..601e18a8 100644 --- a/axum-core/src/body.rs +++ b/axum-core/src/body.rs @@ -3,7 +3,11 @@ use crate::{BoxError, Error}; use bytes::Bytes; use bytes::{Buf, BufMut}; -use http_body::Body; +use futures_util::stream::Stream; +use http::HeaderMap; +use http_body::Body as _; +use std::pin::Pin; +use std::task::{Context, Poll}; /// A boxed [`Body`] trait object. /// @@ -55,7 +59,7 @@ where // THE SOFTWARE. pub(crate) async fn to_bytes<T>(body: T) -> Result<Bytes, T::Error> where - T: Body, + T: http_body::Body, { futures_util::pin_mut!(body); @@ -85,6 +89,92 @@ where Ok(vec.into()) } +/// The body type used in axum requests and responses. +#[derive(Debug)] +pub struct Body(BoxBody); + +impl Body { + /// Create a new `Body` that wraps another [`http_body::Body`]. + pub fn new<B>(body: B) -> Self + where + B: http_body::Body<Data = Bytes> + Send + 'static, + B::Error: Into<BoxError>, + { + try_downcast(body).unwrap_or_else(|body| Self(boxed(body))) + } + + /// Create an empty body. + pub fn empty() -> Self { + Self::new(http_body::Empty::new()) + } +} + +impl Default for Body { + fn default() -> Self { + Self::empty() + } +} + +macro_rules! body_from_impl { + ($ty:ty) => { + impl From<$ty> for Body { + fn from(buf: $ty) -> Self { + Self::new(http_body::Full::from(buf)) + } + } + }; +} + +body_from_impl!(&'static [u8]); +body_from_impl!(std::borrow::Cow<'static, [u8]>); +body_from_impl!(Vec<u8>); + +body_from_impl!(&'static str); +body_from_impl!(std::borrow::Cow<'static, str>); +body_from_impl!(String); + +body_from_impl!(Bytes); + +impl http_body::Body for Body { + type Data = Bytes; + type Error = Error; + + #[inline] + fn poll_data( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> std::task::Poll<Option<Result<Self::Data, Self::Error>>> { + Pin::new(&mut self.0).poll_data(cx) + } + + #[inline] + fn poll_trailers( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> std::task::Poll<Result<Option<HeaderMap>, Self::Error>> { + Pin::new(&mut self.0).poll_trailers(cx) + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + self.0.size_hint() + } + + #[inline] + fn is_end_stream(&self) -> bool { + self.0.is_end_stream() + } +} + +impl Stream for Body { + type Item = Result<Bytes, Error>; + + #[inline] + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.poll_data(cx) + } +} + #[test] fn test_try_downcast() { assert_eq!(try_downcast::<i32, _>(5_u32), Err(5_u32));