From 2b360a7873f87f6b9bbd1c341ceae3117f969f7e Mon Sep 17 00:00:00 2001 From: David Pedersen <david.pdrsn@gmail.com> Date: Sun, 13 Jun 2021 12:06:59 +0200 Subject: [PATCH] Support getting error from extractors (#14) Makes `Result<T, T::Rejection>` an extractor and makes all extraction errors enums so no type information is lost. --- src/extract/mod.rs | 147 ++++++++++++++++++++++---------------- src/extract/rejection.rs | 148 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 234 insertions(+), 61 deletions(-) diff --git a/src/extract/mod.rs b/src/extract/mod.rs index 58898807..49e7c25b 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -111,6 +111,41 @@ //! # }; //! ``` //! +//! Wrapping extractors in `Result` makes them optional and gives you the reason +//! the extraction failed: +//! +//! ```rust,no_run +//! use awebframework::{extract::{Json, rejection::JsonRejection}, prelude::*}; +//! use serde_json::Value; +//! +//! async fn create_user(payload: Result<Json<Value>, JsonRejection>) { +//! match payload { +//! Ok(payload) => { +//! // We got a valid JSON payload +//! } +//! Err(JsonRejection::MissingJsonContentType(_)) => { +//! // Request didn't have `Content-Type: application/json` +//! // header +//! } +//! Err(JsonRejection::InvalidJsonBody(_)) => { +//! // Couldn't deserialize the body into the target type +//! } +//! Err(JsonRejection::BodyAlreadyExtracted(_)) => { +//! // Another extractor had already consumed the body +//! } +//! Err(_) => { +//! // `JsonRejection` is marked `#[non_exhaustive]` so match must +//! // include a catch-all case. +//! } +//! } +//! } +//! +//! let app = route("/users", post(create_user)); +//! # async { +//! # app.serve(&"".parse().unwrap()).await.unwrap(); +//! # }; +//! ``` +//! //! # Reducing boilerplate //! //! If you're feeling adventorous you can even deconstruct the extractors @@ -133,13 +168,8 @@ use crate::{body::Body, response::IntoResponse}; use async_trait::async_trait; use bytes::{Buf, Bytes}; -use http::{header, HeaderMap, Method, Request, Response, Uri, Version}; -use rejection::{ - BodyAlreadyExtracted, FailedToBufferBody, FailedToDeserializeQueryString, - InvalidFormContentType, InvalidJsonBody, InvalidUrlParam, InvalidUtf8, LengthRequired, - MissingExtension, MissingJsonContentType, MissingRouteParams, PayloadTooLarge, - QueryStringMissing, RequestAlreadyExtracted, UrlParamsAlreadyExtracted, -}; +use http::{header, HeaderMap, Method, Request, Uri, Version}; +use rejection::*; use serde::de::DeserializeOwned; use std::{collections::HashMap, convert::Infallible, mem, str::FromStr}; @@ -170,6 +200,18 @@ where } } +#[async_trait] +impl<T> FromRequest for Result<T, T::Rejection> +where + T: FromRequest, +{ + type Rejection = Infallible; + + async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { + Ok(T::from_request(req).await) + } +} + /// Extractor that deserializes query strings into some type. /// /// `T` is expected to implement [`serde::Deserialize`]. @@ -207,17 +249,12 @@ impl<T> FromRequest for Query<T> where T: DeserializeOwned, { - type Rejection = Response<Body>; + type Rejection = QueryRejection; async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { - let query = req - .uri() - .query() - .ok_or(QueryStringMissing) - .map_err(IntoResponse::into_response)?; + let query = req.uri().query().ok_or(QueryStringMissing)?; let value = serde_urlencoded::from_str(query) - .map_err(FailedToDeserializeQueryString::new::<T, _>) - .map_err(IntoResponse::into_response)?; + .map_err(FailedToDeserializeQueryString::new::<T, _>)?; Ok(Query(value)) } } @@ -257,33 +294,26 @@ impl<T> FromRequest for Form<T> where T: DeserializeOwned, { - type Rejection = Response<Body>; + type Rejection = FormRejection; #[allow(warnings)] async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { if !has_content_type(&req, "application/x-www-form-urlencoded") { - return Err(InvalidFormContentType.into_response()); + Err(InvalidFormContentType)?; } if req.method() == Method::GET { - let query = req - .uri() - .query() - .ok_or(QueryStringMissing) - .map_err(IntoResponse::into_response)?; + let query = req.uri().query().ok_or(QueryStringMissing)?; let value = serde_urlencoded::from_str(query) - .map_err(FailedToDeserializeQueryString::new::<T, _>) - .map_err(IntoResponse::into_response)?; + .map_err(FailedToDeserializeQueryString::new::<T, _>)?; Ok(Form(value)) } else { - let body = take_body(req).map_err(IntoResponse::into_response)?; + let body = take_body(req)?; let chunks = hyper::body::aggregate(body) .await - .map_err(FailedToBufferBody::from_err) - .map_err(IntoResponse::into_response)?; + .map_err(FailedToBufferBody::from_err)?; let value = serde_urlencoded::from_reader(chunks.reader()) - .map_err(FailedToDeserializeQueryString::new::<T, _>) - .map_err(IntoResponse::into_response)?; + .map_err(FailedToDeserializeQueryString::new::<T, _>)?; Ok(Form(value)) } @@ -327,26 +357,23 @@ impl<T> FromRequest for Json<T> where T: DeserializeOwned, { - type Rejection = Response<Body>; + type Rejection = JsonRejection; async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { use bytes::Buf; if has_content_type(req, "application/json") { - let body = take_body(req).map_err(IntoResponse::into_response)?; + let body = take_body(req)?; let buf = hyper::body::aggregate(body) .await - .map_err(InvalidJsonBody::from_err) - .map_err(IntoResponse::into_response)?; + .map_err(InvalidJsonBody::from_err)?; - let value = serde_json::from_reader(buf.reader()) - .map_err(InvalidJsonBody::from_err) - .map_err(IntoResponse::into_response)?; + let value = serde_json::from_reader(buf.reader()).map_err(InvalidJsonBody::from_err)?; Ok(Json(value)) } else { - Err(MissingJsonContentType.into_response()) + Err(MissingJsonContentType.into()) } } } @@ -419,15 +446,14 @@ where #[async_trait] impl FromRequest for Bytes { - type Rejection = Response<Body>; + type Rejection = BytesRejection; async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { - let body = take_body(req).map_err(IntoResponse::into_response)?; + let body = take_body(req)?; let bytes = hyper::body::to_bytes(body) .await - .map_err(FailedToBufferBody::from_err) - .map_err(IntoResponse::into_response)?; + .map_err(FailedToBufferBody::from_err)?; Ok(bytes) } @@ -435,20 +461,17 @@ impl FromRequest for Bytes { #[async_trait] impl FromRequest for String { - type Rejection = Response<Body>; + type Rejection = StringRejection; async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { - let body = take_body(req).map_err(IntoResponse::into_response)?; + let body = take_body(req)?; let bytes = hyper::body::to_bytes(body) .await - .map_err(FailedToBufferBody::from_err) - .map_err(IntoResponse::into_response)? + .map_err(FailedToBufferBody::from_err)? .to_vec(); - let string = String::from_utf8(bytes) - .map_err(InvalidUtf8::from_err) - .map_err(IntoResponse::into_response)?; + let string = String::from_utf8(bytes).map_err(InvalidUtf8::from_err)?; Ok(string) } @@ -541,7 +564,7 @@ impl<T, const N: u64> FromRequest for ContentLengthLimit<T, N> where T: FromRequest, { - type Rejection = Response<Body>; + type Rejection = ContentLengthLimitRejection<T::Rejection>; async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned(); @@ -551,15 +574,17 @@ where if let Some(length) = content_length { if length > N { - return Err(PayloadTooLarge.into_response()); + return Err(ContentLengthLimitRejection::PayloadTooLarge( + PayloadTooLarge, + )); } } else { - return Err(LengthRequired.into_response()); + return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired)); }; let value = T::from_request(req) .await - .map_err(IntoResponse::into_response)?; + .map_err(ContentLengthLimitRejection::Inner)?; Ok(Self(value)) } @@ -603,7 +628,7 @@ impl UrlParamsMap { #[async_trait] impl FromRequest for UrlParamsMap { - type Rejection = Response<Body>; + type Rejection = UrlParamsMapRejection; async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { if let Some(params) = req @@ -613,10 +638,10 @@ impl FromRequest for UrlParamsMap { if let Some(params) = params.take() { Ok(Self(params.0.into_iter().collect())) } else { - Err(UrlParamsAlreadyExtracted.into_response()) + Err(UrlParamsAlreadyExtracted.into()) } } else { - Err(MissingRouteParams.into_response()) + Err(MissingRouteParams.into()) } } } @@ -656,7 +681,7 @@ macro_rules! impl_parse_url { $head: FromStr + Send, $( $tail: FromStr + Send, )* { - type Rejection = Response<Body>; + type Rejection = UrlParamsRejection; #[allow(non_snake_case)] async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { @@ -667,30 +692,30 @@ macro_rules! impl_parse_url { if let Some(params) = params.take() { params.0 } else { - return Err(UrlParamsAlreadyExtracted.into_response()); + return Err(UrlParamsAlreadyExtracted.into()); } } else { - return Err(MissingRouteParams.into_response()) + return Err(MissingRouteParams.into()) }; if let [(_, $head), $((_, $tail),)*] = &*params { let $head = if let Ok(x) = $head.parse::<$head>() { x } else { - return Err(InvalidUrlParam::new::<$head>().into_response()); + return Err(InvalidUrlParam::new::<$head>().into()); }; $( let $tail = if let Ok(x) = $tail.parse::<$tail>() { x } else { - return Err(InvalidUrlParam::new::<$tail>().into_response()); + return Err(InvalidUrlParam::new::<$tail>().into()); }; )* Ok(UrlParams(($head, $($tail,)*))) } else { - return Err(MissingRouteParams.into_response()) + Err(MissingRouteParams.into()) } } } diff --git a/src/extract/rejection.rs b/src/extract/rejection.rs index 51c481cc..8f27efee 100644 --- a/src/extract/rejection.rs +++ b/src/extract/rejection.rs @@ -212,3 +212,151 @@ impl IntoResponse for FailedToDeserializeQueryString { res } } + +macro_rules! composite_rejection { + ( + $(#[$m:meta])* + pub enum $name:ident { + $($variant:ident),+ + $(,)? + } + ) => { + $(#[$m])* + #[derive(Debug)] + #[non_exhaustive] + pub enum $name { + $( + #[allow(missing_docs)] + $variant($variant) + ),+ + } + + impl IntoResponse for $name { + fn into_response(self) -> http::Response<Body> { + match self { + $( + Self::$variant(inner) => inner.into_response(), + )+ + } + } + } + + $( + impl From<$variant> for $name { + fn from(inner: $variant) -> Self { + Self::$variant(inner) + } + } + )+ + }; +} + +composite_rejection! { + /// Rejection used for [`Query`](super::Query). + /// + /// Contains one variant for each way the [`Query`](super::Query) extractor + /// can fail. + pub enum QueryRejection { + QueryStringMissing, + FailedToDeserializeQueryString, + } +} + +composite_rejection! { + /// Rejection used for [`Form`](super::Form). + /// + /// Contains one variant for each way the [`Form`](super::Form) extractor + /// can fail. + pub enum FormRejection { + InvalidFormContentType, + QueryStringMissing, + FailedToDeserializeQueryString, + FailedToBufferBody, + BodyAlreadyExtracted, + } +} + +composite_rejection! { + /// Rejection used for [`Json`](super::Json). + /// + /// Contains one variant for each way the [`Json`](super::Json) extractor + /// can fail. + pub enum JsonRejection { + InvalidJsonBody, + MissingJsonContentType, + BodyAlreadyExtracted, + } +} + +composite_rejection! { + /// Rejection used for [`UrlParamsMap`](super::UrlParamsMap). + /// + /// Contains one variant for each way the [`UrlParamsMap`](super::UrlParamsMap) extractor + /// can fail. + pub enum UrlParamsMapRejection { + UrlParamsAlreadyExtracted, + MissingRouteParams, + } +} + +composite_rejection! { + /// Rejection used for [`UrlParams`](super::UrlParams). + /// + /// Contains one variant for each way the [`UrlParams`](super::UrlParams) extractor + /// can fail. + pub enum UrlParamsRejection { + InvalidUrlParam, + UrlParamsAlreadyExtracted, + MissingRouteParams, + } +} + +composite_rejection! { + /// Rejection used for [`Bytes`](bytes::Bytes). + /// + /// Contains one variant for each way the [`Bytes`](bytes::Bytes) extractor + /// can fail. + pub enum BytesRejection { + BodyAlreadyExtracted, + FailedToBufferBody, + } +} + +composite_rejection! { + /// Rejection used for [`String`]. + /// + /// Contains one variant for each way the [`String`] extractor can fail. + pub enum StringRejection { + BodyAlreadyExtracted, + FailedToBufferBody, + InvalidUtf8, + } +} + +/// Rejection used for [`ContentLengthLimit`](super::ContentLengthLimit). +/// +/// Contains one variant for each way the +/// [`ContentLengthLimit`](super::ContentLengthLimit) extractor can fail. +#[derive(Debug)] +#[non_exhaustive] +pub enum ContentLengthLimitRejection<T> { + #[allow(missing_docs)] + PayloadTooLarge(PayloadTooLarge), + #[allow(missing_docs)] + LengthRequired(LengthRequired), + #[allow(missing_docs)] + Inner(T), +} + +impl<T> IntoResponse for ContentLengthLimitRejection<T> +where + T: IntoResponse, +{ + fn into_response(self) -> http::Response<Body> { + match self { + Self::PayloadTooLarge(inner) => inner.into_response(), + Self::LengthRequired(inner) => inner.into_response(), + Self::Inner(inner) => inner.into_response(), + } + } +}