mirror of
https://github.com/tokio-rs/axum.git
synced 2025-03-13 19:27:53 +01:00
Support getting error from extractors (#14)
Makes `Result<T, T::Rejection>` an extractor and makes all extraction errors enums so no type information is lost.
This commit is contained in:
parent
1002685a20
commit
2b360a7873
2 changed files with 234 additions and 61 deletions
|
@ -111,6 +111,41 @@
|
|||
//! # };
|
||||
//! ```
|
||||
//!
|
||||
//! Wrapping extractors in `Result` makes them optional and gives you the reason
|
||||
//! the extraction failed:
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use awebframework::{extract::{Json, rejection::JsonRejection}, prelude::*};
|
||||
//! use serde_json::Value;
|
||||
//!
|
||||
//! async fn create_user(payload: Result<Json<Value>, JsonRejection>) {
|
||||
//! match payload {
|
||||
//! Ok(payload) => {
|
||||
//! // We got a valid JSON payload
|
||||
//! }
|
||||
//! Err(JsonRejection::MissingJsonContentType(_)) => {
|
||||
//! // Request didn't have `Content-Type: application/json`
|
||||
//! // header
|
||||
//! }
|
||||
//! Err(JsonRejection::InvalidJsonBody(_)) => {
|
||||
//! // Couldn't deserialize the body into the target type
|
||||
//! }
|
||||
//! Err(JsonRejection::BodyAlreadyExtracted(_)) => {
|
||||
//! // Another extractor had already consumed the body
|
||||
//! }
|
||||
//! Err(_) => {
|
||||
//! // `JsonRejection` is marked `#[non_exhaustive]` so match must
|
||||
//! // include a catch-all case.
|
||||
//! }
|
||||
//! }
|
||||
//! }
|
||||
//!
|
||||
//! let app = route("/users", post(create_user));
|
||||
//! # async {
|
||||
//! # app.serve(&"".parse().unwrap()).await.unwrap();
|
||||
//! # };
|
||||
//! ```
|
||||
//!
|
||||
//! # Reducing boilerplate
|
||||
//!
|
||||
//! If you're feeling adventorous you can even deconstruct the extractors
|
||||
|
@ -133,13 +168,8 @@
|
|||
use crate::{body::Body, response::IntoResponse};
|
||||
use async_trait::async_trait;
|
||||
use bytes::{Buf, Bytes};
|
||||
use http::{header, HeaderMap, Method, Request, Response, Uri, Version};
|
||||
use rejection::{
|
||||
BodyAlreadyExtracted, FailedToBufferBody, FailedToDeserializeQueryString,
|
||||
InvalidFormContentType, InvalidJsonBody, InvalidUrlParam, InvalidUtf8, LengthRequired,
|
||||
MissingExtension, MissingJsonContentType, MissingRouteParams, PayloadTooLarge,
|
||||
QueryStringMissing, RequestAlreadyExtracted, UrlParamsAlreadyExtracted,
|
||||
};
|
||||
use http::{header, HeaderMap, Method, Request, Uri, Version};
|
||||
use rejection::*;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::{collections::HashMap, convert::Infallible, mem, str::FromStr};
|
||||
|
||||
|
@ -170,6 +200,18 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> FromRequest for Result<T, T::Rejection>
|
||||
where
|
||||
T: FromRequest,
|
||||
{
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
Ok(T::from_request(req).await)
|
||||
}
|
||||
}
|
||||
|
||||
/// Extractor that deserializes query strings into some type.
|
||||
///
|
||||
/// `T` is expected to implement [`serde::Deserialize`].
|
||||
|
@ -207,17 +249,12 @@ impl<T> FromRequest for Query<T>
|
|||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
type Rejection = Response<Body>;
|
||||
type Rejection = QueryRejection;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
let query = req
|
||||
.uri()
|
||||
.query()
|
||||
.ok_or(QueryStringMissing)
|
||||
.map_err(IntoResponse::into_response)?;
|
||||
let query = req.uri().query().ok_or(QueryStringMissing)?;
|
||||
let value = serde_urlencoded::from_str(query)
|
||||
.map_err(FailedToDeserializeQueryString::new::<T, _>)
|
||||
.map_err(IntoResponse::into_response)?;
|
||||
.map_err(FailedToDeserializeQueryString::new::<T, _>)?;
|
||||
Ok(Query(value))
|
||||
}
|
||||
}
|
||||
|
@ -257,33 +294,26 @@ impl<T> FromRequest for Form<T>
|
|||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
type Rejection = Response<Body>;
|
||||
type Rejection = FormRejection;
|
||||
|
||||
#[allow(warnings)]
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
if !has_content_type(&req, "application/x-www-form-urlencoded") {
|
||||
return Err(InvalidFormContentType.into_response());
|
||||
Err(InvalidFormContentType)?;
|
||||
}
|
||||
|
||||
if req.method() == Method::GET {
|
||||
let query = req
|
||||
.uri()
|
||||
.query()
|
||||
.ok_or(QueryStringMissing)
|
||||
.map_err(IntoResponse::into_response)?;
|
||||
let query = req.uri().query().ok_or(QueryStringMissing)?;
|
||||
let value = serde_urlencoded::from_str(query)
|
||||
.map_err(FailedToDeserializeQueryString::new::<T, _>)
|
||||
.map_err(IntoResponse::into_response)?;
|
||||
.map_err(FailedToDeserializeQueryString::new::<T, _>)?;
|
||||
Ok(Form(value))
|
||||
} else {
|
||||
let body = take_body(req).map_err(IntoResponse::into_response)?;
|
||||
let body = take_body(req)?;
|
||||
let chunks = hyper::body::aggregate(body)
|
||||
.await
|
||||
.map_err(FailedToBufferBody::from_err)
|
||||
.map_err(IntoResponse::into_response)?;
|
||||
.map_err(FailedToBufferBody::from_err)?;
|
||||
let value = serde_urlencoded::from_reader(chunks.reader())
|
||||
.map_err(FailedToDeserializeQueryString::new::<T, _>)
|
||||
.map_err(IntoResponse::into_response)?;
|
||||
.map_err(FailedToDeserializeQueryString::new::<T, _>)?;
|
||||
|
||||
Ok(Form(value))
|
||||
}
|
||||
|
@ -327,26 +357,23 @@ impl<T> FromRequest for Json<T>
|
|||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
type Rejection = Response<Body>;
|
||||
type Rejection = JsonRejection;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
use bytes::Buf;
|
||||
|
||||
if has_content_type(req, "application/json") {
|
||||
let body = take_body(req).map_err(IntoResponse::into_response)?;
|
||||
let body = take_body(req)?;
|
||||
|
||||
let buf = hyper::body::aggregate(body)
|
||||
.await
|
||||
.map_err(InvalidJsonBody::from_err)
|
||||
.map_err(IntoResponse::into_response)?;
|
||||
.map_err(InvalidJsonBody::from_err)?;
|
||||
|
||||
let value = serde_json::from_reader(buf.reader())
|
||||
.map_err(InvalidJsonBody::from_err)
|
||||
.map_err(IntoResponse::into_response)?;
|
||||
let value = serde_json::from_reader(buf.reader()).map_err(InvalidJsonBody::from_err)?;
|
||||
|
||||
Ok(Json(value))
|
||||
} else {
|
||||
Err(MissingJsonContentType.into_response())
|
||||
Err(MissingJsonContentType.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -419,15 +446,14 @@ where
|
|||
|
||||
#[async_trait]
|
||||
impl FromRequest for Bytes {
|
||||
type Rejection = Response<Body>;
|
||||
type Rejection = BytesRejection;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
let body = take_body(req).map_err(IntoResponse::into_response)?;
|
||||
let body = take_body(req)?;
|
||||
|
||||
let bytes = hyper::body::to_bytes(body)
|
||||
.await
|
||||
.map_err(FailedToBufferBody::from_err)
|
||||
.map_err(IntoResponse::into_response)?;
|
||||
.map_err(FailedToBufferBody::from_err)?;
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
|
@ -435,20 +461,17 @@ impl FromRequest for Bytes {
|
|||
|
||||
#[async_trait]
|
||||
impl FromRequest for String {
|
||||
type Rejection = Response<Body>;
|
||||
type Rejection = StringRejection;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
let body = take_body(req).map_err(IntoResponse::into_response)?;
|
||||
let body = take_body(req)?;
|
||||
|
||||
let bytes = hyper::body::to_bytes(body)
|
||||
.await
|
||||
.map_err(FailedToBufferBody::from_err)
|
||||
.map_err(IntoResponse::into_response)?
|
||||
.map_err(FailedToBufferBody::from_err)?
|
||||
.to_vec();
|
||||
|
||||
let string = String::from_utf8(bytes)
|
||||
.map_err(InvalidUtf8::from_err)
|
||||
.map_err(IntoResponse::into_response)?;
|
||||
let string = String::from_utf8(bytes).map_err(InvalidUtf8::from_err)?;
|
||||
|
||||
Ok(string)
|
||||
}
|
||||
|
@ -541,7 +564,7 @@ impl<T, const N: u64> FromRequest for ContentLengthLimit<T, N>
|
|||
where
|
||||
T: FromRequest,
|
||||
{
|
||||
type Rejection = Response<Body>;
|
||||
type Rejection = ContentLengthLimitRejection<T::Rejection>;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
|
||||
|
@ -551,15 +574,17 @@ where
|
|||
|
||||
if let Some(length) = content_length {
|
||||
if length > N {
|
||||
return Err(PayloadTooLarge.into_response());
|
||||
return Err(ContentLengthLimitRejection::PayloadTooLarge(
|
||||
PayloadTooLarge,
|
||||
));
|
||||
}
|
||||
} else {
|
||||
return Err(LengthRequired.into_response());
|
||||
return Err(ContentLengthLimitRejection::LengthRequired(LengthRequired));
|
||||
};
|
||||
|
||||
let value = T::from_request(req)
|
||||
.await
|
||||
.map_err(IntoResponse::into_response)?;
|
||||
.map_err(ContentLengthLimitRejection::Inner)?;
|
||||
|
||||
Ok(Self(value))
|
||||
}
|
||||
|
@ -603,7 +628,7 @@ impl UrlParamsMap {
|
|||
|
||||
#[async_trait]
|
||||
impl FromRequest for UrlParamsMap {
|
||||
type Rejection = Response<Body>;
|
||||
type Rejection = UrlParamsMapRejection;
|
||||
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
if let Some(params) = req
|
||||
|
@ -613,10 +638,10 @@ impl FromRequest for UrlParamsMap {
|
|||
if let Some(params) = params.take() {
|
||||
Ok(Self(params.0.into_iter().collect()))
|
||||
} else {
|
||||
Err(UrlParamsAlreadyExtracted.into_response())
|
||||
Err(UrlParamsAlreadyExtracted.into())
|
||||
}
|
||||
} else {
|
||||
Err(MissingRouteParams.into_response())
|
||||
Err(MissingRouteParams.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -656,7 +681,7 @@ macro_rules! impl_parse_url {
|
|||
$head: FromStr + Send,
|
||||
$( $tail: FromStr + Send, )*
|
||||
{
|
||||
type Rejection = Response<Body>;
|
||||
type Rejection = UrlParamsRejection;
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||
|
@ -667,30 +692,30 @@ macro_rules! impl_parse_url {
|
|||
if let Some(params) = params.take() {
|
||||
params.0
|
||||
} else {
|
||||
return Err(UrlParamsAlreadyExtracted.into_response());
|
||||
return Err(UrlParamsAlreadyExtracted.into());
|
||||
}
|
||||
} else {
|
||||
return Err(MissingRouteParams.into_response())
|
||||
return Err(MissingRouteParams.into())
|
||||
};
|
||||
|
||||
if let [(_, $head), $((_, $tail),)*] = &*params {
|
||||
let $head = if let Ok(x) = $head.parse::<$head>() {
|
||||
x
|
||||
} else {
|
||||
return Err(InvalidUrlParam::new::<$head>().into_response());
|
||||
return Err(InvalidUrlParam::new::<$head>().into());
|
||||
};
|
||||
|
||||
$(
|
||||
let $tail = if let Ok(x) = $tail.parse::<$tail>() {
|
||||
x
|
||||
} else {
|
||||
return Err(InvalidUrlParam::new::<$tail>().into_response());
|
||||
return Err(InvalidUrlParam::new::<$tail>().into());
|
||||
};
|
||||
)*
|
||||
|
||||
Ok(UrlParams(($head, $($tail,)*)))
|
||||
} else {
|
||||
return Err(MissingRouteParams.into_response())
|
||||
Err(MissingRouteParams.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -212,3 +212,151 @@ impl IntoResponse for FailedToDeserializeQueryString {
|
|||
res
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! composite_rejection {
|
||||
(
|
||||
$(#[$m:meta])*
|
||||
pub enum $name:ident {
|
||||
$($variant:ident),+
|
||||
$(,)?
|
||||
}
|
||||
) => {
|
||||
$(#[$m])*
|
||||
#[derive(Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum $name {
|
||||
$(
|
||||
#[allow(missing_docs)]
|
||||
$variant($variant)
|
||||
),+
|
||||
}
|
||||
|
||||
impl IntoResponse for $name {
|
||||
fn into_response(self) -> http::Response<Body> {
|
||||
match self {
|
||||
$(
|
||||
Self::$variant(inner) => inner.into_response(),
|
||||
)+
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
$(
|
||||
impl From<$variant> for $name {
|
||||
fn from(inner: $variant) -> Self {
|
||||
Self::$variant(inner)
|
||||
}
|
||||
}
|
||||
)+
|
||||
};
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`Query`](super::Query).
|
||||
///
|
||||
/// Contains one variant for each way the [`Query`](super::Query) extractor
|
||||
/// can fail.
|
||||
pub enum QueryRejection {
|
||||
QueryStringMissing,
|
||||
FailedToDeserializeQueryString,
|
||||
}
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`Form`](super::Form).
|
||||
///
|
||||
/// Contains one variant for each way the [`Form`](super::Form) extractor
|
||||
/// can fail.
|
||||
pub enum FormRejection {
|
||||
InvalidFormContentType,
|
||||
QueryStringMissing,
|
||||
FailedToDeserializeQueryString,
|
||||
FailedToBufferBody,
|
||||
BodyAlreadyExtracted,
|
||||
}
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`Json`](super::Json).
|
||||
///
|
||||
/// Contains one variant for each way the [`Json`](super::Json) extractor
|
||||
/// can fail.
|
||||
pub enum JsonRejection {
|
||||
InvalidJsonBody,
|
||||
MissingJsonContentType,
|
||||
BodyAlreadyExtracted,
|
||||
}
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`UrlParamsMap`](super::UrlParamsMap).
|
||||
///
|
||||
/// Contains one variant for each way the [`UrlParamsMap`](super::UrlParamsMap) extractor
|
||||
/// can fail.
|
||||
pub enum UrlParamsMapRejection {
|
||||
UrlParamsAlreadyExtracted,
|
||||
MissingRouteParams,
|
||||
}
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`UrlParams`](super::UrlParams).
|
||||
///
|
||||
/// Contains one variant for each way the [`UrlParams`](super::UrlParams) extractor
|
||||
/// can fail.
|
||||
pub enum UrlParamsRejection {
|
||||
InvalidUrlParam,
|
||||
UrlParamsAlreadyExtracted,
|
||||
MissingRouteParams,
|
||||
}
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`Bytes`](bytes::Bytes).
|
||||
///
|
||||
/// Contains one variant for each way the [`Bytes`](bytes::Bytes) extractor
|
||||
/// can fail.
|
||||
pub enum BytesRejection {
|
||||
BodyAlreadyExtracted,
|
||||
FailedToBufferBody,
|
||||
}
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`String`].
|
||||
///
|
||||
/// Contains one variant for each way the [`String`] extractor can fail.
|
||||
pub enum StringRejection {
|
||||
BodyAlreadyExtracted,
|
||||
FailedToBufferBody,
|
||||
InvalidUtf8,
|
||||
}
|
||||
}
|
||||
|
||||
/// Rejection used for [`ContentLengthLimit`](super::ContentLengthLimit).
|
||||
///
|
||||
/// Contains one variant for each way the
|
||||
/// [`ContentLengthLimit`](super::ContentLengthLimit) extractor can fail.
|
||||
#[derive(Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum ContentLengthLimitRejection<T> {
|
||||
#[allow(missing_docs)]
|
||||
PayloadTooLarge(PayloadTooLarge),
|
||||
#[allow(missing_docs)]
|
||||
LengthRequired(LengthRequired),
|
||||
#[allow(missing_docs)]
|
||||
Inner(T),
|
||||
}
|
||||
|
||||
impl<T> IntoResponse for ContentLengthLimitRejection<T>
|
||||
where
|
||||
T: IntoResponse,
|
||||
{
|
||||
fn into_response(self) -> http::Response<Body> {
|
||||
match self {
|
||||
Self::PayloadTooLarge(inner) => inner.into_response(),
|
||||
Self::LengthRequired(inner) => inner.into_response(),
|
||||
Self::Inner(inner) => inner.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue