Support getting error from extractors (#14)

Makes `Result<T, T::Rejection>` an extractor and makes all extraction errors enums so no type information is lost.
This commit is contained in:
David Pedersen 2021-06-13 12:06:59 +02:00 committed by GitHub
parent 1002685a20
commit 2b360a7873
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 234 additions and 61 deletions

View file

@ -111,6 +111,41 @@
//! # };
//! ```
//!
//! Wrapping extractors in `Result` makes them optional and gives you the reason
//! the extraction failed:
//!
//! ```rust,no_run
//! use awebframework::{extract::{Json, rejection::JsonRejection}, prelude::*};
//! use serde_json::Value;
//!
//! async fn create_user(payload: Result<Json<Value>, JsonRejection>) {
//! match payload {
//! Ok(payload) => {
//! // We got a valid JSON payload
//! }
//! Err(JsonRejection::MissingJsonContentType(_)) => {
//! // Request didn't have `Content-Type: application/json`
//! // header
//! }
//! Err(JsonRejection::InvalidJsonBody(_)) => {
//! // Couldn't deserialize the body into the target type
//! }
//! Err(JsonRejection::BodyAlreadyExtracted(_)) => {
//! // Another extractor had already consumed the body
//! }
//! Err(_) => {
//! // `JsonRejection` is marked `#[non_exhaustive]` so match must
//! // include a catch-all case.
//! }
//! }
//! }
//!
//! let app = route("/users", post(create_user));
//! # async {
//! # app.serve(&"".parse().unwrap()).await.unwrap();
//! # };
//! ```
//!
//! # Reducing boilerplate
//!
//! If you're feeling adventorous you can even deconstruct the extractors
@ -133,13 +168,8 @@
use crate::{body::Body, response::IntoResponse};
use async_trait::async_trait;
use bytes::{Buf, Bytes};
use http::{header, HeaderMap, Method, Request, Response, Uri, Version};
use rejection::{
BodyAlreadyExtracted, FailedToBufferBody, FailedToDeserializeQueryString,
InvalidFormContentType, InvalidJsonBody, InvalidUrlParam, InvalidUtf8, LengthRequired,
MissingExtension, MissingJsonContentType, MissingRouteParams, PayloadTooLarge,
QueryStringMissing, RequestAlreadyExtracted, UrlParamsAlreadyExtracted,
};
use http::{header, HeaderMap, Method, Request, Uri, Version};
use rejection::*;
use serde::de::DeserializeOwned;
use std::{collections::HashMap, convert::Infallible, mem, str::FromStr};
@ -170,6 +200,18 @@ where
}
}
#[async_trait]
impl<T> FromRequest for Result<T, T::Rejection>
where
T: FromRequest,
{
type Rejection = Infallible;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
Ok(T::from_request(req).await)
}
}
/// Extractor that deserializes query strings into some type.
///
/// `T` is expected to implement [`serde::Deserialize`].
@ -207,17 +249,12 @@ impl<T> FromRequest for Query<T>
where
T: DeserializeOwned,
{
type Rejection = Response<Body>;
type Rejection = QueryRejection;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
let query = req
.uri()
.query()
.ok_or(QueryStringMissing)
.map_err(IntoResponse::into_response)?;
let query = req.uri().query().ok_or(QueryStringMissing)?;
let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::new::<T, _>)
.map_err(IntoResponse::into_response)?;
.map_err(FailedToDeserializeQueryString::new::<T, _>)?;
Ok(Query(value))
}
}
@ -257,33 +294,26 @@ impl<T> FromRequest for Form<T>
where
T: DeserializeOwned,
{
type Rejection = Response<Body>;
type Rejection = FormRejection;
#[allow(warnings)]
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
if !has_content_type(&req, "application/x-www-form-urlencoded") {
return Err(InvalidFormContentType.into_response());
Err(InvalidFormContentType)?;
}
if req.method() == Method::GET {
let query = req
.uri()
.query()
.ok_or(QueryStringMissing)
.map_err(IntoResponse::into_response)?;
let query = req.uri().query().ok_or(QueryStringMissing)?;
let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::new::<T, _>)
.map_err(IntoResponse::into_response)?;
.map_err(FailedToDeserializeQueryString::new::<T, _>)?;
Ok(Form(value))
} else {
let body = take_body(req).map_err(IntoResponse::into_response)?;
let body = take_body(req)?;
let chunks = hyper::body::aggregate(body)
.await
.map_err(FailedToBufferBody::from_err)
.map_err(IntoResponse::into_response)?;
.map_err(FailedToBufferBody::from_err)?;
let value = serde_urlencoded::from_reader(chunks.reader())
.map_err(FailedToDeserializeQueryString::new::<T, _>)
.map_err(IntoResponse::into_response)?;
.map_err(FailedToDeserializeQueryString::new::<T, _>)?;
Ok(Form(value))
}
@ -327,26 +357,23 @@ impl<T> FromRequest for Json<T>
where
T: DeserializeOwned,
{
type Rejection = Response<Body>;
type Rejection = JsonRejection;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
use bytes::Buf;
if has_content_type(req, "application/json") {
let body = take_body(req).map_err(IntoResponse::into_response)?;
let body = take_body(req)?;
let buf = hyper::body::aggregate(body)
.await
.map_err(InvalidJsonBody::from_err)
.map_err(IntoResponse::into_response)?;
.map_err(InvalidJsonBody::from_err)?;
let value = serde_json::from_reader(buf.reader())
.map_err(InvalidJsonBody::from_err)
.map_err(IntoResponse::into_response)?;
let value = serde_json::from_reader(buf.reader()).map_err(InvalidJsonBody::from_err)?;
Ok(Json(value))
} else {
Err(MissingJsonContentType.into_response())
Err(MissingJsonContentType.into())
}
}
}
@ -419,15 +446,14 @@ where
#[async_trait]
impl FromRequest for Bytes {
type Rejection = Response<Body>;
type Rejection = BytesRejection;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
let body = take_body(req).map_err(IntoResponse::into_response)?;
let body = take_body(req)?;
let bytes = hyper::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)
.map_err(IntoResponse::into_response)?;
.map_err(FailedToBufferBody::from_err)?;
Ok(bytes)
}
@ -435,20 +461,17 @@ impl FromRequest for Bytes {
#[async_trait]
impl FromRequest for String {
type Rejection = Response<Body>;
type Rejection = StringRejection;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
let body = take_body(req).map_err(IntoResponse::into_response)?;
let body = take_body(req)?;
let bytes = hyper::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)
.map_err(IntoResponse::into_response)?
.map_err(FailedToBufferBody::from_err)?
.to_vec();
let string = String::from_utf8(bytes)
.map_err(InvalidUtf8::from_err)
.map_err(IntoResponse::into_response)?;
let string = String::from_utf8(bytes).map_err(InvalidUtf8::from_err)?;
Ok(string)
}
@ -541,7 +564,7 @@ impl<T, const N: u64> FromRequest for ContentLengthLimit<T, N>
where
T: FromRequest,
{
type Rejection = Response<Body>;
type Rejection = ContentLengthLimitRejection<T::Rejection>;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
@ -551,15 +574,17 @@ where
if let Some(length) = content_length {
if length > N {
return Err(PayloadTooLarge.into_response());
return Err(ContentLengthLimitRejection::PayloadTooLarge(
PayloadTooLarge,
));
}
} else {
return Err(LengthRequired.into_response());
return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired));
};
let value = T::from_request(req)
.await
.map_err(IntoResponse::into_response)?;
.map_err(ContentLengthLimitRejection::Inner)?;
Ok(Self(value))
}
@ -603,7 +628,7 @@ impl UrlParamsMap {
#[async_trait]
impl FromRequest for UrlParamsMap {
type Rejection = Response<Body>;
type Rejection = UrlParamsMapRejection;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
if let Some(params) = req
@ -613,10 +638,10 @@ impl FromRequest for UrlParamsMap {
if let Some(params) = params.take() {
Ok(Self(params.0.into_iter().collect()))
} else {
Err(UrlParamsAlreadyExtracted.into_response())
Err(UrlParamsAlreadyExtracted.into())
}
} else {
Err(MissingRouteParams.into_response())
Err(MissingRouteParams.into())
}
}
}
@ -656,7 +681,7 @@ macro_rules! impl_parse_url {
$head: FromStr + Send,
$( $tail: FromStr + Send, )*
{
type Rejection = Response<Body>;
type Rejection = UrlParamsRejection;
#[allow(non_snake_case)]
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
@ -667,30 +692,30 @@ macro_rules! impl_parse_url {
if let Some(params) = params.take() {
params.0
} else {
return Err(UrlParamsAlreadyExtracted.into_response());
return Err(UrlParamsAlreadyExtracted.into());
}
} else {
return Err(MissingRouteParams.into_response())
return Err(MissingRouteParams.into())
};
if let [(_, $head), $((_, $tail),)*] = &*params {
let $head = if let Ok(x) = $head.parse::<$head>() {
x
} else {
return Err(InvalidUrlParam::new::<$head>().into_response());
return Err(InvalidUrlParam::new::<$head>().into());
};
$(
let $tail = if let Ok(x) = $tail.parse::<$tail>() {
x
} else {
return Err(InvalidUrlParam::new::<$tail>().into_response());
return Err(InvalidUrlParam::new::<$tail>().into());
};
)*
Ok(UrlParams(($head, $($tail,)*)))
} else {
return Err(MissingRouteParams.into_response())
Err(MissingRouteParams.into())
}
}
}

