From 2b360a7873f87f6b9bbd1c341ceae3117f969f7e Mon Sep 17 00:00:00 2001
From: David Pedersen <david.pdrsn@gmail.com>
Date: Sun, 13 Jun 2021 12:06:59 +0200
Subject: [PATCH] Support getting error from extractors (#14)

Makes `Result<T, T::Rejection>` an extractor and makes all extraction errors enums so no type information is lost.
---
 src/extract/mod.rs       | 147 ++++++++++++++++++++++----------------
 src/extract/rejection.rs | 148 +++++++++++++++++++++++++++++++++++++++
 2 files changed, 234 insertions(+), 61 deletions(-)

diff --git a/src/extract/mod.rs b/src/extract/mod.rs
index 58898807..49e7c25b 100644
--- a/src/extract/mod.rs
+++ b/src/extract/mod.rs
@@ -111,6 +111,41 @@
 //! # };
 //! ```
 //!
+//! Wrapping extractors in `Result` makes them optional and gives you the reason
+//! the extraction failed:
+//!
+//! ```rust,no_run
+//! use awebframework::{extract::{Json, rejection::JsonRejection}, prelude::*};
+//! use serde_json::Value;
+//!
+//! async fn create_user(payload: Result<Json<Value>, JsonRejection>) {
+//!     match payload {
+//!         Ok(payload) => {
+//!             // We got a valid JSON payload
+//!         }
+//!         Err(JsonRejection::MissingJsonContentType(_)) => {
+//!             // Request didn't have `Content-Type: application/json`
+//!             // header
+//!         }
+//!         Err(JsonRejection::InvalidJsonBody(_)) => {
+//!             // Couldn't deserialize the body into the target type
+//!         }
+//!         Err(JsonRejection::BodyAlreadyExtracted(_)) => {
+//!             // Another extractor had already consumed the body
+//!         }
+//!         Err(_) => {
+//!             // `JsonRejection` is marked `#[non_exhaustive]` so match must
+//!             // include a catch-all case.
+//!         }
+//!     }
+//! }
+//!
+//! let app = route("/users", post(create_user));
+//! # async {
+//! # app.serve(&"".parse().unwrap()).await.unwrap();
+//! # };
+//! ```
+//!
 //! # Reducing boilerplate
 //!
 //! If you're feeling adventorous you can even deconstruct the extractors
@@ -133,13 +168,8 @@
 use crate::{body::Body, response::IntoResponse};
 use async_trait::async_trait;
 use bytes::{Buf, Bytes};
-use http::{header, HeaderMap, Method, Request, Response, Uri, Version};
-use rejection::{
-    BodyAlreadyExtracted, FailedToBufferBody, FailedToDeserializeQueryString,
-    InvalidFormContentType, InvalidJsonBody, InvalidUrlParam, InvalidUtf8, LengthRequired,
-    MissingExtension, MissingJsonContentType, MissingRouteParams, PayloadTooLarge,
-    QueryStringMissing, RequestAlreadyExtracted, UrlParamsAlreadyExtracted,
-};
+use http::{header, HeaderMap, Method, Request, Uri, Version};
+use rejection::*;
 use serde::de::DeserializeOwned;
 use std::{collections::HashMap, convert::Infallible, mem, str::FromStr};
 
@@ -170,6 +200,18 @@ where
     }
 }
 
+#[async_trait]
+impl<T> FromRequest for Result<T, T::Rejection>
+where
+    T: FromRequest,
+{
+    type Rejection = Infallible;
+
+    async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
+        Ok(T::from_request(req).await)
+    }
+}
+
 /// Extractor that deserializes query strings into some type.
 ///
 /// `T` is expected to implement [`serde::Deserialize`].
@@ -207,17 +249,12 @@ impl<T> FromRequest for Query<T>
 where
     T: DeserializeOwned,
 {
-    type Rejection = Response<Body>;
+    type Rejection = QueryRejection;
 
     async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
-        let query = req
-            .uri()
-            .query()
-            .ok_or(QueryStringMissing)
-            .map_err(IntoResponse::into_response)?;
+        let query = req.uri().query().ok_or(QueryStringMissing)?;
         let value = serde_urlencoded::from_str(query)
-            .map_err(FailedToDeserializeQueryString::new::<T, _>)
-            .map_err(IntoResponse::into_response)?;
+            .map_err(FailedToDeserializeQueryString::new::<T, _>)?;
         Ok(Query(value))
     }
 }
