From e1eb7d66152e72d039c9d4fb7ee15b5196fd125d Mon Sep 17 00:00:00 2001
From: David Pedersen <david.pdrsn@gmail.com>
Date: Sun, 27 Nov 2022 11:57:47 +0100
Subject: [PATCH] Add `axum_core::body::Body` (#1584)

---
 axum-core/src/body.rs | 94 ++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 92 insertions(+), 2 deletions(-)

diff --git a/axum-core/src/body.rs b/axum-core/src/body.rs
index 9f254089..601e18a8 100644
--- a/axum-core/src/body.rs
+++ b/axum-core/src/body.rs
@@ -3,7 +3,11 @@
 use crate::{BoxError, Error};
 use bytes::Bytes;
 use bytes::{Buf, BufMut};
-use http_body::Body;
+use futures_util::stream::Stream;
+use http::HeaderMap;
+use http_body::Body as _;
+use std::pin::Pin;
+use std::task::{Context, Poll};
 
 /// A boxed [`Body`] trait object.
 ///
@@ -55,7 +59,7 @@ where
 // THE SOFTWARE.
 pub(crate) async fn to_bytes<T>(body: T) -> Result<Bytes, T::Error>
 where
-    T: Body,
+    T: http_body::Body,
 {
     futures_util::pin_mut!(body);
 
@@ -85,6 +89,92 @@ where
     Ok(vec.into())
 }
 
+/// The body type used in axum requests and responses.
+#[derive(Debug)]
+pub struct Body(BoxBody);
+
+impl Body {
+    /// Create a new `Body` that wraps another [`http_body::Body`].
+    pub fn new<B>(body: B) -> Self
+    where
+        B: http_body::Body<Data = Bytes> + Send + 'static,
+        B::Error: Into<BoxError>,
+    {
+        try_downcast(body).unwrap_or_else(|body| Self(boxed(body)))
+    }
+
+    /// Create an empty body.
+    pub fn empty() -> Self {
+        Self::new(http_body::Empty::new())
+    }
+}
+
+impl Default for Body {
+    fn default() -> Self {
+        Self::empty()
+    }
+}
+
+macro_rules! body_from_impl {
+    ($ty:ty) => {
+        impl From<$ty> for Body {
+            fn from(buf: $ty) -> Self {
+                Self::new(http_body::Full::from(buf))
+            }
+        }
+    };
+}
+
+body_from_impl!(&'static [u8]);
+body_from_impl!(std::borrow::Cow<'static, [u8]>);
+body_from_impl!(Vec<u8>);
+
+body_from_impl!(&'static str);
+body_from_impl!(std::borrow::Cow<'static, str>);
+body_from_impl!(String);
+
+body_from_impl!(Bytes);
+
+impl http_body::Body for Body {
+    type Data = Bytes;
+    type Error = Error;
+
+    #[inline]
+    fn poll_data(
+        mut self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> std::task::Poll<Option<Result<Self::Data, Self::Error>>> {
+        Pin::new(&mut self.0).poll_data(cx)
+    }
+
+    #[inline]
+    fn poll_trailers(
+        mut self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> std::task::Poll<Result<Option<HeaderMap>, Self::Error>> {
+        Pin::new(&mut self.0).poll_trailers(cx)
+    }
+
+    #[inline]
+    fn size_hint(&self) -> http_body::SizeHint {
+        self.0.size_hint()
+    }
+
+    #[inline]
+    fn is_end_stream(&self) -> bool {
+        self.0.is_end_stream()
+    }
+}
+
+impl Stream for Body {
+    type Item = Result<Bytes, Error>;
+
+    #[inline]
+    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+        self.poll_data(cx)
+    }
+}
+
 #[test]
 fn test_try_downcast() {
     assert_eq!(try_downcast::<i32, _>(5_u32), Err(5_u32));