View file

@ -212,3 +212,151 @@ impl IntoResponse for FailedToDeserializeQueryString {
res
}
}
macro_rules! composite_rejection {
(
$(#[$m:meta])*
pub enum $name:ident {
$($variant:ident),+
$(,)?
}
) => {
$(#[$m])*
#[derive(Debug)]
#[non_exhaustive]
pub enum $name {
$(
#[allow(missing_docs)]
$variant($variant)
),+
}
impl IntoResponse for $name {
fn into_response(self) -> http::Response<Body> {
match self {
$(
Self::$variant(inner) => inner.into_response(),
)+
}
}
}
$(
impl From<$variant> for $name {
fn from(inner: $variant) -> Self {
Self::$variant(inner)
}
}
)+
};
}
composite_rejection! {
/// Rejection used for [`Query`](super::Query).
///
/// Contains one variant for each way the [`Query`](super::Query) extractor
/// can fail.
pub enum QueryRejection {
QueryStringMissing,
FailedToDeserializeQueryString,
}
}
composite_rejection! {
/// Rejection used for [`Form`](super::Form).
///
/// Contains one variant for each way the [`Form`](super::Form) extractor
/// can fail.
pub enum FormRejection {
InvalidFormContentType,
QueryStringMissing,
FailedToDeserializeQueryString,
FailedToBufferBody,
BodyAlreadyExtracted,
}
}
composite_rejection! {
/// Rejection used for [`Json`](super::Json).
///
/// Contains one variant for each way the [`Json`](super::Json) extractor
/// can fail.
pub enum JsonRejection {
InvalidJsonBody,
MissingJsonContentType,
BodyAlreadyExtracted,
}
}
composite_rejection! {
/// Rejection used for [`UrlParamsMap`](super::UrlParamsMap).
///
/// Contains one variant for each way the [`UrlParamsMap`](super::UrlParamsMap) extractor
/// can fail.
pub enum UrlParamsMapRejection {
UrlParamsAlreadyExtracted,
MissingRouteParams,
}
}
composite_rejection! {
/// Rejection used for [`UrlParams`](super::UrlParams).
///
/// Contains one variant for each way the [`UrlParams`](super::UrlParams) extractor
/// can fail.
pub enum UrlParamsRejection {
InvalidUrlParam,
UrlParamsAlreadyExtracted,
MissingRouteParams,
}
}
composite_rejection! {
/// Rejection used for [`Bytes`](bytes::Bytes).
///
/// Contains one variant for each way the [`Bytes`](bytes::Bytes) extractor
/// can fail.
pub enum BytesRejection {
BodyAlreadyExtracted,
FailedToBufferBody,
}
}
composite_rejection! {
/// Rejection used for [`String`].
///
/// Contains one variant for each way the [`String`] extractor can fail.
pub enum StringRejection {
BodyAlreadyExtracted,
FailedToBufferBody,
InvalidUtf8,
}
}
/// Rejection used for [`ContentLengthLimit`](super::ContentLengthLimit).
///
/// Contains one variant for each way the
/// [`ContentLengthLimit`](super::ContentLengthLimit) extractor can fail.
#[derive(Debug)]
#[non_exhaustive]
pub enum ContentLengthLimitRejection<T> {
#[allow(missing_docs)]
PayloadTooLarge(PayloadTooLarge),
#[allow(missing_docs)]
LengthRequired(LengthRequired),
#[allow(missing_docs)]
Inner(T),
}
impl<T> IntoResponse for ContentLengthLimitRejection<T>
where
T: IntoResponse,
{
fn into_response(self) -> http::Response<Body> {
match self {
Self::PayloadTooLarge(inner) => inner.into_response(),
Self::LengthRequired(inner) => inner.into_response(),
Self::Inner(inner) => inner.into_response(),
}
}
}