@@ -257,33 +294,26 @@ impl<T> FromRequest for Form<T>
 where
     T: DeserializeOwned,
 {
-    type Rejection = Response<Body>;
+    type Rejection = FormRejection;
 
     #[allow(warnings)]
     async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
         if !has_content_type(&req, "application/x-www-form-urlencoded") {
-            return Err(InvalidFormContentType.into_response());
+            Err(InvalidFormContentType)?;
         }
 
         if req.method() == Method::GET {
-            let query = req
-                .uri()
-                .query()
-                .ok_or(QueryStringMissing)
-                .map_err(IntoResponse::into_response)?;
+            let query = req.uri().query().ok_or(QueryStringMissing)?;
             let value = serde_urlencoded::from_str(query)
-                .map_err(FailedToDeserializeQueryString::new::<T, _>)
-                .map_err(IntoResponse::into_response)?;
+                .map_err(FailedToDeserializeQueryString::new::<T, _>)?;
             Ok(Form(value))
         } else {
-            let body = take_body(req).map_err(IntoResponse::into_response)?;
+            let body = take_body(req)?;
             let chunks = hyper::body::aggregate(body)
                 .await
-                .map_err(FailedToBufferBody::from_err)
-                .map_err(IntoResponse::into_response)?;
+                .map_err(FailedToBufferBody::from_err)?;
             let value = serde_urlencoded::from_reader(chunks.reader())
-                .map_err(FailedToDeserializeQueryString::new::<T, _>)
-                .map_err(IntoResponse::into_response)?;
+                .map_err(FailedToDeserializeQueryString::new::<T, _>)?;
 
             Ok(Form(value))
         }
@@ -327,26 +357,23 @@ impl<T> FromRequest for Json<T>
 where
     T: DeserializeOwned,
 {
-    type Rejection = Response<Body>;
+    type Rejection = JsonRejection;
 
     async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
         use bytes::Buf;
 
         if has_content_type(req, "application/json") {
-            let body = take_body(req).map_err(IntoResponse::into_response)?;
+            let body = take_body(req)?;
 
             let buf = hyper::body::aggregate(body)
                 .await
-                .map_err(InvalidJsonBody::from_err)
-                .map_err(IntoResponse::into_response)?;
+                .map_err(InvalidJsonBody::from_err)?;
 
-            let value = serde_json::from_reader(buf.reader())
-                .map_err(InvalidJsonBody::from_err)
-                .map_err(IntoResponse::into_response)?;
+            let value = serde_json::from_reader(buf.reader()).map_err(InvalidJsonBody::from_err)?;
 
             Ok(Json(value))
         } else {
-            Err(MissingJsonContentType.into_response())
+            Err(MissingJsonContentType.into())
         }
     }
 }
@@ -419,15 +446,14 @@ where
 
 #[async_trait]
 impl FromRequest for Bytes {
-    type Rejection = Response<Body>;
+    type Rejection = BytesRejection;
 
     async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
-        let body = take_body(req).map_err(IntoResponse::into_response)?;
+        let body = take_body(req)?;
 
         let bytes = hyper::body::to_bytes(body)
             .await
-            .map_err(FailedToBufferBody::from_err)
-            .map_err(IntoResponse::into_response)?;
+            .map_err(FailedToBufferBody::from_err)?;
 
         Ok(bytes)
     }
