Make extractors easier to write (#36)

Previously extractors worked directly on `Request<B>` which meant you
had to do weird tricks like `mem::take(req.headers_mut())` to get owned
parts of the request.

This changes that instead to use a new `RequestParts` type that have
methods to "take" each part of the request. Without having to do weird
tricks.

Also removed the need to have `B: Default` for body extractors.
This commit is contained in:
David Pedersen 2021-07-22 13:23:50 +02:00 committed by GitHub
parent e544fe1c39
commit f32d325e55
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 441 additions and 193 deletions

View file

@ -7,7 +7,8 @@
//! ```
use axum::{
extract::{ContentLengthLimit, Extension, UrlParams},
async_trait,
extract::{extractor_middleware, ContentLengthLimit, Extension, RequestParts, UrlParams},
prelude::*,
response::IntoResponse,
routing::BoxRoute,
@ -24,8 +25,7 @@ use std::{
};
use tower::{BoxError, ServiceBuilder};
use tower_http::{
add_extension::AddExtensionLayer, auth::RequireAuthorizationLayer,
compression::CompressionLayer, trace::TraceLayer,
add_extension::AddExtensionLayer, compression::CompressionLayer, trace::TraceLayer,
};
#[tokio::main]
@ -118,10 +118,40 @@ fn admin_routes() -> BoxRoute<hyper::Body> {
route("/keys", delete(delete_all_keys))
.route("/key/:key", delete(remove_key))
// Require beare auth for all admin routes
.layer(RequireAuthorizationLayer::bearer("secret-token"))
.layer(extractor_middleware::<RequireAuth>())
.boxed()
}
/// An extractor that performs authorization.
// TODO: when https://github.com/hyperium/http-body/pull/46 is merged we can use
// `tower_http::auth::RequireAuthorization` instead
struct RequireAuth;
#[async_trait]
impl<B> extract::FromRequest<B> for RequireAuth
where
B: Send,
{
type Rejection = StatusCode;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let auth_header = req
.headers()
.and_then(|headers| headers.get(http::header::AUTHORIZATION))
.and_then(|value| value.to_str().ok());
if let Some(value) = auth_header {
if let Some(token) = value.strip_prefix("Bearer ") {
if token == "secret-token" {
return Ok(Self);
}
}
}
Err(StatusCode::UNAUTHORIZED)
}
}
fn handle_error(error: BoxError) -> impl IntoResponse {
if error.is::<tower::timeout::error::Elapsed>() {
return (StatusCode::REQUEST_TIMEOUT, Cow::from("request timed out"));

View file

@ -1,5 +1,9 @@
use axum::response::IntoResponse;
use axum::{async_trait, extract::FromRequest, prelude::*};
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
prelude::*,
};
use http::Response;
use http::StatusCode;
use std::net::SocketAddr;
@ -36,7 +40,7 @@ where
{
type Rejection = Response<Body>;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let params = extract::UrlParamsMap::from_request(req)
.await
.map_err(IntoResponse::into_response)?;

View file

@ -1,13 +1,8 @@
//! HTTP body utilities.
use bytes::Bytes;
use http_body::{Empty, Full};
use std::{
error::Error as StdError,
fmt,
pin::Pin,
task::{Context, Poll},
};
use http_body::Body as _;
use std::{error::Error as StdError, fmt};
use tower::BoxError;
pub use hyper::body::Body;
@ -16,75 +11,18 @@ pub use hyper::body::Body;
///
/// This is used in axum as the response body type for applications. Its necessary to unify
/// multiple response bodies types into one.
pub struct BoxBody {
// when we've gotten rid of `BoxStdError` we should be able to change the error type to
// `BoxError`
inner: Pin<Box<dyn http_body::Body<Data = Bytes, Error = BoxStdError> + Send + Sync + 'static>>,
}
pub type BoxBody = http_body::combinators::BoxBody<Bytes, BoxStdError>;
impl BoxBody {
/// Create a new `BoxBody`.
pub fn new<B>(body: B) -> Self
where
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError>,
{
Self {
inner: Box::pin(body.map_err(|error| BoxStdError(error.into()))),
}
}
pub(crate) fn empty() -> Self {
Self::new(Empty::new())
}
}
impl Default for BoxBody {
fn default() -> Self {
BoxBody::empty()
}
}
impl fmt::Debug for BoxBody {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BoxBody").finish()
}
}
impl http_body::Body for BoxBody {
type Data = Bytes;
type Error = BoxStdError;
fn poll_data(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
self.inner.as_mut().poll_data(cx)
}
fn poll_trailers(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
self.inner.as_mut().poll_trailers(cx)
}
fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}
fn size_hint(&self) -> http_body::SizeHint {
self.inner.size_hint()
}
}
impl<B> From<B> for BoxBody
pub(crate) fn box_body<B>(body: B) -> BoxBody
where
B: Into<Bytes>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError>,
{
fn from(s: B) -> Self {
BoxBody::new(Full::from(s.into()))
}
body.map_err(|err| BoxStdError(err.into())).boxed()
}
pub(crate) fn empty() -> BoxBody {
box_body(http_body::Empty::new())
}
/// A boxed error trait object that implements [`std::error::Error`].

View file

@ -2,7 +2,7 @@
//!
//! See [`extractor_middleware`] for more details.
use super::FromRequest;
use super::{FromRequest, RequestParts};
use crate::{body::BoxBody, response::IntoResponse};
use bytes::Bytes;
use futures_util::{future::BoxFuture, ready};
@ -34,7 +34,7 @@ use tower::{BoxError, Layer, Service};
/// # Example
///
/// ```rust
/// use axum::{extract::extractor_middleware, prelude::*};
/// use axum::{extract::{extractor_middleware, RequestParts}, prelude::*};
/// use http::StatusCode;
/// use async_trait::async_trait;
///
@ -48,12 +48,13 @@ use tower::{BoxError, Layer, Service};
/// {
/// type Rejection = StatusCode;
///
/// async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
/// if let Some(value) = req
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
/// let auth_header = req
/// .headers()
/// .get(http::header::AUTHORIZATION)
/// .and_then(|value| value.to_str().ok())
/// {
/// .and_then(|headers| headers.get(http::header::AUTHORIZATION))
/// .and_then(|value| value.to_str().ok());
///
/// if let Some(value) = auth_header {
/// if value == "secret" {
/// return Ok(Self);
/// }
@ -169,8 +170,9 @@ where
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let extract_future = Box::pin(async move {
let mut req = super::RequestParts::new(req);
let extracted = E::from_request(&mut req).await;
(req, extracted)
});
@ -201,7 +203,7 @@ where
E: FromRequest<ReqBody>,
S: Service<Request<ReqBody>>,
{
Extracting(BoxFuture<'static, (Request<ReqBody>, Result<E, E::Rejection>)>),
Extracting(BoxFuture<'static, (RequestParts<ReqBody>, Result<E, E::Rejection>)>),
Call(#[pin] S::Future),
}
@ -220,16 +222,16 @@ where
let new_state = match this.state.as_mut().project() {
StateProj::Extracting(future) => {
let (req, extracted) = ready!(future.as_mut().poll(cx));
let (mut req, extracted) = ready!(future.as_mut().poll(cx));
match extracted {
Ok(_) => {
let mut svc = this.svc.take().expect("future polled after completion");
let future = svc.call(req);
let future = svc.call(req.into_request());
State::Call(future)
}
Err(err) => {
let res = err.into_response().map(BoxBody::new);
let res = err.into_response().map(crate::body::box_body);
return Poll::Ready(Ok(res));
}
}
@ -237,7 +239,7 @@ where
StateProj::Call(future) => {
return future
.poll(cx)
.map(|result| result.map(|response| response.map(BoxBody::new)));
.map(|result| result.map(|response| response.map(crate::body::box_body)));
}
};

View file

@ -34,7 +34,7 @@
//! You can also define your own extractors by implementing [`FromRequest`]:
//!
//! ```rust,no_run
//! use axum::{async_trait, extract::FromRequest, prelude::*};
//! use axum::{async_trait, extract::{FromRequest, RequestParts}, prelude::*};
//! use http::{StatusCode, header::{HeaderValue, USER_AGENT}};
//!
//! struct ExtractUserAgent(HeaderValue);
@ -46,8 +46,10 @@
//! {
//! type Rejection = (StatusCode, &'static str);
//!
//! async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
//! if let Some(user_agent) = req.headers().get(USER_AGENT) {
//! async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
//! let user_agent = req.headers().and_then(|headers| headers.get(USER_AGENT));
//!
//! if let Some(user_agent) = user_agent {
//! Ok(ExtractUserAgent(user_agent.clone()))
//! } else {
//! Err((StatusCode::BAD_REQUEST, "`User-Agent` header is missing"))
@ -175,13 +177,12 @@ use crate::{response::IntoResponse, util::ByteStr};
use async_trait::async_trait;
use bytes::{Buf, Bytes};
use futures_util::stream::Stream;
use http::{header, HeaderMap, Method, Request, Uri, Version};
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
use rejection::*;
use serde::de::DeserializeOwned;
use std::{
collections::HashMap,
convert::Infallible,
mem,
pin::Pin,
str::FromStr,
task::{Context, Poll},
@ -212,7 +213,195 @@ pub trait FromRequest<B>: Sized {
type Rejection: IntoResponse;
/// Perform the extraction.
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection>;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection>;
}
/// The type used with [`FromRequest`] to extract data from requests.
///
/// Has several convenience methods for getting owned parts of the request.
#[derive(Debug)]
pub struct RequestParts<B> {
method: Option<Method>,
uri: Option<Uri>,
version: Option<Version>,
headers: Option<HeaderMap>,
extensions: Option<Extensions>,
body: Option<B>,
}
impl<B> RequestParts<B> {
pub(crate) fn new(req: Request<B>) -> Self {
let (
http::request::Parts {
method,
uri,
version,
headers,
extensions,
..
},
body,
) = req.into_parts();
RequestParts {
method: Some(method),
uri: Some(uri),
version: Some(version),
headers: Some(headers),
extensions: Some(extensions),
body: Some(body),
}
}
#[allow(clippy::wrong_self_convention)]
pub(crate) fn into_request(&mut self) -> Request<B> {
let Self {
method,
uri,
version,
headers,
extensions,
body,
} = self;
let mut req = Request::new(body.take().expect("body already extracted"));
if let Some(method) = method.take() {
*req.method_mut() = method;
}
if let Some(uri) = uri.take() {
*req.uri_mut() = uri;
}
if let Some(version) = version.take() {
*req.version_mut() = version;
}
if let Some(headers) = headers.take() {
*req.headers_mut() = headers;
}
if let Some(extensions) = extensions.take() {
*req.extensions_mut() = extensions;
}
req
}
/// Gets a reference to the request method.
///
/// Returns `None` if the method has been taken by another extractor.
pub fn method(&self) -> Option<&Method> {
self.method.as_ref()
}
/// Gets a mutable reference to the request method.
///
/// Returns `None` if the method has been taken by another extractor.
pub fn method_mut(&mut self) -> Option<&mut Method> {
self.method.as_mut()
}
/// Takes the method out of the request, leaving a `None` in its place.
pub fn take_method(&mut self) -> Option<Method> {
self.method.take()
}
/// Gets a reference to the request URI.
///
/// Returns `None` if the URI has been taken by another extractor.
pub fn uri(&self) -> Option<&Uri> {
self.uri.as_ref()
}
/// Gets a mutable reference to the request URI.
///
/// Returns `None` if the URI has been taken by another extractor.
pub fn uri_mut(&mut self) -> Option<&mut Uri> {
self.uri.as_mut()
}
/// Takes the URI out of the request, leaving a `None` in its place.
pub fn take_uri(&mut self) -> Option<Uri> {
self.uri.take()
}
/// Gets a reference to the request HTTP version.
///
/// Returns `None` if the HTTP version has been taken by another extractor.
pub fn version(&self) -> Option<Version> {
self.version
}
/// Gets a mutable reference to the request HTTP version.
///
/// Returns `None` if the HTTP version has been taken by another extractor.
pub fn version_mut(&mut self) -> Option<&mut Version> {
self.version.as_mut()
}
/// Takes the HTTP version out of the request, leaving a `None` in its place.
pub fn take_version(&mut self) -> Option<Version> {
self.version.take()
}
/// Gets a reference to the request headers.
///
/// Returns `None` if the headers has been taken by another extractor.
pub fn headers(&self) -> Option<&HeaderMap> {
self.headers.as_ref()
}
/// Gets a mutable reference to the request headers.
///
/// Returns `None` if the headers has been taken by another extractor.
pub fn headers_mut(&mut self) -> Option<&mut HeaderMap> {
self.headers.as_mut()
}
/// Takes the headers out of the request, leaving a `None` in its place.
pub fn take_headers(&mut self) -> Option<HeaderMap> {
self.headers.take()
}
/// Gets a reference to the request extensions.
///
/// Returns `None` if the extensions has been taken by another extractor.
pub fn extensions(&self) -> Option<&Extensions> {
self.extensions.as_ref()
}
/// Gets a mutable reference to the request extensions.
///
/// Returns `None` if the extensions has been taken by another extractor.
pub fn extensions_mut(&mut self) -> Option<&mut Extensions> {
self.extensions.as_mut()
}
/// Takes the extensions out of the request, leaving a `None` in its place.
pub fn take_extensions(&mut self) -> Option<Extensions> {
self.extensions.take()
}
/// Gets a reference to the request body.
///
/// Returns `None` if the body has been taken by another extractor.
pub fn body(&self) -> Option<&B> {
self.body.as_ref()
}
/// Gets a mutable reference to the request body.
///
/// Returns `None` if the body has been taken by another extractor.
pub fn body_mut(&mut self) -> Option<&mut B> {
self.body.as_mut()
}
/// Takes the body out of the request, leaving a `None` in its place.
pub fn take_body(&mut self) -> Option<B> {
self.body.take()
}
}
#[async_trait]
@ -223,7 +412,7 @@ where
{
type Rejection = Infallible;
async fn from_request(req: &mut Request<B>) -> Result<Option<T>, Self::Rejection> {
async fn from_request(req: &mut RequestParts<B>) -> Result<Option<T>, Self::Rejection> {
Ok(T::from_request(req).await.ok())
}
}
@ -236,7 +425,7 @@ where
{
type Rejection = Infallible;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
Ok(T::from_request(req).await)
}
}
@ -284,8 +473,12 @@ where
{
type Rejection = QueryRejection;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
let query = req.uri().query().ok_or(QueryStringMissing)?;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let query = req
.uri()
.ok_or(UriAlreadyExtracted)?
.query()
.ok_or(QueryStringMissing)?;
let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::new::<T, _>)?;
Ok(Query(value))
@ -329,20 +522,24 @@ pub struct Form<T>(pub T);
impl<T, B> FromRequest<B> for Form<T>
where
T: DeserializeOwned,
B: http_body::Body + Default + Send,
B: http_body::Body + Send,
B::Data: Send,
B::Error: Into<tower::BoxError>,
{
type Rejection = FormRejection;
#[allow(warnings)]
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
if !has_content_type(&req, "application/x-www-form-urlencoded") {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
if !has_content_type(&req, "application/x-www-form-urlencoded")? {
Err(InvalidFormContentType)?;
}
if req.method() == Method::GET {
let query = req.uri().query().ok_or(QueryStringMissing)?;
if req.method().ok_or(MethodAlreadyExtracted)? == Method::GET {
let query = req
.uri()
.ok_or(UriAlreadyExtracted)?
.query()
.ok_or(QueryStringMissing)?;
let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::new::<T, _>)?;
Ok(Form(value))
@ -398,16 +595,16 @@ pub struct Json<T>(pub T);
impl<T, B> FromRequest<B> for Json<T>
where
T: DeserializeOwned,
B: http_body::Body + Default + Send,
B: http_body::Body + Send,
B::Data: Send,
B::Error: Into<tower::BoxError>,
{
type Rejection = JsonRejection;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
use bytes::Buf;
if has_content_type(req, "application/json") {
if has_content_type(req, "application/json")? {
let body = take_body(req)?;
let buf = hyper::body::aggregate(body)
@ -423,20 +620,27 @@ where
}
}
fn has_content_type<B>(req: &Request<B>, expected_content_type: &str) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
fn has_content_type<B>(
req: &RequestParts<B>,
expected_content_type: &str,
) -> Result<bool, HeadersAlreadyExtracted> {
let content_type = if let Some(content_type) = req
.headers()
.ok_or(HeadersAlreadyExtracted)?
.get(header::CONTENT_TYPE)
{
content_type
} else {
return false;
return Ok(false);
};
let content_type = if let Ok(content_type) = content_type.to_str() {
content_type
} else {
return false;
return Ok(false);
};
content_type.starts_with(expected_content_type)
Ok(content_type.starts_with(expected_content_type))
}
/// Extractor that gets a value from request extensions.
@ -480,11 +684,12 @@ where
T: Clone + Send + Sync + 'static,
B: Send,
{
type Rejection = MissingExtension;
type Rejection = ExtensionRejection;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let value = req
.extensions()
.ok_or(ExtensionsAlreadyExtracted)?
.get::<T>()
.ok_or(MissingExtension)
.map(|x| x.clone())?;
@ -496,13 +701,13 @@ where
#[async_trait]
impl<B> FromRequest<B> for Bytes
where
B: http_body::Body + Default + Send,
B: http_body::Body + Send,
B::Data: Send,
B::Error: Into<tower::BoxError>,
{
type Rejection = BytesRejection;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?;
let bytes = hyper::body::to_bytes(body)
@ -516,13 +721,13 @@ where
#[async_trait]
impl<B> FromRequest<B> for String
where
B: http_body::Body + Default + Send,
B: http_body::Body + Send,
B::Data: Send,
B::Error: Into<tower::BoxError>,
{
type Rejection = StringRejection;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?;
let bytes = hyper::body::to_bytes(body)
@ -572,11 +777,11 @@ where
#[async_trait]
impl<B> FromRequest<B> for BodyStream<B>
where
B: http_body::Body + Default + Unpin + Send,
B: http_body::Body + Unpin + Send,
{
type Rejection = BodyAlreadyExtracted;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?;
let stream = BodyStream(body);
Ok(stream)
@ -586,21 +791,22 @@ where
#[async_trait]
impl<B> FromRequest<B> for Request<B>
where
B: Default + Send,
B: Send,
{
type Rejection = RequestAlreadyExtracted;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
struct RequestAlreadyExtractedExt;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let all_parts = req
.method()
.zip(req.uri())
.zip(req.headers())
.zip(req.extensions())
.zip(req.body());
if req
.extensions_mut()
.insert(RequestAlreadyExtractedExt)
.is_some()
{
Err(RequestAlreadyExtracted)
if all_parts.is_some() {
Ok(req.into_request())
} else {
Ok(mem::take(req))
Err(RequestAlreadyExtracted)
}
}
}
@ -610,10 +816,10 @@ impl<B> FromRequest<B> for Method
where
B: Send,
{
type Rejection = Infallible;
type Rejection = MethodAlreadyExtracted;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
Ok(req.method().clone())
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
req.take_method().ok_or(MethodAlreadyExtracted)
}
}
@ -622,10 +828,10 @@ impl<B> FromRequest<B> for Uri
where
B: Send,
{
type Rejection = Infallible;
type Rejection = UriAlreadyExtracted;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
Ok(req.uri().clone())
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
req.take_uri().ok_or(UriAlreadyExtracted)
}
}
@ -634,10 +840,10 @@ impl<B> FromRequest<B> for Version
where
B: Send,
{
type Rejection = Infallible;
type Rejection = VersionAlreadyExtracted;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
Ok(req.version())
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
req.take_version().ok_or(VersionAlreadyExtracted)
}
}
@ -646,10 +852,10 @@ impl<B> FromRequest<B> for HeaderMap
where
B: Send,
{
type Rejection = Infallible;
type Rejection = HeadersAlreadyExtracted;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
Ok(mem::take(req.headers_mut()))
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
req.take_headers().ok_or(HeadersAlreadyExtracted)
}
}
@ -682,8 +888,13 @@ where
{
type Rejection = ContentLengthLimitRejection<T::Rejection>;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let content_length = req
.headers()
.ok_or(ContentLengthLimitRejection::HeadersAlreadyExtracted(
HeadersAlreadyExtracted,
))?
.get(http::header::CONTENT_LENGTH);
let content_length =
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
@ -752,10 +963,10 @@ where
{
type Rejection = MissingRouteParams;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
if let Some(params) = req
.extensions_mut()
.get_mut::<Option<crate::routing::UrlParams>>()
.and_then(|ext| ext.get_mut::<Option<crate::routing::UrlParams>>())
{
if let Some(params) = params {
Ok(Self(params.0.iter().cloned().collect()))
@ -810,10 +1021,12 @@ macro_rules! impl_parse_url {
type Rejection = UrlParamsRejection;
#[allow(non_snake_case)]
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let params = if let Some(params) = req
.extensions_mut()
.get_mut::<Option<crate::routing::UrlParams>>()
.and_then(|ext| {
ext.get_mut::<Option<crate::routing::UrlParams>>()
})
{
if let Some(params) = params {
params.0.clone()
@ -852,23 +1065,8 @@ macro_rules! impl_parse_url {
impl_parse_url!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
/// Request extension used to indicate that body has been extracted and `Default` has been left in
/// its place.
struct BodyAlreadyExtractedExt;
fn take_body<B>(req: &mut Request<B>) -> Result<B, BodyAlreadyExtracted>
where
B: Default,
{
if req
.extensions_mut()
.insert(BodyAlreadyExtractedExt)
.is_some()
{
Err(BodyAlreadyExtracted)
} else {
Ok(mem::take(req.body_mut()))
}
fn take_body<B>(req: &mut RequestParts<B>) -> Result<B, BodyAlreadyExtracted> {
req.take_body().ok_or(BodyAlreadyExtracted)
}
/// Extractor that extracts a typed header value from [`headers`].
@ -903,10 +1101,16 @@ where
T: headers::Header,
B: Send,
{
type Rejection = rejection::TypedHeaderRejection;
type Rejection = TypedHeaderRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let empty_headers = HeaderMap::new();
let header_values = if let Some(headers) = req.headers() {
headers.get_all(T::name())
} else {
empty_headers.get_all(T::name())
};
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
let header_values = req.headers().get_all(T::name());
T::decode(&mut header_values.iter())
.map(Self)
.map_err(|err| rejection::TypedHeaderRejection {

View file

@ -2,7 +2,7 @@
//!
//! See [`Multipart`] for more details.
use super::{rejection::*, BodyStream, FromRequest};
use super::{rejection::*, BodyStream, FromRequest, RequestParts};
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::stream::Stream;
@ -53,9 +53,10 @@ where
{
type Rejection = MultipartRejection;
async fn from_request(req: &mut http::Request<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let stream = BodyStream::from_request(req).await?;
let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
let headers = req.headers().ok_or(HeadersAlreadyExtracted)?;
let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?;
let multipart = multer::Multipart::new(stream, boundary);
Ok(Self { inner: multipart })
}
@ -175,6 +176,7 @@ composite_rejection! {
pub enum MultipartRejection {
BodyAlreadyExtracted,
InvalidBoundary,
HeadersAlreadyExtracted,
}
}

View file

@ -4,6 +4,41 @@ use super::IntoResponse;
use crate::body::Body;
use tower::BoxError;
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Version taken by other extractor"]
/// Rejection used if the HTTP version has been taken by another extractor.
pub struct VersionAlreadyExtracted;
}
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "URI taken by other extractor"]
/// Rejection used if the URI has been taken by another extractor.
pub struct UriAlreadyExtracted;
}
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Method taken by other extractor"]
/// Rejection used if the method has been taken by another extractor.
pub struct MethodAlreadyExtracted;
}
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Extensions taken by other extractor"]
/// Rejection used if the method has been taken by another extractor.
pub struct ExtensionsAlreadyExtracted;
}
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Headers taken by other extractor"]
/// Rejection used if the URI has been taken by another extractor.
pub struct HeadersAlreadyExtracted;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Query string was invalid or missing"]
@ -160,6 +195,7 @@ composite_rejection! {
/// Contains one variant for each way the [`Query`](super::Query) extractor
/// can fail.
pub enum QueryRejection {
UriAlreadyExtracted,
QueryStringMissing,
FailedToDeserializeQueryString,
}
@ -176,6 +212,9 @@ composite_rejection! {
FailedToDeserializeQueryString,
FailedToBufferBody,
BodyAlreadyExtracted,
UriAlreadyExtracted,
HeadersAlreadyExtracted,
MethodAlreadyExtracted,
}
}
@ -188,6 +227,18 @@ composite_rejection! {
InvalidJsonBody,
MissingJsonContentType,
BodyAlreadyExtracted,
HeadersAlreadyExtracted,
}
}
composite_rejection! {
/// Rejection used for [`Extension`](super::Extension).
///
/// Contains one variant for each way the [`Extension`](super::Extension) extractor
/// can fail.
pub enum ExtensionRejection {
MissingExtension,
ExtensionsAlreadyExtracted,
}
}
@ -236,6 +287,8 @@ pub enum ContentLengthLimitRejection<T> {
#[allow(missing_docs)]
LengthRequired(LengthRequired),
#[allow(missing_docs)]
HeadersAlreadyExtracted(HeadersAlreadyExtracted),
#[allow(missing_docs)]
Inner(T),
}
@ -247,6 +300,7 @@ where
match self {
Self::PayloadTooLarge(inner) => inner.into_response(),
Self::LengthRequired(inner) => inner.into_response(),
Self::HeadersAlreadyExtracted(inner) => inner.into_response(),
Self::Inner(inner) => inner.into_response(),
}
}

View file

@ -39,7 +39,7 @@
//! the [`extract`](crate::extract) module.
use crate::{
body::BoxBody,
body::{box_body, BoxBody},
extract::FromRequest,
response::IntoResponse,
routing::{EmptyRouter, MethodFilter, RouteFuture},
@ -289,7 +289,7 @@ where
type Sealed = sealed::Hidden;
async fn call(self, _req: Request<B>) -> Response<BoxBody> {
self().await.into_response().map(BoxBody::new)
self().await.into_response().map(box_body)
}
}
@ -310,22 +310,24 @@ macro_rules! impl_handler {
{
type Sealed = sealed::Hidden;
async fn call(self, mut req: Request<B>) -> Response<BoxBody> {
async fn call(self, req: Request<B>) -> Response<BoxBody> {
let mut req = crate::extract::RequestParts::new(req);
let $head = match $head::from_request(&mut req).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response().map(BoxBody::new),
Err(rejection) => return rejection.into_response().map(crate::body::box_body),
};
$(
let $tail = match $tail::from_request(&mut req).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response().map(BoxBody::new),
Err(rejection) => return rejection.into_response().map(crate::body::box_body),
};
)*
let res = self($head, $($tail,)*).await;
res.into_response().map(BoxBody::new)
res.into_response().map(crate::body::box_body)
}
}
@ -380,8 +382,8 @@ where
.await
.map_err(IntoResponse::into_response)
{
Ok(res) => res.map(BoxBody::new),
Err(res) => res.map(BoxBody::new),
Ok(res) => res.map(box_body),
Err(res) => res.map(box_body),
}
}
}

View file

@ -1,6 +1,11 @@
//! Routing between [`Service`]s.
use crate::{body::BoxBody, buffer::MpscBuffer, response::IntoResponse, util::ByteStr};
use crate::{
body::{box_body, BoxBody},
buffer::MpscBuffer,
response::IntoResponse,
util::ByteStr,
};
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::future;
@ -165,7 +170,7 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized {
.layer_fn(BoxRoute)
.layer_fn(MpscBuffer::new)
.layer(BoxService::layer())
.layer(MapResponseBodyLayer::new(BoxBody::new))
.layer(MapResponseBodyLayer::new(box_body))
.service(self)
}
@ -399,7 +404,7 @@ impl<B, E> Service<Request<B>> for EmptyRouter<E> {
}
fn call(&mut self, _req: Request<B>) -> Self::Future {
let mut res = Response::new(BoxBody::empty());
let mut res = Response::new(crate::body::empty());
*res.status_mut() = StatusCode::NOT_FOUND;
EmptyRouterFuture(future::ok(res))
}

View file

@ -1,6 +1,9 @@
//! [`Service`](tower::Service) future types.
use crate::{body::BoxBody, response::IntoResponse};
use crate::{
body::{box_body, BoxBody},
response::IntoResponse,
};
use bytes::Bytes;
use futures_util::ready;
use http::Response;
@ -36,11 +39,11 @@ where
let this = self.project();
match ready!(this.inner.poll(cx)) {
Ok(res) => Ok(res.map(BoxBody::new)).into(),
Ok(res) => Ok(res.map(box_body)).into(),
Err(err) => {
let f = this.f.take().unwrap();
let res = f(err).into_response();
Ok(res.map(BoxBody::new)).into()
Ok(res.map(box_body)).into()
}
}
}

View file

@ -87,7 +87,7 @@
//! [load shed]: tower::load_shed
use crate::{
body::BoxBody,
body::{box_body, BoxBody},
response::IntoResponse,
routing::{EmptyRouter, MethodFilter, RouteFuture},
};
@ -656,7 +656,7 @@ where
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = ready!(self.project().0.poll(cx))?;
let res = res.map(BoxBody::new);
let res = res.map(box_body);
Poll::Ready(Ok(res))
}
}

View file

@ -1,4 +1,7 @@
use crate::{handler::on, prelude::*, response::IntoResponse, routing::MethodFilter, service};
use crate::{
extract::RequestParts, handler::on, prelude::*, response::IntoResponse, routing::MethodFilter,
service,
};
use bytes::Bytes;
use http::{header::AUTHORIZATION, Request, Response, StatusCode};
use hyper::{Body, Server};
@ -105,7 +108,7 @@ async fn consume_body_to_json_requires_json_content_type() {
let app = route(
"/",
post(|_: Request<Body>, input: extract::Json<Input>| async { input.0.foo }),
post(|input: extract::Json<Input>| async { input.0.foo }),
);
let addr = run_in_background(app).await;
@ -675,9 +678,10 @@ async fn test_extractor_middleware() {
{
type Rejection = StatusCode;
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
if let Some(auth) = req
.headers()
.expect("headers already extracted")
.get("authorization")
.and_then(|v| v.to_str().ok())
{