diff --git a/src/extract.rs b/src/extract.rs index cf92d9fc..e516ed4e 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -1,46 +1,22 @@ use crate::{body::Body, Error}; +use async_trait::async_trait; use bytes::Bytes; -use futures_util::{future, ready}; use http::{header, Request, StatusCode}; -use pin_project::pin_project; use serde::de::DeserializeOwned; -use std::{ - collections::HashMap, - future::Future, - pin::Pin, - str::FromStr, - task::{Context, Poll}, -}; +use std::{collections::HashMap, str::FromStr}; +#[async_trait] pub trait FromRequest: Sized { - type Future: Future<Output = Result<Self, Error>> + Send; - - fn from_request(req: &mut Request<Body>) -> Self::Future; + async fn from_request(req: &mut Request<Body>) -> Result<Self, Error>; } +#[async_trait] impl<T> FromRequest for Option<T> where T: FromRequest, { - type Future = OptionFromRequestFuture<T::Future>; - - fn from_request(req: &mut Request<Body>) -> Self::Future { - OptionFromRequestFuture(T::from_request(req)) - } -} - -#[pin_project] -pub struct OptionFromRequestFuture<F>(#[pin] F); - -impl<F, T> Future for OptionFromRequestFuture<F> -where - F: Future<Output = Result<T, Error>>, -{ - type Output = Result<Option<T>, Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let value = ready!(self.project().0.poll(cx)); - Poll::Ready(Ok(value.ok())) + async fn from_request(req: &mut Request<Body>) -> Result<Option<T>, Error> { + Ok(T::from_request(req).await.ok()) } } @@ -53,20 +29,15 @@ impl<T> Query<T> { } } +#[async_trait] impl<T> FromRequest for Query<T> where - T: DeserializeOwned + Send, + T: DeserializeOwned, { - type Future = future::Ready<Result<Self, Error>>; - - fn from_request(req: &mut Request<Body>) -> Self::Future { - let result = (|| { - let query = req.uri().query().ok_or(Error::QueryStringMissing)?; - let value = serde_urlencoded::from_str(query).map_err(Error::DeserializeQueryString)?; - Ok(Query(value)) - })(); - - future::ready(result) + async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> { + let query = req.uri().query().ok_or(Error::QueryStringMissing)?; + let value = serde_urlencoded::from_str(query).map_err(Error::DeserializeQueryString)?; + Ok(Query(value)) } } @@ -79,26 +50,22 @@ impl<T> Json<T> { } } +#[async_trait] impl<T> FromRequest for Json<T> where T: DeserializeOwned, { - type Future = future::BoxFuture<'static, Result<Self, Error>>; - - fn from_request(req: &mut Request<Body>) -> Self::Future { + async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> { if has_content_type(&req, "application/json") { let body = std::mem::take(req.body_mut()); - Box::pin(async move { - let bytes = hyper::body::to_bytes(body) - .await - .map_err(Error::ConsumeRequestBody)?; - let value = - serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?; - Ok(Json(value)) - }) + let bytes = hyper::body::to_bytes(body) + .await + .map_err(Error::ConsumeRequestBody)?; + let value = serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?; + Ok(Json(value)) } else { - Box::pin(async { Err(Error::Status(StatusCode::BAD_REQUEST)) }) + Err(Error::Status(StatusCode::BAD_REQUEST)) } } } @@ -128,66 +95,58 @@ impl<T> Extension<T> { } } +#[async_trait] impl<T> FromRequest for Extension<T> where T: Clone + Send + Sync + 'static, { - type Future = future::Ready<Result<Self, Error>>; + async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> { + let value = req + .extensions() + .get::<T>() + .ok_or_else(|| Error::MissingExtension { + type_name: std::any::type_name::<T>(), + }) + .map(|x| x.clone())?; - fn from_request(req: &mut Request<Body>) -> Self::Future { - let result = (|| { - let value = req - .extensions() - .get::<T>() - .ok_or_else(|| Error::MissingExtension { - type_name: std::any::type_name::<T>(), - }) - .map(|x| x.clone())?; - Ok(Extension(value)) - })(); - - future::ready(result) + Ok(Extension(value)) } } +#[async_trait] impl FromRequest for Bytes { - type Future = future::BoxFuture<'static, Result<Self, Error>>; - - fn from_request(req: &mut Request<Body>) -> Self::Future { + async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> { let body = std::mem::take(req.body_mut()); - Box::pin(async move { - let bytes = hyper::body::to_bytes(body) - .await - .map_err(Error::ConsumeRequestBody)?; - Ok(bytes) - }) + let bytes = hyper::body::to_bytes(body) + .await + .map_err(Error::ConsumeRequestBody)?; + + Ok(bytes) } } +#[async_trait] impl FromRequest for String { - type Future = future::BoxFuture<'static, Result<Self, Error>>; - - fn from_request(req: &mut Request<Body>) -> Self::Future { + async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> { let body = std::mem::take(req.body_mut()); - Box::pin(async move { - let bytes = hyper::body::to_bytes(body) - .await - .map_err(Error::ConsumeRequestBody)? - .to_vec(); - let string = String::from_utf8(bytes).map_err(|_| Error::InvalidUtf8)?; - Ok(string) - }) + let bytes = hyper::body::to_bytes(body) + .await + .map_err(Error::ConsumeRequestBody)? + .to_vec(); + + let string = String::from_utf8(bytes).map_err(|_| Error::InvalidUtf8)?; + + Ok(string) } } +#[async_trait] impl FromRequest for Body { - type Future = future::Ready<Result<Self, Error>>; - - fn from_request(req: &mut Request<Body>) -> Self::Future { + async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> { let body = std::mem::take(req.body_mut()); - future::ok(body) + Ok(body) } } @@ -200,31 +159,28 @@ impl<const N: u64> BytesMaxLength<N> { } } +#[async_trait] impl<const N: u64> FromRequest for BytesMaxLength<N> { - type Future = future::BoxFuture<'static, Result<Self, Error>>; - - fn from_request(req: &mut Request<Body>) -> Self::Future { + async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> { let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned(); let body = std::mem::take(req.body_mut()); - Box::pin(async move { - let content_length = - content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok()); + let content_length = + content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok()); - if let Some(length) = content_length { - if length > N { - return Err(Error::PayloadTooLarge); - } - } else { - return Err(Error::LengthRequired); - }; + if let Some(length) = content_length { + if length > N { + return Err(Error::PayloadTooLarge); + } + } else { + return Err(Error::LengthRequired); + }; - let bytes = hyper::body::to_bytes(body) - .await - .map_err(Error::ConsumeRequestBody)?; + let bytes = hyper::body::to_bytes(body) + .await + .map_err(Error::ConsumeRequestBody)?; - Ok(BytesMaxLength(bytes)) - }) + Ok(BytesMaxLength(bytes)) } } @@ -249,16 +205,15 @@ impl UrlParamsMap { } } +#[async_trait] impl FromRequest for UrlParamsMap { - type Future = future::Ready<Result<Self, Error>>; - - fn from_request(req: &mut Request<Body>) -> Self::Future { + async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> { if let Some(params) = req .extensions_mut() .get_mut::<Option<crate::routing::UrlParams>>() { let params = params.take().expect("params already taken").0; - future::ok(Self(params.into_iter().collect())) + Ok(Self(params.into_iter().collect())) } else { panic!("no url params found for matched route. This is a bug in tower-web") } @@ -277,15 +232,14 @@ macro_rules! impl_parse_url { () => {}; ( $head:ident, $($tail:ident),* $(,)? ) => { + #[async_trait] impl<$head, $($tail,)*> FromRequest for UrlParams<($head, $($tail,)*)> where $head: FromStr + Send, $( $tail: FromStr + Send, )* { - type Future = future::Ready<Result<Self, Error>>; - #[allow(non_snake_case)] - fn from_request(req: &mut Request<Body>) -> Self::Future { + async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> { let params = if let Some(params) = req .extensions_mut() .get_mut::<Option<crate::routing::UrlParams>>() @@ -299,7 +253,7 @@ macro_rules! impl_parse_url { let $head = if let Ok(x) = $head.parse::<$head>() { x } else { - return future::err(Error::InvalidUrlParam { + return Err(Error::InvalidUrlParam { type_name: std::any::type_name::<$head>(), }); }; @@ -308,13 +262,13 @@ macro_rules! impl_parse_url { let $tail = if let Ok(x) = $tail.parse::<$tail>() { x } else { - return future::err(Error::InvalidUrlParam { + return Err(Error::InvalidUrlParam { type_name: std::any::type_name::<$tail>(), }); }; )* - future::ok(UrlParams(($head, $($tail,)*))) + Ok(UrlParams(($head, $($tail,)*))) } else { panic!("wrong number of url params found for matched route. This is a bug in tower-web") }