@@ -435,20 +461,17 @@ impl FromRequest for Bytes {
 
 #[async_trait]
 impl FromRequest for String {
-    type Rejection = Response<Body>;
+    type Rejection = StringRejection;
 
     async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
-        let body = take_body(req).map_err(IntoResponse::into_response)?;
+        let body = take_body(req)?;
 
         let bytes = hyper::body::to_bytes(body)
             .await
-            .map_err(FailedToBufferBody::from_err)
-            .map_err(IntoResponse::into_response)?
+            .map_err(FailedToBufferBody::from_err)?
             .to_vec();
 
-        let string = String::from_utf8(bytes)
-            .map_err(InvalidUtf8::from_err)
-            .map_err(IntoResponse::into_response)?;
+        let string = String::from_utf8(bytes).map_err(InvalidUtf8::from_err)?;
 
         Ok(string)
     }
@@ -541,7 +564,7 @@ impl<T, const N: u64> FromRequest for ContentLengthLimit<T, N>
 where
     T: FromRequest,
 {
-    type Rejection = Response<Body>;
+    type Rejection = ContentLengthLimitRejection<T::Rejection>;
 
     async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
         let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
@@ -551,15 +574,17 @@ where
 
         if let Some(length) = content_length {
             if length > N {
-                return Err(PayloadTooLarge.into_response());
+                return Err(ContentLengthLimitRejection::PayloadTooLarge(
+                    PayloadTooLarge,
+                ));
             }
         } else {
-            return Err(LengthRequired.into_response());
+            return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired));
         };
 
         let value = T::from_request(req)
             .await
-            .map_err(IntoResponse::into_response)?;
+            .map_err(ContentLengthLimitRejection::Inner)?;
 
         Ok(Self(value))
     }
@@ -603,7 +628,7 @@ impl UrlParamsMap {
 
 #[async_trait]
 impl FromRequest for UrlParamsMap {
-    type Rejection = Response<Body>;
+    type Rejection = UrlParamsMapRejection;
 
     async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
         if let Some(params) = req
@@ -613,10 +638,10 @@ impl FromRequest for UrlParamsMap {
             if let Some(params) = params.take() {
                 Ok(Self(params.0.into_iter().collect()))
             } else {
-                Err(UrlParamsAlreadyExtracted.into_response())
+                Err(UrlParamsAlreadyExtracted.into())
             }
         } else {
-            Err(MissingRouteParams.into_response())
+            Err(MissingRouteParams.into())
         }
     }
 }
@@ -656,7 +681,7 @@ macro_rules! impl_parse_url {
             $head: FromStr + Send,
             $( $tail: FromStr + Send, )*
         {
-            type Rejection = Response<Body>;
+            type Rejection = UrlParamsRejection;
 
             #[allow(non_snake_case)]
             async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
@@ -667,30 +692,30 @@ macro_rules! impl_parse_url {
                     if let Some(params) = params.take() {
                         params.0
                     } else {
-                        return Err(UrlParamsAlreadyExtracted.into_response());
+                        return Err(UrlParamsAlreadyExtracted.into());
                     }
                 } else {
-                    return Err(MissingRouteParams.into_response())
+                    return Err(MissingRouteParams.into())
                 };
 
                 if let [(_, $head), $((_, $tail),)*] = &*params {
                     let $head = if let Ok(x) = $head.parse::<$head>() {
                        x
                     } else {
-                        return Err(InvalidUrlParam::new::<$head>().into_response());
+                        return Err(InvalidUrlParam::new::<$head>().into());
                     };
 
                     $(
                         let $tail = if let Ok(x) = $tail.parse::<$tail>() {
                            x
                         } else {
-                            return Err(InvalidUrlParam::new::<$tail>().into_response());
+                            return Err(InvalidUrlParam::new::<$tail>().into());
                         };
                     )*
 
                     Ok(UrlParams(($head, $($tail,)*)))
                 } else {
-                    return Err(MissingRouteParams.into_response())
+                    Err(MissingRouteParams.into())
                 }
             }
         }
diff --git a/src/extract/rejection.rs b/src/extract/rejection.rs
index 51c481cc..8f27efee 100644
--- a/src/extract/rejection.rs
+++ b/src/extract/rejection.rs
@@ -212,3 +212,151 @@ impl IntoResponse for FailedToDeserializeQueryString {
         res
     }
 }
+
+macro_rules! composite_rejection {
+    (
+        $(#[$m:meta])*
+        pub enum $name:ident {
+            $($variant:ident),+
+            $(,)?
+        }
+    ) => {
+        $(#[$m])*
+        #[derive(Debug)]
+        #[non_exhaustive]
+        pub enum $name {
+            $(
+                #[allow(missing_docs)]
+                $variant($variant)
+            ),+
+        }
+
+        impl IntoResponse for $name {
+            fn into_response(self) -> http::Response<Body> {
+                match self {
+                    $(
+                        Self::$variant(inner) => inner.into_response(),
+                    )+
+                }
+            }
+        }
+
+        $(
+            impl From<$variant> for $name {
+                fn from(inner: $variant) -> Self {
+                    Self::$variant(inner)
+                }
+            }
+        )+
+    };
+}
+
+composite_rejection! {
+    /// Rejection used for [`Query`](super::Query).
+    ///
+    /// Contains one variant for each way the [`Query`](super::Query) extractor
+    /// can fail.
+    pub enum QueryRejection {
+        QueryStringMissing,
+        FailedToDeserializeQueryString,
+    }
+}
+
+composite_rejection! {
+    /// Rejection used for [`Form`](super::Form).
+    ///
+    /// Contains one variant for each way the [`Form`](super::Form) extractor
+    /// can fail.
+    pub enum FormRejection {
+        InvalidFormContentType,
+        QueryStringMissing,
+        FailedToDeserializeQueryString,
+        FailedToBufferBody,
+        BodyAlreadyExtracted,
+    }
+}
+
+composite_rejection! {
+    /// Rejection used for [`Json`](super::Json).
+    ///
+    /// Contains one variant for each way the [`Json`](super::Json) extractor
+    /// can fail.
+    pub enum JsonRejection {
+        InvalidJsonBody,
+        MissingJsonContentType,
+        BodyAlreadyExtracted,
+    }
+}
+
+composite_rejection! {
+    /// Rejection used for [`UrlParamsMap`](super::UrlParamsMap).
+    ///
+    /// Contains one variant for each way the [`UrlParamsMap`](super::UrlParamsMap) extractor
+    /// can fail.
+    pub enum UrlParamsMapRejection {
+        UrlParamsAlreadyExtracted,
+        MissingRouteParams,
+    }
+}
+
+composite_rejection! {
+    /// Rejection used for [`UrlParams`](super::UrlParams).
+    ///
+    /// Contains one variant for each way the [`UrlParams`](super::UrlParams) extractor
+    /// can fail.
+    pub enum UrlParamsRejection {
+        InvalidUrlParam,
+        UrlParamsAlreadyExtracted,
+        MissingRouteParams,
+    }
+}
+
+composite_rejection! {
+    /// Rejection used for [`Bytes`](bytes::Bytes).
+    ///
+    /// Contains one variant for each way the [`Bytes`](bytes::Bytes) extractor
+    /// can fail.
+    pub enum BytesRejection {
+        BodyAlreadyExtracted,
+        FailedToBufferBody,
+    }
+}
+
+composite_rejection! {
+    /// Rejection used for [`String`].
+    ///
+    /// Contains one variant for each way the [`String`] extractor can fail.
+    pub enum StringRejection {
+        BodyAlreadyExtracted,
+        FailedToBufferBody,
+        InvalidUtf8,
+    }
+}
+
+/// Rejection used for [`ContentLengthLimit`](super::ContentLengthLimit).
+///
+/// Contains one variant for each way the
+/// [`ContentLengthLimit`](super::ContentLengthLimit) extractor can fail.
+#[derive(Debug)]
+#[non_exhaustive]
+pub enum ContentLengthLimitRejection<T> {
+    #[allow(missing_docs)]
+    PayloadTooLarge(PayloadTooLarge),
+    #[allow(missing_docs)]
+    LengthRequired(LengthRequired),
+    #[allow(missing_docs)]
+    Inner(T),
+}
+
+impl<T> IntoResponse for ContentLengthLimitRejection<T>
+where
+    T: IntoResponse,
+{
+    fn into_response(self) -> http::Response<Body> {
+        match self {
+            Self::PayloadTooLarge(inner) => inner.into_response(),
+            Self::LengthRequired(inner) => inner.into_response(),
+            Self::Inner(inner) => inner.into_response(),
+        }
+    }
+}