Only allow last extractor to mutate the request (#1272)

* Only allow last extractor to mutate the request

* Change `FromRequest` and add `FromRequestParts` trait (#1275)

* Add `Once`/`Mut` type parameter for `FromRequest` and `RequestParts`

* 🪄

* split traits

* `FromRequest` for tuples

* Remove `BodyAlreadyExtracted`

* don't need fully qualified path

* don't export `Once` and `Mut`

* remove temp tests

* depend on axum again

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

* Port `Handler` and most extractors (#1277)

* Port `Handler` and most extractors

* Put `M` inside `Handler` impls, not trait itself

* comment out tuples for now

* fix lints

* Reorder arguments to `Handler` (#1281)

I think `Request<B>, Arc<S>` is better since its consistent with
`FromRequest` and `FromRequestParts`.

* Port most things in axum-extra (#1282)

* Port `#[derive(TypedPath)]` and `#[debug_handler]` (#1283)

* port #[derive(TypedPath)]

* wip: #[debug_handler]

* fix #[debug_handler]

* don't need itertools

* also require `Send`

* update expected error

* support fully qualified `self`

* Implement FromRequest[Parts] for tuples (#1286)

* Port docs for axum and axum-core (#1285)

* Port axum-extra (#1287)

* Port axum-extra

* Update axum-core/Cargo.toml

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

* remove `impl FromRequest for Either*`

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

* New FromRequest[Parts] trait cleanup (#1288)

* Make private module truly private again

* Simplify tuple FromRequest implementation

* Port `#[derive(FromRequest)]` (#1289)

* fix tests

* fix docs

* revert examples

* fix docs link

* fix intra docs links

* Port examples (#1291)

* Document wrapping other extractors (#1292)

* axum-extra doesn't need to depend on axum-core (#1294)

Missed this in https://github.com/tokio-rs/axum/pull/1287

* Add `FromRequest` changes to changelogs (#1293)

* Update changelog

* Remove default type for `S` in `Handler`

* Clarify which types have default types for `S`

* Apply suggestions from code review

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

* remove unused import

* Rename `Mut` and `Once` (#1296)

* fix trybuild expected output

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
David Pedersen 2022-08-22 12:23:20 +02:00 committed by GitHub
parent f1769e5134
commit be624306f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
104 changed files with 1513 additions and 1936 deletions

View file

@ -7,10 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased # Unreleased
- **breaking:** `FromRequest` and `RequestParts` has a new `S` type param which - **breaking:** `FromRequest` has been reworked and `RequestParts` has been
represents the state ([#1155]) removed. See axum's changelog for more details ([#1272])
- **added:** Added new `FromRequestParts` trait. See axum's changelog for more
details ([#1272])
- **breaking:** `BodyAlreadyExtracted` has been removed ([#1272])
[#1155]: https://github.com/tokio-rs/axum/pull/1155 [#1155]: https://github.com/tokio-rs/axum/pull/1155
[#1272]: https://github.com/tokio-rs/axum/pull/1272
# 0.2.7 (10. July, 2022) # 0.2.7 (10. July, 2022)

View file

@ -4,11 +4,10 @@
//! //!
//! [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html //! [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html
use self::rejection::*;
use crate::response::IntoResponse; use crate::response::IntoResponse;
use async_trait::async_trait; use async_trait::async_trait;
use http::{Extensions, HeaderMap, Method, Request, Uri, Version}; use http::{request::Parts, Request};
use std::{convert::Infallible, sync::Arc}; use std::convert::Infallible;
pub mod rejection; pub mod rejection;
@ -18,9 +17,44 @@ mod tuple;
pub use self::from_ref::FromRef; pub use self::from_ref::FromRef;
mod private {
#[derive(Debug, Clone, Copy)]
pub enum ViaParts {}
#[derive(Debug, Clone, Copy)]
pub enum ViaRequest {}
}
/// Types that can be created from request parts.
///
/// Extractors that implement `FromRequestParts` cannot consume the request body and can thus be
/// run in any order for handlers.
///
/// If your extractor needs to consume the request body then you should implement [`FromRequest`]
/// and not [`FromRequestParts`].
///
/// See [`axum::extract`] for more general docs about extraxtors.
///
/// [`axum::extract`]: https://docs.rs/axum/0.6/axum/extract/index.html
#[async_trait]
pub trait FromRequestParts<S>: Sized {
/// If the extractor fails it'll use this "rejection" type. A rejection is
/// a kind of error that can be converted into a response.
type Rejection: IntoResponse;
/// Perform the extraction.
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection>;
}
/// Types that can be created from requests. /// Types that can be created from requests.
/// ///
/// See [`axum::extract`] for more details. /// Extractors that implement `FromRequest` can consume the request body and can thus only be run
/// once for handlers.
///
/// If your extractor doesn't need to consume the request body then you should implement
/// [`FromRequestParts`] and not [`FromRequest`].
///
/// See [`axum::extract`] for more general docs about extraxtors.
/// ///
/// # What is the `B` type parameter? /// # What is the `B` type parameter?
/// ///
@ -39,7 +73,8 @@ pub use self::from_ref::FromRef;
/// ```rust /// ```rust
/// use axum::{ /// use axum::{
/// async_trait, /// async_trait,
/// extract::{FromRequest, RequestParts}, /// extract::FromRequest,
/// http::Request,
/// }; /// };
/// ///
/// struct MyExtractor; /// struct MyExtractor;
@ -48,12 +83,12 @@ pub use self::from_ref::FromRef;
/// impl<S, B> FromRequest<S, B> for MyExtractor /// impl<S, B> FromRequest<S, B> for MyExtractor
/// where /// where
/// // these bounds are required by `async_trait` /// // these bounds are required by `async_trait`
/// B: Send, /// B: Send + 'static,
/// S: Send + Sync, /// S: Send + Sync,
/// { /// {
/// type Rejection = http::StatusCode; /// type Rejection = http::StatusCode;
/// ///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { /// async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
/// // ... /// // ...
/// # unimplemented!() /// # unimplemented!()
/// } /// }
@ -63,231 +98,45 @@ pub use self::from_ref::FromRef;
/// This ensures your extractor is as flexible as possible. /// This ensures your extractor is as flexible as possible.
/// ///
/// [`http::Request<B>`]: http::Request /// [`http::Request<B>`]: http::Request
/// [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html /// [`axum::extract`]: https://docs.rs/axum/0.6/axum/extract/index.html
#[async_trait] #[async_trait]
pub trait FromRequest<S, B>: Sized { pub trait FromRequest<S, B, M = private::ViaRequest>: Sized {
/// If the extractor fails it'll use this "rejection" type. A rejection is /// If the extractor fails it'll use this "rejection" type. A rejection is
/// a kind of error that can be converted into a response. /// a kind of error that can be converted into a response.
type Rejection: IntoResponse; type Rejection: IntoResponse;
/// Perform the extraction. /// Perform the extraction.
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection>; async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection>;
} }
/// The type used with [`FromRequest`] to extract data from requests. #[async_trait]
/// impl<S, B, T> FromRequest<S, B, private::ViaParts> for T
/// Has several convenience methods for getting owned parts of the request.
#[derive(Debug)]
pub struct RequestParts<S, B> {
pub(crate) state: Arc<S>,
method: Method,
uri: Uri,
version: Version,
headers: HeaderMap,
extensions: Extensions,
body: Option<B>,
}
impl<B> RequestParts<(), B> {
/// Create a new `RequestParts` without any state.
///
/// You generally shouldn't need to construct this type yourself, unless
/// using extractors outside of axum for example to implement a
/// [`tower::Service`].
///
/// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html
pub fn new(req: Request<B>) -> Self {
Self::with_state((), req)
}
}
impl<S, B> RequestParts<S, B> {
/// Create a new `RequestParts` with the given state.
///
/// You generally shouldn't need to construct this type yourself, unless
/// using extractors outside of axum for example to implement a
/// [`tower::Service`].
///
/// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html
pub fn with_state(state: S, req: Request<B>) -> Self {
Self::with_state_arc(Arc::new(state), req)
}
/// Create a new `RequestParts` with the given [`Arc`]'ed state.
///
/// You generally shouldn't need to construct this type yourself, unless
/// using extractors outside of axum for example to implement a
/// [`tower::Service`].
///
/// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html
pub fn with_state_arc(state: Arc<S>, req: Request<B>) -> Self {
let (
http::request::Parts {
method,
uri,
version,
headers,
extensions,
..
},
body,
) = req.into_parts();
RequestParts {
state,
method,
uri,
version,
headers,
extensions,
body: Some(body),
}
}
/// Apply an extractor to this `RequestParts`.
///
/// `req.extract::<Extractor>()` is equivalent to `Extractor::from_request(req)`.
/// This function simply exists as a convenience.
///
/// # Example
///
/// ```
/// # struct MyExtractor {}
///
/// use std::convert::Infallible;
///
/// use async_trait::async_trait;
/// use axum::extract::{FromRequest, RequestParts};
/// use http::{Method, Uri};
///
/// #[async_trait]
/// impl<S, B> FromRequest<S, B> for MyExtractor
/// where
/// B: Send,
/// S: Send + Sync,
/// {
/// type Rejection = Infallible;
///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Infallible> {
/// let method = req.extract::<Method>().await?;
/// let path = req.extract::<Uri>().await?.path().to_owned();
///
/// todo!()
/// }
/// }
/// ```
pub async fn extract<E>(&mut self) -> Result<E, E::Rejection>
where where
E: FromRequest<S, B>, B: Send + 'static,
S: Send + Sync,
T: FromRequestParts<S>,
{ {
E::from_request(self).await type Rejection = <Self as FromRequestParts<S>>::Rejection;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let (mut parts, _) = req.into_parts();
Self::from_request_parts(&mut parts, state).await
}
} }
/// Convert this `RequestParts` back into a [`Request`]. #[async_trait]
/// impl<S, T> FromRequestParts<S> for Option<T>
/// Fails if The request body has been extracted, that is [`take_body`] has where
/// been called. T: FromRequestParts<S>,
/// S: Send + Sync,
/// [`take_body`]: RequestParts::take_body {
pub fn try_into_request(self) -> Result<Request<B>, BodyAlreadyExtracted> { type Rejection = Infallible;
let Self {
state: _,
method,
uri,
version,
headers,
extensions,
mut body,
} = self;
let mut req = if let Some(body) = body.take() { async fn from_request_parts(
Request::new(body) parts: &mut Parts,
} else { state: &S,
return Err(BodyAlreadyExtracted); ) -> Result<Option<T>, Self::Rejection> {
}; Ok(T::from_request_parts(parts, state).await.ok())
*req.method_mut() = method;
*req.uri_mut() = uri;
*req.version_mut() = version;
*req.headers_mut() = headers;
*req.extensions_mut() = extensions;
Ok(req)
}
/// Gets a reference to the request method.
pub fn method(&self) -> &Method {
&self.method
}
/// Gets a mutable reference to the request method.
pub fn method_mut(&mut self) -> &mut Method {
&mut self.method
}
/// Gets a reference to the request URI.
pub fn uri(&self) -> &Uri {
&self.uri
}
/// Gets a mutable reference to the request URI.
pub fn uri_mut(&mut self) -> &mut Uri {
&mut self.uri
}
/// Get the request HTTP version.
pub fn version(&self) -> Version {
self.version
}
/// Gets a mutable reference to the request HTTP version.
pub fn version_mut(&mut self) -> &mut Version {
&mut self.version
}
/// Gets a reference to the request headers.
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
/// Gets a mutable reference to the request headers.
pub fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.headers
}
/// Gets a reference to the request extensions.
pub fn extensions(&self) -> &Extensions {
&self.extensions
}
/// Gets a mutable reference to the request extensions.
pub fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}
/// Gets a reference to the request body.
///
/// Returns `None` if the body has been taken by another extractor.
pub fn body(&self) -> Option<&B> {
self.body.as_ref()
}
/// Gets a mutable reference to the request body.
///
/// Returns `None` if the body has been taken by another extractor.
// this returns `&mut Option<B>` rather than `Option<&mut B>` such that users can use it to set the body.
pub fn body_mut(&mut self) -> &mut Option<B> {
&mut self.body
}
/// Takes the body out of the request, leaving a `None` in its place.
pub fn take_body(&mut self) -> Option<B> {
self.body.take()
}
/// Get a reference to the state.
pub fn state(&self) -> &S {
&self.state
} }
} }
@ -295,13 +144,26 @@ impl<S, B> RequestParts<S, B> {
impl<S, T, B> FromRequest<S, B> for Option<T> impl<S, T, B> FromRequest<S, B> for Option<T>
where where
T: FromRequest<S, B>, T: FromRequest<S, B>,
B: Send, B: Send + 'static,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Option<T>, Self::Rejection> { async fn from_request(req: Request<B>, state: &S) -> Result<Option<T>, Self::Rejection> {
Ok(T::from_request(req).await.ok()) Ok(T::from_request(req, state).await.ok())
}
}
#[async_trait]
impl<S, T> FromRequestParts<S> for Result<T, T::Rejection>
where
T: FromRequestParts<S>,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
Ok(T::from_request_parts(parts, state).await)
} }
} }
@ -309,12 +171,12 @@ where
impl<S, T, B> FromRequest<S, B> for Result<T, T::Rejection> impl<S, T, B> FromRequest<S, B> for Result<T, T::Rejection>
where where
T: FromRequest<S, B>, T: FromRequest<S, B>,
B: Send, B: Send + 'static,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
Ok(T::from_request(req).await) Ok(T::from_request(req, state).await)
} }
} }

View file

@ -1,35 +1,6 @@
//! Rejection response types. //! Rejection response types.
use crate::{ use crate::BoxError;
response::{IntoResponse, Response},
BoxError,
};
use http::StatusCode;
use std::fmt;
/// Rejection type used if you try and extract the request body more than
/// once.
#[derive(Debug, Default)]
#[non_exhaustive]
pub struct BodyAlreadyExtracted;
impl BodyAlreadyExtracted {
const BODY: &'static str = "Cannot have two request body extractors for a single handler";
}
impl IntoResponse for BodyAlreadyExtracted {
fn into_response(self) -> Response {
(StatusCode::INTERNAL_SERVER_ERROR, Self::BODY).into_response()
}
}
impl fmt::Display for BodyAlreadyExtracted {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", Self::BODY)
}
}
impl std::error::Error for BodyAlreadyExtracted {}
composite_rejection! { composite_rejection! {
/// Rejection type for extractors that buffer the request body. Used if the /// Rejection type for extractors that buffer the request body. Used if the
@ -85,7 +56,6 @@ composite_rejection! {
/// Contains one variant for each way the [`Bytes`](bytes::Bytes) extractor /// Contains one variant for each way the [`Bytes`](bytes::Bytes) extractor
/// can fail. /// can fail.
pub enum BytesRejection { pub enum BytesRejection {
BodyAlreadyExtracted,
FailedToBufferBody, FailedToBufferBody,
} }
} }
@ -95,7 +65,6 @@ composite_rejection! {
/// ///
/// Contains one variant for each way the [`String`] extractor can fail. /// Contains one variant for each way the [`String`] extractor can fail.
pub enum StringRejection { pub enum StringRejection {
BodyAlreadyExtracted,
FailedToBufferBody, FailedToBufferBody,
InvalidUtf8, InvalidUtf8,
} }

View file

@ -1,9 +1,9 @@
use super::{rejection::*, FromRequest, RequestParts}; use super::{rejection::*, FromRequest, FromRequestParts};
use crate::BoxError; use crate::BoxError;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use http::{Extensions, HeaderMap, Method, Request, Uri, Version}; use http::{request::Parts, HeaderMap, Method, Request, Uri, Version};
use std::{convert::Infallible, sync::Arc}; use std::convert::Infallible;
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for Request<B> impl<S, B> FromRequest<S, B> for Request<B>
@ -11,62 +11,46 @@ where
B: Send, B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = BodyAlreadyExtracted; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, _: &S) -> Result<Self, Self::Rejection> {
let req = std::mem::replace( Ok(req)
req,
RequestParts {
state: Arc::clone(&req.state),
method: req.method.clone(),
version: req.version,
uri: req.uri.clone(),
headers: HeaderMap::new(),
extensions: Extensions::default(),
body: None,
},
);
req.try_into_request()
} }
} }
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for Method impl<S> FromRequestParts<S> for Method
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
Ok(req.method().clone()) Ok(parts.method.clone())
} }
} }
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for Uri impl<S> FromRequestParts<S> for Uri
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
Ok(req.uri().clone()) Ok(parts.uri.clone())
} }
} }
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for Version impl<S> FromRequestParts<S> for Version
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
Ok(req.version()) Ok(parts.version)
} }
} }
@ -76,30 +60,29 @@ where
/// ///
/// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html /// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for HeaderMap impl<S> FromRequestParts<S> for HeaderMap
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
Ok(req.headers().clone()) Ok(parts.headers.clone())
} }
} }
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for Bytes impl<S, B> FromRequest<S, B> for Bytes
where where
B: http_body::Body + Send, B: http_body::Body + Send + 'static,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = BytesRejection; type Rejection = BytesRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, _: &S) -> Result<Self, Self::Rejection> {
let body = take_body(req)?; let body = req.into_body();
let bytes = crate::body::to_bytes(body) let bytes = crate::body::to_bytes(body)
.await .await
@ -112,15 +95,15 @@ where
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for String impl<S, B> FromRequest<S, B> for String
where where
B: http_body::Body + Send, B: http_body::Body + Send + 'static,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = StringRejection; type Rejection = StringRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, _: &S) -> Result<Self, Self::Rejection> {
let body = take_body(req)?; let body = req.into_body();
let bytes = crate::body::to_bytes(body) let bytes = crate::body::to_bytes(body)
.await .await
@ -134,40 +117,14 @@ where
} }
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for http::request::Parts impl<S, B> FromRequest<S, B> for Parts
where where
B: Send, B: Send + 'static,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, _: &S) -> Result<Self, Self::Rejection> {
let method = unwrap_infallible(Method::from_request(req).await); Ok(req.into_parts().0)
let uri = unwrap_infallible(Uri::from_request(req).await);
let version = unwrap_infallible(Version::from_request(req).await);
let headers = unwrap_infallible(HeaderMap::from_request(req).await);
let extensions = std::mem::take(req.extensions_mut());
let mut temp_request = Request::new(());
*temp_request.method_mut() = method;
*temp_request.uri_mut() = uri;
*temp_request.version_mut() = version;
*temp_request.headers_mut() = headers;
*temp_request.extensions_mut() = extensions;
let (parts, _) = temp_request.into_parts();
Ok(parts)
} }
} }
fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
match result {
Ok(value) => value,
Err(err) => match err {},
}
}
pub(crate) fn take_body<S, B>(req: &mut RequestParts<S, B>) -> Result<B, BodyAlreadyExtracted> {
req.take_body().ok_or(BodyAlreadyExtracted)
}

View file

@ -1,41 +1,143 @@
use super::{FromRequest, RequestParts}; use super::{FromRequest, FromRequestParts};
use crate::response::{IntoResponse, Response}; use crate::response::{IntoResponse, Response};
use async_trait::async_trait; use async_trait::async_trait;
use http::request::{Parts, Request};
use std::convert::Infallible; use std::convert::Infallible;
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for () impl<S> FromRequestParts<S> for ()
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(_: &mut RequestParts<S, B>) -> Result<(), Self::Rejection> { async fn from_request_parts(_: &mut Parts, _: &S) -> Result<(), Self::Rejection> {
Ok(()) Ok(())
} }
} }
macro_rules! impl_from_request { macro_rules! impl_from_request {
() => {}; (
[$($ty:ident),*], $last:ident
( $($ty:ident),* $(,)? ) => { ) => {
#[async_trait] #[async_trait]
#[allow(non_snake_case)] #[allow(non_snake_case, unused_mut, unused_variables)]
impl<S, B, $($ty,)*> FromRequest<S, B> for ($($ty,)*) impl<S, $($ty,)* $last> FromRequestParts<S> for ($($ty,)* $last,)
where where
$( $ty: FromRequest<S, B> + Send, )* $( $ty: FromRequestParts<S> + Send, )*
B: Send, $last: FromRequestParts<S> + Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = Response; type Rejection = Response;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
$( let $ty = $ty::from_request(req).await.map_err(|err| err.into_response())?; )* $(
Ok(($($ty,)*)) let $ty = $ty::from_request_parts(parts, state)
.await
.map_err(|err| err.into_response())?;
)*
let $last = $last::from_request_parts(parts, state)
.await
.map_err(|err| err.into_response())?;
Ok(($($ty,)* $last,))
}
}
// This impl must not be generic over M, otherwise it would conflict with the blanket
// implementation of `FromRequest<S, B, Mut>` for `T: FromRequestParts<S>`.
#[async_trait]
#[allow(non_snake_case, unused_mut, unused_variables)]
impl<S, B, $($ty,)* $last> FromRequest<S, B> for ($($ty,)* $last,)
where
$( $ty: FromRequestParts<S> + Send, )*
$last: FromRequest<S, B> + Send,
B: Send + 'static,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let (mut parts, body) = req.into_parts();
$(
let $ty = $ty::from_request_parts(&mut parts, state).await.map_err(|err| err.into_response())?;
)*
let req = Request::from_parts(parts, body);
let $last = $last::from_request(req, state).await.map_err(|err| err.into_response())?;
Ok(($($ty,)* $last,))
} }
} }
}; };
} }
all_the_tuples!(impl_from_request); impl_from_request!([], T1);
impl_from_request!([T1], T2);
impl_from_request!([T1, T2], T3);
impl_from_request!([T1, T2, T3], T4);
impl_from_request!([T1, T2, T3, T4], T5);
impl_from_request!([T1, T2, T3, T4, T5], T6);
impl_from_request!([T1, T2, T3, T4, T5, T6], T7);
impl_from_request!([T1, T2, T3, T4, T5, T6, T7], T8);
impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8], T9);
impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10);
impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11);
impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12);
impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13);
impl_from_request!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13],
T14
);
impl_from_request!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14],
T15
);
impl_from_request!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15],
T16
);
#[cfg(test)]
mod tests {
use bytes::Bytes;
use http::Method;
use crate::extract::{FromRequest, FromRequestParts};
fn assert_from_request<M, T>()
where
T: FromRequest<(), http_body::Full<Bytes>, M>,
{
}
fn assert_from_request_parts<T: FromRequestParts<()>>() {}
#[test]
fn unit() {
assert_from_request_parts::<()>();
assert_from_request::<_, ()>();
}
#[test]
fn tuple_of_one() {
assert_from_request_parts::<(Method,)>();
assert_from_request::<_, (Method,)>();
assert_from_request::<_, (Bytes,)>();
}
#[test]
fn tuple_of_two() {
assert_from_request_parts::<((), ())>();
assert_from_request::<_, ((), ())>();
assert_from_request::<_, (Method, Bytes)>();
}
#[test]
fn nested_tuple() {
assert_from_request_parts::<(((Method,),),)>();
assert_from_request::<_, ((((Bytes,),),),)>();
}
}

View file

@ -4,15 +4,50 @@
//! //!
//! ``` //! ```
//! use axum_extra::either::Either3; //! use axum_extra::either::Either3;
//! use axum::{body::Bytes, Json}; //! use axum::{
//! body::Bytes,
//! Router,
//! async_trait,
//! routing::get,
//! extract::FromRequestParts,
//! };
//!
//! // extractors for checking permissions
//! struct AdminPermissions {}
//!
//! #[async_trait]
//! impl<S> FromRequestParts<S> for AdminPermissions
//! where
//! S: Send + Sync,
//! {
//! // check for admin permissions...
//! # type Rejection = ();
//! # async fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> Result<Self, Self::Rejection> {
//! # todo!()
//! # }
//! }
//!
//! struct User {}
//!
//! #[async_trait]
//! impl<S> FromRequestParts<S> for User
//! where
//! S: Send + Sync,
//! {
//! // check for a logged in user...
//! # type Rejection = ();
//! # async fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> Result<Self, Self::Rejection> {
//! # todo!()
//! # }
//! }
//! //!
//! async fn handler( //! async fn handler(
//! body: Either3<Json<serde_json::Value>, String, Bytes>, //! body: Either3<AdminPermissions, User, ()>,
//! ) { //! ) {
//! match body { //! match body {
//! Either3::E1(json) => { /* ... */ } //! Either3::E1(admin) => { /* ... */ }
//! Either3::E2(string) => { /* ... */ } //! Either3::E2(user) => { /* ... */ }
//! Either3::E3(bytes) => { /* ... */ } //! Either3::E3(guest) => { /* ... */ }
//! } //! }
//! } //! }
//! # //! #
@ -60,9 +95,10 @@
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRequest, RequestParts}, extract::FromRequestParts,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use http::request::Parts;
/// Combines two extractors or responses into a single type. /// Combines two extractors or responses into a single type.
/// ///
@ -190,23 +226,22 @@ macro_rules! impl_traits_for_either {
$last:ident $(,)? $last:ident $(,)?
) => { ) => {
#[async_trait] #[async_trait]
impl<S, B, $($ident),*, $last> FromRequest<S, B> for $either<$($ident),*, $last> impl<S, $($ident),*, $last> FromRequestParts<S> for $either<$($ident),*, $last>
where where
$($ident: FromRequest<S, B>),*, $($ident: FromRequestParts<S>),*,
$last: FromRequest<S, B>, $last: FromRequestParts<S>,
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = $last::Rejection; type Rejection = $last::Rejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
$( $(
if let Ok(value) = req.extract().await { if let Ok(value) = FromRequestParts::from_request_parts(parts, state).await {
return Ok(Self::$ident(value)); return Ok(Self::$ident(value));
} }
)* )*
req.extract().await.map(Self::$last) FromRequestParts::from_request_parts(parts, state).await.map(Self::$last)
} }
} }

View file

@ -1,7 +1,8 @@
use axum::{ use axum::{
async_trait, async_trait,
extract::{Extension, FromRequest, RequestParts}, extract::{Extension, FromRequest, FromRequestParts},
}; };
use http::{request::Parts, Request};
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
/// Cache results of other extractors. /// Cache results of other extractors.
@ -20,24 +21,23 @@ use std::ops::{Deref, DerefMut};
/// use axum_extra::extract::Cached; /// use axum_extra::extract::Cached;
/// use axum::{ /// use axum::{
/// async_trait, /// async_trait,
/// extract::{FromRequest, RequestParts}, /// extract::FromRequestParts,
/// body::BoxBody, /// body::BoxBody,
/// response::{IntoResponse, Response}, /// response::{IntoResponse, Response},
/// http::StatusCode, /// http::{StatusCode, request::Parts},
/// }; /// };
/// ///
/// #[derive(Clone)] /// #[derive(Clone)]
/// struct Session { /* ... */ } /// struct Session { /* ... */ }
/// ///
/// #[async_trait] /// #[async_trait]
/// impl<S, B> FromRequest<S, B> for Session /// impl<S> FromRequestParts<S> for Session
/// where /// where
/// B: Send,
/// S: Send + Sync, /// S: Send + Sync,
/// { /// {
/// type Rejection = (StatusCode, String); /// type Rejection = (StatusCode, String);
/// ///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
/// // load session... /// // load session...
/// # unimplemented!() /// # unimplemented!()
/// } /// }
@ -46,19 +46,18 @@ use std::ops::{Deref, DerefMut};
/// struct CurrentUser { /* ... */ } /// struct CurrentUser { /* ... */ }
/// ///
/// #[async_trait] /// #[async_trait]
/// impl<S, B> FromRequest<S, B> for CurrentUser /// impl<S> FromRequestParts<S> for CurrentUser
/// where /// where
/// B: Send,
/// S: Send + Sync, /// S: Send + Sync,
/// { /// {
/// type Rejection = Response; /// type Rejection = Response;
/// ///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
/// // loading a `CurrentUser` requires first loading the `Session` /// // loading a `CurrentUser` requires first loading the `Session`
/// // /// //
/// // by using `Cached<Session>` we avoid extracting the session more than /// // by using `Cached<Session>` we avoid extracting the session more than
/// // once, in case other extractors for the same request also loads the session /// // once, in case other extractors for the same request also loads the session
/// let session: Session = Cached::<Session>::from_request(req) /// let session: Session = Cached::<Session>::from_request_parts(parts, state)
/// .await /// .await
/// .map_err(|err| err.into_response())? /// .map_err(|err| err.into_response())?
/// .0; /// .0;
@ -92,18 +91,40 @@ struct CachedEntry<T>(T);
#[async_trait] #[async_trait]
impl<S, B, T> FromRequest<S, B> for Cached<T> impl<S, B, T> FromRequest<S, B> for Cached<T>
where where
B: Send, B: Send + 'static,
S: Send + Sync, S: Send + Sync,
T: FromRequest<S, B> + Clone + Send + Sync + 'static, T: FromRequestParts<S> + Clone + Send + Sync + 'static,
{ {
type Rejection = T::Rejection; type Rejection = T::Rejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
match Extension::<CachedEntry<T>>::from_request(req).await { let (mut parts, _) = req.into_parts();
match Extension::<CachedEntry<T>>::from_request_parts(&mut parts, state).await {
Ok(Extension(CachedEntry(value))) => Ok(Self(value)), Ok(Extension(CachedEntry(value))) => Ok(Self(value)),
Err(_) => { Err(_) => {
let value = T::from_request(req).await?; let value = T::from_request_parts(&mut parts, state).await?;
req.extensions_mut().insert(CachedEntry(value.clone())); parts.extensions.insert(CachedEntry(value.clone()));
Ok(Self(value))
}
}
}
}
#[async_trait]
impl<S, T> FromRequestParts<S> for Cached<T>
where
S: Send + Sync,
T: FromRequestParts<S> + Clone + Send + Sync + 'static,
{
type Rejection = T::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
match Extension::<CachedEntry<T>>::from_request_parts(parts, state).await {
Ok(Extension(CachedEntry(value))) => Ok(Self(value)),
Err(_) => {
let value = T::from_request_parts(parts, state).await?;
parts.extensions.insert(CachedEntry(value.clone()));
Ok(Self(value)) Ok(Self(value))
} }
} }
@ -127,7 +148,8 @@ impl<T> DerefMut for Cached<T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use axum::http::Request; use axum::{extract::FromRequestParts, http::Request};
use http::request::Parts;
use std::{ use std::{
convert::Infallible, convert::Infallible,
sync::atomic::{AtomicU32, Ordering}, sync::atomic::{AtomicU32, Ordering},
@ -142,25 +164,33 @@ mod tests {
struct Extractor(Instant); struct Extractor(Instant);
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for Extractor impl<S> FromRequestParts<S> for Extractor
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(
_parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
COUNTER.fetch_add(1, Ordering::SeqCst); COUNTER.fetch_add(1, Ordering::SeqCst);
Ok(Self(Instant::now())) Ok(Self(Instant::now()))
} }
} }
let mut req = RequestParts::new(Request::new(())); let (mut parts, _) = Request::new(()).into_parts();
let first = Cached::<Extractor>::from_request(&mut req).await.unwrap().0; let first = Cached::<Extractor>::from_request_parts(&mut parts, &())
.await
.unwrap()
.0;
assert_eq!(COUNTER.load(Ordering::SeqCst), 1); assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
let second = Cached::<Extractor>::from_request(&mut req).await.unwrap().0; let second = Cached::<Extractor>::from_request_parts(&mut parts, &())
.await
.unwrap()
.0;
assert_eq!(COUNTER.load(Ordering::SeqCst), 1); assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
assert_eq!(first, second); assert_eq!(first, second);

View file

@ -4,11 +4,12 @@
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRequest, RequestParts}, extract::FromRequestParts,
response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
}; };
use http::{ use http::{
header::{COOKIE, SET_COOKIE}, header::{COOKIE, SET_COOKIE},
request::Parts,
HeaderMap, HeaderMap,
}; };
use std::convert::Infallible; use std::convert::Infallible;
@ -88,15 +89,14 @@ pub struct CookieJar {
} }
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for CookieJar impl<S> FromRequestParts<S> for CookieJar
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
Ok(Self::from_headers(req.headers())) Ok(Self::from_headers(&parts.headers))
} }
} }
@ -115,7 +115,9 @@ impl CookieJar {
/// The cookies in `headers` will be added to the jar. /// The cookies in `headers` will be added to the jar.
/// ///
/// This is inteded to be used in middleware and other places where it might be difficult to /// This is inteded to be used in middleware and other places where it might be difficult to
/// run extractors. Normally you should create `CookieJar`s through [`FromRequest`]. /// run extractors. Normally you should create `CookieJar`s through [`FromRequestParts`].
///
/// [`FromRequestParts`]: axum::extract::FromRequestParts
pub fn from_headers(headers: &HeaderMap) -> Self { pub fn from_headers(headers: &HeaderMap) -> Self {
let mut jar = cookie::CookieJar::new(); let mut jar = cookie::CookieJar::new();
for cookie in cookies_from_request(headers) { for cookie in cookies_from_request(headers) {
@ -127,10 +129,12 @@ impl CookieJar {
/// Create a new empty `CookieJar`. /// Create a new empty `CookieJar`.
/// ///
/// This is inteded to be used in middleware and other places where it might be difficult to /// This is inteded to be used in middleware and other places where it might be difficult to
/// run extractors. Normally you should create `CookieJar`s through [`FromRequest`]. /// run extractors. Normally you should create `CookieJar`s through [`FromRequestParts`].
/// ///
/// If you need a jar that contains the headers from a request use `impl From<&HeaderMap> for /// If you need a jar that contains the headers from a request use `impl From<&HeaderMap> for
/// CookieJar`. /// CookieJar`.
///
/// [`FromRequestParts`]: axum::extract::FromRequestParts
pub fn new() -> Self { pub fn new() -> Self {
Self::default() Self::default()
} }

View file

@ -1,11 +1,11 @@
use super::{cookies_from_request, set_cookies, Cookie, Key}; use super::{cookies_from_request, set_cookies, Cookie, Key};
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRef, FromRequest, RequestParts}, extract::{FromRef, FromRequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
}; };
use cookie::PrivateJar; use cookie::PrivateJar;
use http::HeaderMap; use http::{request::Parts, HeaderMap};
use std::{convert::Infallible, fmt, marker::PhantomData}; use std::{convert::Infallible, fmt, marker::PhantomData};
/// Extractor that grabs private cookies from the request and manages the jar. /// Extractor that grabs private cookies from the request and manages the jar.
@ -87,22 +87,21 @@ impl<K> fmt::Debug for PrivateCookieJar<K> {
} }
#[async_trait] #[async_trait]
impl<S, B, K> FromRequest<S, B> for PrivateCookieJar<K> impl<S, K> FromRequestParts<S> for PrivateCookieJar<K>
where where
B: Send,
S: Send + Sync, S: Send + Sync,
K: FromRef<S> + Into<Key>, K: FromRef<S> + Into<Key>,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let k = K::from_ref(req.state()); let k = K::from_ref(state);
let key = k.into(); let key = k.into();
let PrivateCookieJar { let PrivateCookieJar {
jar, jar,
key, key,
_marker: _, _marker: _,
} = PrivateCookieJar::from_headers(req.headers(), key); } = PrivateCookieJar::from_headers(&parts.headers, key);
Ok(PrivateCookieJar { Ok(PrivateCookieJar {
jar, jar,
key, key,
@ -117,7 +116,9 @@ impl PrivateCookieJar {
/// The valid cookies in `headers` will be added to the jar. /// The valid cookies in `headers` will be added to the jar.
/// ///
/// This is inteded to be used in middleware and other where places it might be difficult to /// This is inteded to be used in middleware and other where places it might be difficult to
/// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequest`]. /// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequestParts`].
///
/// [`FromRequestParts`]: axum::extract::FromRequestParts
pub fn from_headers(headers: &HeaderMap, key: Key) -> Self { pub fn from_headers(headers: &HeaderMap, key: Key) -> Self {
let mut jar = cookie::CookieJar::new(); let mut jar = cookie::CookieJar::new();
let mut private_jar = jar.private_mut(&key); let mut private_jar = jar.private_mut(&key);
@ -137,7 +138,9 @@ impl PrivateCookieJar {
/// Create a new empty `PrivateCookieJarIter`. /// Create a new empty `PrivateCookieJarIter`.
/// ///
/// This is inteded to be used in middleware and other places where it might be difficult to /// This is inteded to be used in middleware and other places where it might be difficult to
/// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequest`]. /// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequestParts`].
///
/// [`FromRequestParts`]: axum::extract::FromRequestParts
pub fn new(key: Key) -> Self { pub fn new(key: Key) -> Self {
Self { Self {
jar: Default::default(), jar: Default::default(),

View file

@ -1,12 +1,12 @@
use super::{cookies_from_request, set_cookies}; use super::{cookies_from_request, set_cookies};
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRef, FromRequest, RequestParts}, extract::{FromRef, FromRequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
}; };
use cookie::SignedJar; use cookie::SignedJar;
use cookie::{Cookie, Key}; use cookie::{Cookie, Key};
use http::HeaderMap; use http::{request::Parts, HeaderMap};
use std::{convert::Infallible, fmt, marker::PhantomData}; use std::{convert::Infallible, fmt, marker::PhantomData};
/// Extractor that grabs signed cookies from the request and manages the jar. /// Extractor that grabs signed cookies from the request and manages the jar.
@ -105,22 +105,21 @@ impl<K> fmt::Debug for SignedCookieJar<K> {
} }
#[async_trait] #[async_trait]
impl<S, B, K> FromRequest<S, B> for SignedCookieJar<K> impl<S, K> FromRequestParts<S> for SignedCookieJar<K>
where where
B: Send,
S: Send + Sync, S: Send + Sync,
K: FromRef<S> + Into<Key>, K: FromRef<S> + Into<Key>,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let k = K::from_ref(req.state()); let k = K::from_ref(state);
let key = k.into(); let key = k.into();
let SignedCookieJar { let SignedCookieJar {
jar, jar,
key, key,
_marker: _, _marker: _,
} = SignedCookieJar::from_headers(req.headers(), key); } = SignedCookieJar::from_headers(&parts.headers, key);
Ok(SignedCookieJar { Ok(SignedCookieJar {
jar, jar,
key, key,
@ -135,7 +134,9 @@ impl SignedCookieJar {
/// The valid cookies in `headers` will be added to the jar. /// The valid cookies in `headers` will be added to the jar.
/// ///
/// This is inteded to be used in middleware and other places where it might be difficult to /// This is inteded to be used in middleware and other places where it might be difficult to
/// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequest`]. /// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequestParts`].
///
/// [`FromRequestParts`]: axum::extract::FromRequestParts
pub fn from_headers(headers: &HeaderMap, key: Key) -> Self { pub fn from_headers(headers: &HeaderMap, key: Key) -> Self {
let mut jar = cookie::CookieJar::new(); let mut jar = cookie::CookieJar::new();
let mut signed_jar = jar.signed_mut(&key); let mut signed_jar = jar.signed_mut(&key);
@ -155,7 +156,9 @@ impl SignedCookieJar {
/// Create a new empty `SignedCookieJar`. /// Create a new empty `SignedCookieJar`.
/// ///
/// This is inteded to be used in middleware and other places where it might be difficult to /// This is inteded to be used in middleware and other places where it might be difficult to
/// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequest`]. /// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequestParts`].
///
/// [`FromRequestParts`]: axum::extract::FromRequestParts
pub fn new(key: Key) -> Self { pub fn new(key: Key) -> Self {
Self { Self {
jar: Default::default(), jar: Default::default(),

View file

@ -3,12 +3,12 @@ use axum::{
body::HttpBody, body::HttpBody,
extract::{ extract::{
rejection::{FailedToDeserializeQueryString, FormRejection, InvalidFormContentType}, rejection::{FailedToDeserializeQueryString, FormRejection, InvalidFormContentType},
FromRequest, RequestParts, FromRequest,
}, },
BoxError, BoxError,
}; };
use bytes::Bytes; use bytes::Bytes;
use http::{header, Method}; use http::{header, HeaderMap, Method, Request};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::ops::Deref; use std::ops::Deref;
@ -58,25 +58,25 @@ impl<T> Deref for Form<T> {
impl<T, S, B> FromRequest<S, B> for Form<T> impl<T, S, B> FromRequest<S, B> for Form<T>
where where
T: DeserializeOwned, T: DeserializeOwned,
B: HttpBody + Send, B: HttpBody + Send + 'static,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = FormRejection; type Rejection = FormRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
if req.method() == Method::GET { if req.method() == Method::GET {
let query = req.uri().query().unwrap_or_default(); let query = req.uri().query().unwrap_or_default();
let value = serde_html_form::from_str(query) let value = serde_html_form::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?; .map_err(FailedToDeserializeQueryString::__private_new)?;
Ok(Form(value)) Ok(Form(value))
} else { } else {
if !has_content_type(req, &mime::APPLICATION_WWW_FORM_URLENCODED) { if !has_content_type(req.headers(), &mime::APPLICATION_WWW_FORM_URLENCODED) {
return Err(InvalidFormContentType::default().into()); return Err(InvalidFormContentType::default().into());
} }
let bytes = Bytes::from_request(req).await?; let bytes = Bytes::from_request(req, state).await?;
let value = serde_html_form::from_bytes(&bytes) let value = serde_html_form::from_bytes(&bytes)
.map_err(FailedToDeserializeQueryString::__private_new)?; .map_err(FailedToDeserializeQueryString::__private_new)?;
@ -86,8 +86,8 @@ where
} }
// this is duplicated in `axum/src/extract/mod.rs` // this is duplicated in `axum/src/extract/mod.rs`
fn has_content_type<S, B>(req: &RequestParts<S, B>, expected_content_type: &mime::Mime) -> bool { fn has_content_type(headers: &HeaderMap, expected_content_type: &mime::Mime) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
content_type content_type
} else { } else {
return false; return false;

View file

@ -2,9 +2,10 @@ use axum::{
async_trait, async_trait,
extract::{ extract::{
rejection::{FailedToDeserializeQueryString, QueryRejection}, rejection::{FailedToDeserializeQueryString, QueryRejection},
FromRequest, RequestParts, FromRequestParts,
}, },
}; };
use http::request::Parts;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::ops::Deref; use std::ops::Deref;
@ -58,16 +59,15 @@ use std::ops::Deref;
pub struct Query<T>(pub T); pub struct Query<T>(pub T);
#[async_trait] #[async_trait]
impl<T, S, B> FromRequest<S, B> for Query<T> impl<T, S> FromRequestParts<S> for Query<T>
where where
T: DeserializeOwned, T: DeserializeOwned,
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = QueryRejection; type Rejection = QueryRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let query = req.uri().query().unwrap_or_default(); let query = parts.uri.query().unwrap_or_default();
let value = serde_html_form::from_str(query) let value = serde_html_form::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?; .map_err(FailedToDeserializeQueryString::__private_new)?;
Ok(Query(value)) Ok(Query(value))

View file

@ -1,6 +1,8 @@
use axum::async_trait; use axum::async_trait;
use axum::extract::{FromRequest, RequestParts}; use axum::extract::{FromRequest, FromRequestParts};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use http::request::Parts;
use http::Request;
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
@ -109,23 +111,40 @@ impl<E, R> DerefMut for WithRejection<E, R> {
#[async_trait] #[async_trait]
impl<B, E, R, S> FromRequest<S, B> for WithRejection<E, R> impl<B, E, R, S> FromRequest<S, B> for WithRejection<E, R>
where where
B: Send, B: Send + 'static,
S: Send + Sync, S: Send + Sync,
E: FromRequest<S, B>, E: FromRequest<S, B>,
R: From<E::Rejection> + IntoResponse, R: From<E::Rejection> + IntoResponse,
{ {
type Rejection = R; type Rejection = R;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let extractor = req.extract::<E>().await?; let extractor = E::from_request(req, state).await?;
Ok(WithRejection(extractor, PhantomData))
}
}
#[async_trait]
impl<E, R, S> FromRequestParts<S> for WithRejection<E, R>
where
S: Send + Sync,
E: FromRequestParts<S>,
R: From<E::Rejection> + IntoResponse,
{
type Rejection = R;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let extractor = E::from_request_parts(parts, state).await?;
Ok(WithRejection(extractor, PhantomData)) Ok(WithRejection(extractor, PhantomData))
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use axum::extract::FromRequestParts;
use axum::http::Request; use axum::http::Request;
use axum::response::Response; use axum::response::Response;
use http::request::Parts;
use super::*; use super::*;
@ -135,14 +154,16 @@ mod tests {
struct TestRejection; struct TestRejection;
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for TestExtractor impl<S> FromRequestParts<S> for TestExtractor
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = (); type Rejection = ();
async fn from_request(_: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(
_parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
Err(()) Err(())
} }
} }
@ -159,12 +180,14 @@ mod tests {
} }
} }
let mut req = RequestParts::new(Request::new(())); let req = Request::new(());
let result = WithRejection::<TestExtractor, TestRejection>::from_request(req, &()).await;
assert!(matches!(result, Err(TestRejection)));
let result = req let (mut parts, _) = Request::new(()).into_parts();
.extract::<WithRejection<TestExtractor, TestRejection>>() let result =
WithRejection::<TestExtractor, TestRejection>::from_request_parts(&mut parts, &())
.await; .await;
assert!(matches!(result, Err(TestRejection)));
assert!(matches!(result, Err(TestRejection)))
} }
} }

View file

@ -1,7 +1,7 @@
//! Additional handler utilities. //! Additional handler utilities.
use axum::{ use axum::{
extract::{FromRequest, RequestParts}, extract::FromRequest,
handler::Handler, handler::Handler,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
@ -26,8 +26,8 @@ pub trait HandlerCallWithExtractors<T, S, B>: Sized {
/// Call the handler with the extracted inputs. /// Call the handler with the extracted inputs.
fn call( fn call(
self, self,
state: Arc<S>,
extractors: T, extractors: T,
state: Arc<S>,
) -> <Self as HandlerCallWithExtractors<T, S, B>>::Future; ) -> <Self as HandlerCallWithExtractors<T, S, B>>::Future;
/// Conver this `HandlerCallWithExtractors` into [`Handler`]. /// Conver this `HandlerCallWithExtractors` into [`Handler`].
@ -51,7 +51,7 @@ pub trait HandlerCallWithExtractors<T, S, B>: Sized {
/// Router, /// Router,
/// async_trait, /// async_trait,
/// routing::get, /// routing::get,
/// extract::FromRequest, /// extract::FromRequestParts,
/// }; /// };
/// ///
/// // handlers for varying levels of access /// // handlers for varying levels of access
@ -71,14 +71,13 @@ pub trait HandlerCallWithExtractors<T, S, B>: Sized {
/// struct AdminPermissions {} /// struct AdminPermissions {}
/// ///
/// #[async_trait] /// #[async_trait]
/// impl<S, B> FromRequest<S, B> for AdminPermissions /// impl<S> FromRequestParts<S> for AdminPermissions
/// where /// where
/// B: Send,
/// S: Send + Sync, /// S: Send + Sync,
/// { /// {
/// // check for admin permissions... /// // check for admin permissions...
/// # type Rejection = (); /// # type Rejection = ();
/// # async fn from_request(req: &mut axum::extract::RequestParts<S, B>) -> Result<Self, Self::Rejection> { /// # async fn from_request_parts(parts: &mut http::request::Parts, state: &S) -> Result<Self, Self::Rejection> {
/// # todo!() /// # todo!()
/// # } /// # }
/// } /// }
@ -86,14 +85,13 @@ pub trait HandlerCallWithExtractors<T, S, B>: Sized {
/// struct User {} /// struct User {}
/// ///
/// #[async_trait] /// #[async_trait]
/// impl<S, B> FromRequest<S, B> for User /// impl<S> FromRequestParts<S> for User
/// where /// where
/// B: Send,
/// S: Send + Sync, /// S: Send + Sync,
/// { /// {
/// // check for a logged in user... /// // check for a logged in user...
/// # type Rejection = (); /// # type Rejection = ();
/// # async fn from_request(req: &mut axum::extract::RequestParts<S, B>) -> Result<Self, Self::Rejection> { /// # async fn from_request_parts(parts: &mut http::request::Parts, state: &S) -> Result<Self, Self::Rejection> {
/// # todo!() /// # todo!()
/// # } /// # }
/// } /// }
@ -134,8 +132,8 @@ macro_rules! impl_handler_call_with {
fn call( fn call(
self, self,
_state: Arc<S>,
($($ty,)*): ($($ty,)*), ($($ty,)*): ($($ty,)*),
_state: Arc<S>,
) -> <Self as HandlerCallWithExtractors<($($ty,)*), S, B>>::Future { ) -> <Self as HandlerCallWithExtractors<($($ty,)*), S, B>>::Future {
self($($ty,)*).map(IntoResponse::into_response) self($($ty,)*).map(IntoResponse::into_response)
} }
@ -180,11 +178,10 @@ where
{ {
type Future = BoxFuture<'static, Response>; type Future = BoxFuture<'static, Response>;
fn call(self, state: Arc<S>, req: http::Request<B>) -> Self::Future { fn call(self, req: http::Request<B>, state: Arc<S>) -> Self::Future {
Box::pin(async move { Box::pin(async move {
let mut req = RequestParts::with_state_arc(Arc::clone(&state), req); match T::from_request(req, &state).await {
match req.extract::<T>().await { Ok(t) => self.handler.call(t, state).await,
Ok(t) => self.handler.call(state, t).await,
Err(rejection) => rejection.into_response(), Err(rejection) => rejection.into_response(),
} }
}) })

View file

@ -1,13 +1,12 @@
use super::HandlerCallWithExtractors; use super::HandlerCallWithExtractors;
use crate::either::Either; use crate::either::Either;
use axum::{ use axum::{
extract::{FromRequest, RequestParts}, extract::{FromRequest, FromRequestParts},
handler::Handler, handler::Handler,
http::Request, http::Request,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use futures_util::future::{BoxFuture, Either as EitherFuture, FutureExt, Map}; use futures_util::future::{BoxFuture, Either as EitherFuture, FutureExt, Map};
use http::StatusCode;
use std::{future::Future, marker::PhantomData, sync::Arc}; use std::{future::Future, marker::PhantomData, sync::Arc};
/// [`Handler`] that runs one [`Handler`] and if that rejects it'll fallback to another /// [`Handler`] that runs one [`Handler`] and if that rejects it'll fallback to another
@ -37,30 +36,30 @@ where
fn call( fn call(
self, self,
state: Arc<S>,
extractors: Either<Lt, Rt>, extractors: Either<Lt, Rt>,
state: Arc<S>,
) -> <Self as HandlerCallWithExtractors<Either<Lt, Rt>, S, B>>::Future { ) -> <Self as HandlerCallWithExtractors<Either<Lt, Rt>, S, B>>::Future {
match extractors { match extractors {
Either::E1(lt) => self Either::E1(lt) => self
.lhs .lhs
.call(state, lt) .call(lt, state)
.map(IntoResponse::into_response as _) .map(IntoResponse::into_response as _)
.left_future(), .left_future(),
Either::E2(rt) => self Either::E2(rt) => self
.rhs .rhs
.call(state, rt) .call(rt, state)
.map(IntoResponse::into_response as _) .map(IntoResponse::into_response as _)
.right_future(), .right_future(),
} }
} }
} }
impl<S, B, L, R, Lt, Rt> Handler<(Lt, Rt), S, B> for Or<L, R, Lt, Rt, S, B> impl<S, B, L, R, Lt, Rt, M> Handler<(M, Lt, Rt), S, B> for Or<L, R, Lt, Rt, S, B>
where where
L: HandlerCallWithExtractors<Lt, S, B> + Clone + Send + 'static, L: HandlerCallWithExtractors<Lt, S, B> + Clone + Send + 'static,
R: HandlerCallWithExtractors<Rt, S, B> + Clone + Send + 'static, R: HandlerCallWithExtractors<Rt, S, B> + Clone + Send + 'static,
Lt: FromRequest<S, B> + Send + 'static, Lt: FromRequestParts<S> + Send + 'static,
Rt: FromRequest<S, B> + Send + 'static, Rt: FromRequest<S, B, M> + Send + 'static,
Lt::Rejection: Send, Lt::Rejection: Send,
Rt::Rejection: Send, Rt::Rejection: Send,
B: Send + 'static, B: Send + 'static,
@ -69,19 +68,20 @@ where
// this puts `futures_util` in our public API but thats fine in axum-extra // this puts `futures_util` in our public API but thats fine in axum-extra
type Future = BoxFuture<'static, Response>; type Future = BoxFuture<'static, Response>;
fn call(self, state: Arc<S>, req: Request<B>) -> Self::Future { fn call(self, req: Request<B>, state: Arc<S>) -> Self::Future {
Box::pin(async move { Box::pin(async move {
let mut req = RequestParts::with_state_arc(Arc::clone(&state), req); let (mut parts, body) = req.into_parts();
if let Ok(lt) = req.extract::<Lt>().await { if let Ok(lt) = Lt::from_request_parts(&mut parts, &state).await {
return self.lhs.call(state, lt).await; return self.lhs.call(lt, state).await;
} }
if let Ok(rt) = req.extract::<Rt>().await { let req = Request::from_parts(parts, body);
return self.rhs.call(state, rt).await;
}
StatusCode::NOT_FOUND.into_response() match Rt::from_request(req, &state).await {
Ok(rt) => self.rhs.call(rt, state).await,
Err(rejection) => rejection.into_response(),
}
}) })
} }
} }

View file

@ -3,15 +3,17 @@
use axum::{ use axum::{
async_trait, async_trait,
body::{HttpBody, StreamBody}, body::{HttpBody, StreamBody},
extract::{rejection::BodyAlreadyExtracted, FromRequest, RequestParts}, extract::FromRequest,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
BoxError, BoxError,
}; };
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use futures_util::stream::{BoxStream, Stream, TryStream, TryStreamExt}; use futures_util::stream::{BoxStream, Stream, TryStream, TryStreamExt};
use http::Request;
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use std::{ use std::{
convert::Infallible,
io::{self, Write}, io::{self, Write},
marker::PhantomData, marker::PhantomData,
pin::Pin, pin::Pin,
@ -106,14 +108,14 @@ where
T: DeserializeOwned, T: DeserializeOwned,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = BodyAlreadyExtracted; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
// `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead` // `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead`
// so we can call `AsyncRead::lines` and then convert it back to a `Stream` // so we can call `AsyncRead::lines` and then convert it back to a `Stream`
let body = BodyStream {
let body = req.take_body().ok_or_else(BodyAlreadyExtracted::default)?; body: req.into_body(),
let body = BodyStream { body }; };
let stream = body let stream = body
.map_ok(Into::into) .map_ok(Into::into)

View file

@ -3,12 +3,12 @@
use axum::{ use axum::{
async_trait, async_trait,
body::{Bytes, HttpBody}, body::{Bytes, HttpBody},
extract::{rejection::BytesRejection, FromRequest, RequestParts}, extract::{rejection::BytesRejection, FromRequest},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
BoxError, BoxError,
}; };
use bytes::BytesMut; use bytes::BytesMut;
use http::StatusCode; use http::{Request, StatusCode};
use prost::Message; use prost::Message;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
@ -100,15 +100,15 @@ pub struct ProtoBuf<T>(pub T);
impl<T, S, B> FromRequest<S, B> for ProtoBuf<T> impl<T, S, B> FromRequest<S, B> for ProtoBuf<T>
where where
T: Message + Default, T: Message + Default,
B: HttpBody + Send, B: HttpBody + Send + 'static,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = ProtoBufRejection; type Rejection = ProtoBufRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let mut bytes = Bytes::from_request(req).await?; let mut bytes = Bytes::from_request(req, state).await?;
match T::decode(&mut bytes) { match T::decode(&mut bytes) {
Ok(value) => Ok(ProtoBuf(value)), Ok(value) => Ok(ProtoBuf(value)),

View file

@ -24,7 +24,7 @@ pub use self::resource::Resource;
pub use axum_macros::TypedPath; pub use axum_macros::TypedPath;
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
pub use self::typed::{FirstElementIs, TypedPath}; pub use self::typed::{SecondElementIs, TypedPath};
#[cfg(feature = "spa")] #[cfg(feature = "spa")]
pub use self::spa::SpaRouter; pub use self::spa::SpaRouter;
@ -41,7 +41,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_get<H, T, P>(self, handler: H) -> Self fn typed_get<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
/// Add a typed `DELETE` route to the router. /// Add a typed `DELETE` route to the router.
@ -54,7 +54,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_delete<H, T, P>(self, handler: H) -> Self fn typed_delete<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
/// Add a typed `HEAD` route to the router. /// Add a typed `HEAD` route to the router.
@ -67,7 +67,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_head<H, T, P>(self, handler: H) -> Self fn typed_head<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
/// Add a typed `OPTIONS` route to the router. /// Add a typed `OPTIONS` route to the router.
@ -80,7 +80,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_options<H, T, P>(self, handler: H) -> Self fn typed_options<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
/// Add a typed `PATCH` route to the router. /// Add a typed `PATCH` route to the router.
@ -93,7 +93,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_patch<H, T, P>(self, handler: H) -> Self fn typed_patch<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
/// Add a typed `POST` route to the router. /// Add a typed `POST` route to the router.
@ -106,7 +106,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_post<H, T, P>(self, handler: H) -> Self fn typed_post<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
/// Add a typed `PUT` route to the router. /// Add a typed `PUT` route to the router.
@ -119,7 +119,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_put<H, T, P>(self, handler: H) -> Self fn typed_put<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
/// Add a typed `TRACE` route to the router. /// Add a typed `TRACE` route to the router.
@ -132,7 +132,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_trace<H, T, P>(self, handler: H) -> Self fn typed_trace<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
/// Add another route to the router with an additional "trailing slash redirect" route. /// Add another route to the router with an additional "trailing slash redirect" route.
@ -184,7 +184,7 @@ where
fn typed_get<H, T, P>(self, handler: H) -> Self fn typed_get<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
self.route(P::PATH, axum::routing::get(handler)) self.route(P::PATH, axum::routing::get(handler))
@ -194,7 +194,7 @@ where
fn typed_delete<H, T, P>(self, handler: H) -> Self fn typed_delete<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
self.route(P::PATH, axum::routing::delete(handler)) self.route(P::PATH, axum::routing::delete(handler))
@ -204,7 +204,7 @@ where
fn typed_head<H, T, P>(self, handler: H) -> Self fn typed_head<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
self.route(P::PATH, axum::routing::head(handler)) self.route(P::PATH, axum::routing::head(handler))
@ -214,7 +214,7 @@ where
fn typed_options<H, T, P>(self, handler: H) -> Self fn typed_options<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
self.route(P::PATH, axum::routing::options(handler)) self.route(P::PATH, axum::routing::options(handler))
@ -224,7 +224,7 @@ where
fn typed_patch<H, T, P>(self, handler: H) -> Self fn typed_patch<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
self.route(P::PATH, axum::routing::patch(handler)) self.route(P::PATH, axum::routing::patch(handler))
@ -234,7 +234,7 @@ where
fn typed_post<H, T, P>(self, handler: H) -> Self fn typed_post<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
self.route(P::PATH, axum::routing::post(handler)) self.route(P::PATH, axum::routing::post(handler))
@ -244,7 +244,7 @@ where
fn typed_put<H, T, P>(self, handler: H) -> Self fn typed_put<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
self.route(P::PATH, axum::routing::put(handler)) self.route(P::PATH, axum::routing::put(handler))
@ -254,7 +254,7 @@ where
fn typed_trace<H, T, P>(self, handler: H) -> Self fn typed_trace<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, S, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: SecondElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
self.route(P::PATH, axum::routing::trace(handler)) self.route(P::PATH, axum::routing::trace(handler))

View file

@ -231,10 +231,10 @@ pub trait TypedPath: std::fmt::Display {
} }
} }
/// Utility trait used with [`RouterExt`] to ensure the first element of a tuple type is a /// Utility trait used with [`RouterExt`] to ensure the second element of a tuple type is a
/// given type. /// given type.
/// ///
/// If you see it in type errors its most likely because the first argument to your handler doesn't /// If you see it in type errors its most likely because the second argument to your handler doesn't
/// implement [`TypedPath`]. /// implement [`TypedPath`].
/// ///
/// You normally shouldn't have to use this trait directly. /// You normally shouldn't have to use this trait directly.
@ -242,56 +242,56 @@ pub trait TypedPath: std::fmt::Display {
/// It is sealed such that it cannot be implemented outside this crate. /// It is sealed such that it cannot be implemented outside this crate.
/// ///
/// [`RouterExt`]: super::RouterExt /// [`RouterExt`]: super::RouterExt
pub trait FirstElementIs<P>: Sealed {} pub trait SecondElementIs<P>: Sealed {}
macro_rules! impl_first_element_is { macro_rules! impl_second_element_is {
( $($ty:ident),* $(,)? ) => { ( $($ty:ident),* $(,)? ) => {
impl<P, $($ty,)*> FirstElementIs<P> for (P, $($ty,)*) impl<M, P, $($ty,)*> SecondElementIs<P> for (M, P, $($ty,)*)
where where
P: TypedPath P: TypedPath
{} {}
impl<P, $($ty,)*> Sealed for (P, $($ty,)*) impl<M, P, $($ty,)*> Sealed for (M, P, $($ty,)*)
where where
P: TypedPath P: TypedPath
{} {}
impl<P, $($ty,)*> FirstElementIs<P> for (Option<P>, $($ty,)*) impl<M, P, $($ty,)*> SecondElementIs<P> for (M, Option<P>, $($ty,)*)
where where
P: TypedPath P: TypedPath
{} {}
impl<P, $($ty,)*> Sealed for (Option<P>, $($ty,)*) impl<M, P, $($ty,)*> Sealed for (M, Option<P>, $($ty,)*)
where where
P: TypedPath P: TypedPath
{} {}
impl<P, E, $($ty,)*> FirstElementIs<P> for (Result<P, E>, $($ty,)*) impl<M, P, E, $($ty,)*> SecondElementIs<P> for (M, Result<P, E>, $($ty,)*)
where where
P: TypedPath P: TypedPath
{} {}
impl<P, E, $($ty,)*> Sealed for (Result<P, E>, $($ty,)*) impl<M, P, E, $($ty,)*> Sealed for (M, Result<P, E>, $($ty,)*)
where where
P: TypedPath P: TypedPath
{} {}
}; };
} }
impl_first_element_is!(); impl_second_element_is!();
impl_first_element_is!(T1); impl_second_element_is!(T1);
impl_first_element_is!(T1, T2); impl_second_element_is!(T1, T2);
impl_first_element_is!(T1, T2, T3); impl_second_element_is!(T1, T2, T3);
impl_first_element_is!(T1, T2, T3, T4); impl_second_element_is!(T1, T2, T3, T4);
impl_first_element_is!(T1, T2, T3, T4, T5); impl_second_element_is!(T1, T2, T3, T4, T5);
impl_first_element_is!(T1, T2, T3, T4, T5, T6); impl_second_element_is!(T1, T2, T3, T4, T5, T6);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);

View file

@ -10,9 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **change:** axum-macro's MSRV is now 1.60 ([#1239]) - **change:** axum-macro's MSRV is now 1.60 ([#1239])
- **added:** Support using a different rejection for `#[derive(FromRequest)]` - **added:** Support using a different rejection for `#[derive(FromRequest)]`
with `#[from_request(rejection(MyRejection))]` ([#1256]) with `#[from_request(rejection(MyRejection))]` ([#1256])
- **breaking:** `#[derive(FromRequest)]` will no longer generate a rejection
enum but instead generate `type Rejection = axum::response::Response`. Use the
new `#[from_request(rejection(MyRejection))]` attribute to change this.
The `rejection_derive` attribute has also been removed ([#1272])
[#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1239]: https://github.com/tokio-rs/axum/pull/1239
[#1256]: https://github.com/tokio-rs/axum/pull/1256 [#1256]: https://github.com/tokio-rs/axum/pull/1256
[#1272]: https://github.com/tokio-rs/axum/pull/1272
# 0.2.3 (27. June, 2022) # 0.2.3 (27. June, 2022)

View file

@ -1,14 +1,11 @@
use std::collections::HashSet;
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned}; use quote::{format_ident, quote, quote_spanned};
use std::collections::HashSet;
use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Token, Type}; use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Token, Type};
pub(crate) fn expand(mut attr: Attrs, item_fn: ItemFn) -> TokenStream { pub(crate) fn expand(mut attr: Attrs, item_fn: ItemFn) -> TokenStream {
let check_extractor_count = check_extractor_count(&item_fn); let check_extractor_count = check_extractor_count(&item_fn);
let check_request_last_extractor = check_request_last_extractor(&item_fn);
let check_path_extractor = check_path_extractor(&item_fn); let check_path_extractor = check_path_extractor(&item_fn);
let check_multiple_body_extractors = check_multiple_body_extractors(&item_fn);
let check_output_impls_into_response = check_output_impls_into_response(&item_fn); let check_output_impls_into_response = check_output_impls_into_response(&item_fn);
// If the function is generic, we can't reliably check its inputs or whether the future it // If the function is generic, we can't reliably check its inputs or whether the future it
@ -39,9 +36,7 @@ pub(crate) fn expand(mut attr: Attrs, item_fn: ItemFn) -> TokenStream {
quote! { quote! {
#item_fn #item_fn
#check_extractor_count #check_extractor_count
#check_request_last_extractor
#check_path_extractor #check_path_extractor
#check_multiple_body_extractors
#check_output_impls_into_response #check_output_impls_into_response
#check_inputs_and_future_send #check_inputs_and_future_send
} }
@ -135,22 +130,6 @@ fn extractor_idents(item_fn: &ItemFn) -> impl Iterator<Item = (usize, &syn::FnAr
}) })
} }
fn check_request_last_extractor(item_fn: &ItemFn) -> Option<TokenStream> {
let request_extractor_ident =
extractor_idents(item_fn).find(|(_, _, ident)| *ident == "Request");
if let Some((idx, fn_arg, _)) = request_extractor_ident {
if idx != item_fn.sig.inputs.len() - 1 {
return Some(
syn::Error::new_spanned(fn_arg, "`Request` extractor should always be last")
.to_compile_error(),
);
}
}
None
}
fn check_path_extractor(item_fn: &ItemFn) -> TokenStream { fn check_path_extractor(item_fn: &ItemFn) -> TokenStream {
let path_extractors = extractor_idents(item_fn) let path_extractors = extractor_idents(item_fn)
.filter(|(_, _, ident)| *ident == "Path") .filter(|(_, _, ident)| *ident == "Path")
@ -174,30 +153,14 @@ fn check_path_extractor(item_fn: &ItemFn) -> TokenStream {
} }
} }
fn check_multiple_body_extractors(item_fn: &ItemFn) -> TokenStream { fn is_self_pat_type(typed: &syn::PatType) -> bool {
let body_extractors = extractor_idents(item_fn) let ident = if let syn::Pat::Ident(ident) = &*typed.pat {
.filter(|(_, _, ident)| { &ident.ident
*ident == "String"
|| *ident == "Bytes"
|| *ident == "Json"
|| *ident == "RawBody"
|| *ident == "BodyStream"
|| *ident == "Multipart"
|| *ident == "Request"
})
.collect::<Vec<_>>();
if body_extractors.len() > 1 {
body_extractors
.into_iter()
.map(|(_, arg, _)| {
syn::Error::new_spanned(arg, "Only one body extractor can be applied")
.to_compile_error()
})
.collect()
} else { } else {
quote! {} return false;
} };
ident == "self"
} }
fn check_inputs_impls_from_request( fn check_inputs_impls_from_request(
@ -205,6 +168,11 @@ fn check_inputs_impls_from_request(
body_ty: &Type, body_ty: &Type,
state_ty: Type, state_ty: Type,
) -> TokenStream { ) -> TokenStream {
let takes_self = item_fn.sig.inputs.first().map_or(false, |arg| match arg {
FnArg::Receiver(_) => true,
FnArg::Typed(typed) => is_self_pat_type(typed),
});
item_fn item_fn
.sig .sig
.inputs .inputs
@ -227,21 +195,53 @@ fn check_inputs_impls_from_request(
FnArg::Typed(typed) => { FnArg::Typed(typed) => {
let ty = &typed.ty; let ty = &typed.ty;
let span = ty.span(); let span = ty.span();
if is_self_pat_type(typed) {
(span, syn::parse_quote!(Self))
} else {
(span, ty.clone()) (span, ty.clone())
} }
}
}; };
let name = format_ident!( let check_fn = format_ident!(
"__axum_macros_check_{}_{}_from_request", "__axum_macros_check_{}_{}_from_request_check",
item_fn.sig.ident, item_fn.sig.ident,
idx idx,
span = span,
); );
let call_check_fn = format_ident!(
"__axum_macros_check_{}_{}_from_request_call_check",
item_fn.sig.ident,
idx,
span = span,
);
let call_check_fn_body = if takes_self {
quote_spanned! {span=>
Self::#check_fn();
}
} else {
quote_spanned! {span=>
#check_fn();
}
};
quote_spanned! {span=> quote_spanned! {span=>
#[allow(warnings)] #[allow(warnings)]
fn #name() fn #check_fn<M>()
where where
#ty: ::axum::extract::FromRequest<#state_ty, #body_ty> + Send, #ty: ::axum::extract::FromRequest<#state_ty, #body_ty, M> + Send,
{} {}
// we have to call the function to actually trigger a compile error
// since the function is generic, just defining it is not enough
#[allow(warnings)]
fn #call_check_fn()
{
#call_check_fn_body
}
} }
}) })
.collect::<TokenStream>() .collect::<TokenStream>()
@ -380,11 +380,11 @@ fn check_future_send(item_fn: &ItemFn) -> TokenStream {
} }
fn self_receiver(item_fn: &ItemFn) -> Option<TokenStream> { fn self_receiver(item_fn: &ItemFn) -> Option<TokenStream> {
let takes_self = item_fn let takes_self = item_fn.sig.inputs.iter().any(|arg| match arg {
.sig FnArg::Receiver(_) => true,
.inputs FnArg::Typed(typed) => is_self_pat_type(typed),
.iter() });
.any(|arg| matches!(arg, syn::FnArg::Receiver(_)));
if takes_self { if takes_self {
return Some(quote! { Self:: }); return Some(quote! { Self:: });
} }

View file

@ -1,10 +1,8 @@
use self::attr::{ use self::attr::{
parse_container_attrs, parse_field_attrs, FromRequestContainerAttr, FromRequestFieldAttr, parse_container_attrs, parse_field_attrs, FromRequestContainerAttr, FromRequestFieldAttr,
RejectionDeriveOptOuts,
}; };
use heck::ToUpperCamelCase;
use proc_macro2::{Span, TokenStream}; use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned}; use quote::{quote, quote_spanned};
use syn::{punctuated::Punctuated, spanned::Spanned, Ident, Token}; use syn::{punctuated::Punctuated, spanned::Spanned, Ident, Token};
mod attr; mod attr;
@ -18,7 +16,7 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> {
generics, generics,
fields, fields,
semi_token: _, semi_token: _,
vis, vis: _,
struct_token: _, struct_token: _,
} = item; } = item;
@ -34,32 +32,15 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> {
generic_ident, generic_ident,
) )
} }
FromRequestContainerAttr::RejectionDerive(_, opt_outs) => {
error_on_generic_ident(generic_ident)?;
impl_struct_by_extracting_each_field(ident, fields, vis, opt_outs, None)
}
FromRequestContainerAttr::Rejection(rejection) => { FromRequestContainerAttr::Rejection(rejection) => {
error_on_generic_ident(generic_ident)?; error_on_generic_ident(generic_ident)?;
impl_struct_by_extracting_each_field( impl_struct_by_extracting_each_field(ident, fields, Some(rejection))
ident,
fields,
vis,
RejectionDeriveOptOuts::default(),
Some(rejection),
)
} }
FromRequestContainerAttr::None => { FromRequestContainerAttr::None => {
error_on_generic_ident(generic_ident)?; error_on_generic_ident(generic_ident)?;
impl_struct_by_extracting_each_field( impl_struct_by_extracting_each_field(ident, fields, None)
ident,
fields,
vis,
RejectionDeriveOptOuts::default(),
None,
)
} }
} }
} }
@ -88,12 +69,6 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> {
FromRequestContainerAttr::Via { path, rejection } => { FromRequestContainerAttr::Via { path, rejection } => {
impl_enum_by_extracting_all_at_once(ident, variants, path, rejection) impl_enum_by_extracting_all_at_once(ident, variants, path, rejection)
} }
FromRequestContainerAttr::RejectionDerive(rejection_derive, _) => {
Err(syn::Error::new_spanned(
rejection_derive,
"cannot use `rejection_derive` on enums",
))
}
FromRequestContainerAttr::Rejection(rejection) => Err(syn::Error::new_spanned( FromRequestContainerAttr::Rejection(rejection) => Err(syn::Error::new_spanned(
rejection, rejection,
"cannot use `rejection` without `via`", "cannot use `rejection` without `via`",
@ -197,22 +172,16 @@ fn error_on_generic_ident(generic_ident: Option<Ident>) -> syn::Result<()> {
fn impl_struct_by_extracting_each_field( fn impl_struct_by_extracting_each_field(
ident: syn::Ident, ident: syn::Ident,
fields: syn::Fields, fields: syn::Fields,
vis: syn::Visibility,
rejection_derive_opt_outs: RejectionDeriveOptOuts,
rejection: Option<syn::Path>, rejection: Option<syn::Path>,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
let extract_fields = extract_fields(&fields, &rejection)?; let extract_fields = extract_fields(&fields, &rejection)?;
let (rejection_ident, rejection) = if let Some(rejection) = rejection { let rejection_ident = if let Some(rejection) = rejection {
let rejection_ident = syn::parse_quote!(#rejection); quote!(#rejection)
(rejection_ident, None)
} else if has_no_fields(&fields) { } else if has_no_fields(&fields) {
(syn::parse_quote!(::std::convert::Infallible), None) quote!(::std::convert::Infallible)
} else { } else {
let rejection_ident = rejection_ident(&ident); quote!(::axum::response::Response)
let rejection =
extract_each_field_rejection(&ident, &fields, &vis, rejection_derive_opt_outs)?;
(rejection_ident, Some(rejection))
}; };
Ok(quote! { Ok(quote! {
@ -228,15 +197,14 @@ fn impl_struct_by_extracting_each_field(
type Rejection = #rejection_ident; type Rejection = #rejection_ident;
async fn from_request( async fn from_request(
req: &mut ::axum::extract::RequestParts<S, B>, mut req: axum::http::Request<B>,
state: &S,
) -> ::std::result::Result<Self, Self::Rejection> { ) -> ::std::result::Result<Self, Self::Rejection> {
::std::result::Result::Ok(Self { ::std::result::Result::Ok(Self {
#(#extract_fields)* #(#extract_fields)*
}) })
} }
} }
#rejection
}) })
} }
@ -248,11 +216,6 @@ fn has_no_fields(fields: &syn::Fields) -> bool {
} }
} }
fn rejection_ident(ident: &syn::Ident) -> syn::Type {
let ident = format_ident!("{}Rejection", ident);
syn::parse_quote!(#ident)
}
fn extract_fields( fn extract_fields(
fields: &syn::Fields, fields: &syn::Fields,
rejection: &Option<syn::Path>, rejection: &Option<syn::Path>,
@ -261,6 +224,8 @@ fn extract_fields(
.iter() .iter()
.enumerate() .enumerate()
.map(|(index, field)| { .map(|(index, field)| {
let is_last = fields.len() - 1 == index;
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?; let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
let member = if let Some(ident) = &field.ident { let member = if let Some(ident) = &field.ident {
@ -286,40 +251,79 @@ fn extract_fields(
} }
}; };
let rejection_variant_name = rejection_variant_name(field)?;
if peel_option(&field.ty).is_some() { if peel_option(&field.ty).is_some() {
if is_last {
Ok(quote_spanned! {ty_span=> Ok(quote_spanned! {ty_span=>
#member: { #member: {
::axum::extract::FromRequest::from_request(req) ::axum::extract::FromRequest::from_request(req, state)
.await .await
.ok() .ok()
.map(#into_inner) .map(#into_inner)
}, },
}) })
} else if peel_result_ok(&field.ty).is_some() { } else {
Ok(quote_spanned! {ty_span=> Ok(quote_spanned! {ty_span=>
#member: { #member: {
::axum::extract::FromRequest::from_request(req) let (mut parts, body) = req.into_parts();
let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state)
.await
.ok()
.map(#into_inner);
req = ::axum::http::Request::from_parts(parts, body);
value
},
})
}
} else if peel_result_ok(&field.ty).is_some() {
if is_last {
Ok(quote_spanned! {ty_span=>
#member: {
::axum::extract::FromRequest::from_request(req, state)
.await .await
.map(#into_inner) .map(#into_inner)
}, },
}) })
} else {
Ok(quote_spanned! {ty_span=>
#member: {
let (mut parts, body) = req.into_parts();
let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state)
.await
.map(#into_inner);
req = ::axum::http::Request::from_parts(parts, body);
value
},
})
}
} else { } else {
let map_err = if let Some(rejection) = rejection { let map_err = if let Some(rejection) = rejection {
quote! { <#rejection as ::std::convert::From<_>>::from } quote! { <#rejection as ::std::convert::From<_>>::from }
} else { } else {
quote! { Self::Rejection::#rejection_variant_name } quote! { ::axum::response::IntoResponse::into_response }
}; };
if is_last {
Ok(quote_spanned! {ty_span=> Ok(quote_spanned! {ty_span=>
#member: { #member: {
::axum::extract::FromRequest::from_request(req) ::axum::extract::FromRequest::from_request(req, state)
.await .await
.map(#into_inner) .map(#into_inner)
.map_err(#map_err)? .map_err(#map_err)?
}, },
}) })
} else {
Ok(quote_spanned! {ty_span=>
#member: {
let (mut parts, body) = req.into_parts();
let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state)
.await
.map(#into_inner)
.map_err(#map_err)?;
req = ::axum::http::Request::from_parts(parts, body);
value
},
})
}
} }
}) })
.collect() .collect()
@ -387,199 +391,6 @@ fn peel_result_ok(ty: &syn::Type) -> Option<&syn::Type> {
} }
} }
fn extract_each_field_rejection(
ident: &syn::Ident,
fields: &syn::Fields,
vis: &syn::Visibility,
rejection_derive_opt_outs: RejectionDeriveOptOuts,
) -> syn::Result<TokenStream> {
let rejection_ident = rejection_ident(ident);
let variants = fields
.iter()
.map(|field| {
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
let field_ty = &field.ty;
let ty_span = field_ty.span();
let variant_name = rejection_variant_name(field)?;
let extractor_ty = if let Some((_, path)) = via {
if let Some(inner) = peel_option(field_ty) {
quote_spanned! {ty_span=>
::std::option::Option<#path<#inner>>
}
} else if let Some(inner) = peel_result_ok(field_ty) {
quote_spanned! {ty_span=>
::std::result::Result<#path<#inner>, TypedHeaderRejection>
}
} else {
quote_spanned! {ty_span=> #path<#field_ty> }
}
} else {
quote_spanned! {ty_span=> #field_ty }
};
Ok(quote_spanned! {ty_span=>
#[allow(non_camel_case_types)]
#variant_name(<#extractor_ty as ::axum::extract::FromRequest<(), ::axum::body::Body>>::Rejection),
})
})
.collect::<syn::Result<Vec<_>>>()?;
let impl_into_response = {
let arms = fields
.iter()
.map(|field| {
let variant_name = rejection_variant_name(field)?;
Ok(quote! {
Self::#variant_name(inner) => inner.into_response(),
})
})
.collect::<syn::Result<Vec<_>>>()?;
quote! {
#[automatically_derived]
impl ::axum::response::IntoResponse for #rejection_ident {
fn into_response(self) -> ::axum::response::Response {
match self {
#(#arms)*
}
}
}
}
};
let impl_display = if rejection_derive_opt_outs.derive_display() {
let arms = fields
.iter()
.map(|field| {
let variant_name = rejection_variant_name(field)?;
Ok(quote! {
Self::#variant_name(inner) => inner.fmt(f),
})
})
.collect::<syn::Result<Vec<_>>>()?;
Some(quote! {
#[automatically_derived]
impl ::std::fmt::Display for #rejection_ident {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
match self {
#(#arms)*
}
}
}
})
} else {
None
};
let impl_error = if rejection_derive_opt_outs.derive_error() {
let arms = fields
.iter()
.map(|field| {
let variant_name = rejection_variant_name(field)?;
Ok(quote! {
Self::#variant_name(inner) => Some(inner),
})
})
.collect::<syn::Result<Vec<_>>>()?;
Some(quote! {
#[automatically_derived]
impl ::std::error::Error for #rejection_ident {
fn source(&self) -> ::std::option::Option<&(dyn ::std::error::Error + 'static)> {
match self {
#(#arms)*
}
}
}
})
} else {
None
};
let impl_debug = rejection_derive_opt_outs.derive_debug().then(|| {
quote! { #[derive(Debug)] }
});
Ok(quote! {
#impl_debug
#vis enum #rejection_ident {
#(#variants)*
}
#impl_into_response
#impl_display
#impl_error
})
}
fn rejection_variant_name(field: &syn::Field) -> syn::Result<syn::Ident> {
fn rejection_variant_name_for_type(out: &mut String, ty: &syn::Type) -> syn::Result<()> {
if let syn::Type::Path(type_path) = ty {
let segment = type_path
.path
.segments
.last()
.ok_or_else(|| syn::Error::new_spanned(ty, "Empty type path"))?;
out.push_str(&segment.ident.to_string());
match &segment.arguments {
syn::PathArguments::AngleBracketed(args) => {
let ty = if args.args.len() == 1 {
args.args.last().unwrap()
} else if args.args.len() == 2 {
if segment.ident == "Result" {
args.args.first().unwrap()
} else {
return Err(syn::Error::new_spanned(
segment,
"Only `Result<T, E>` is supported with two generics type paramters",
));
}
} else {
return Err(syn::Error::new_spanned(
&args.args,
"Expected exactly one or two type paramters",
));
};
if let syn::GenericArgument::Type(ty) = ty {
rejection_variant_name_for_type(out, ty)
} else {
Err(syn::Error::new_spanned(ty, "Expected type path"))
}
}
syn::PathArguments::Parenthesized(args) => {
Err(syn::Error::new_spanned(args, "Unsupported"))
}
syn::PathArguments::None => Ok(()),
}
} else {
Err(syn::Error::new_spanned(ty, "Expected type path"))
}
}
if let Some(ident) = &field.ident {
Ok(format_ident!("{}", ident.to_string().to_upper_camel_case()))
} else {
let mut out = String::new();
rejection_variant_name_for_type(&mut out, &field.ty)?;
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
if let Some((_, path)) = via {
let via_ident = &path.segments.last().unwrap().ident;
Ok(format_ident!("{}{}", via_ident, out))
} else {
Ok(format_ident!("{}", out))
}
}
}
fn impl_struct_by_extracting_all_at_once( fn impl_struct_by_extracting_all_at_once(
ident: syn::Ident, ident: syn::Ident,
fields: syn::Fields, fields: syn::Fields,
@ -606,12 +417,16 @@ fn impl_struct_by_extracting_all_at_once(
let path_span = path.span(); let path_span = path.span();
let associated_rejection_type = if let Some(rejection) = &rejection { let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection {
quote! { #rejection } let rejection = quote! { #rejection };
let map_err = quote! { ::std::convert::From::from };
(rejection, map_err)
} else { } else {
quote! { let rejection = quote! {
<#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection ::axum::response::Response
} };
let map_err = quote! { ::axum::response::IntoResponse::into_response };
(rejection, map_err)
}; };
let rejection_bound = rejection.as_ref().map(|rejection| { let rejection_bound = rejection.as_ref().map(|rejection| {
@ -658,18 +473,19 @@ fn impl_struct_by_extracting_all_at_once(
where where
#path<#via_type_generics>: ::axum::extract::FromRequest<S, B>, #path<#via_type_generics>: ::axum::extract::FromRequest<S, B>,
#rejection_bound #rejection_bound
B: ::std::marker::Send, B: ::std::marker::Send + 'static,
S: ::std::marker::Send + ::std::marker::Sync, S: ::std::marker::Send + ::std::marker::Sync,
{ {
type Rejection = #associated_rejection_type; type Rejection = #associated_rejection_type;
async fn from_request( async fn from_request(
req: &mut ::axum::extract::RequestParts<S, B>, req: ::axum::http::Request<B>,
state: &S
) -> ::std::result::Result<Self, Self::Rejection> { ) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequest::<S, B>::from_request(req) ::axum::extract::FromRequest::from_request(req, state)
.await .await
.map(|#path(value)| #value_to_self) .map(|#path(value)| #value_to_self)
.map_err(::std::convert::From::from) .map_err(#map_err)
} }
} }
}) })
@ -707,12 +523,16 @@ fn impl_enum_by_extracting_all_at_once(
} }
} }
let associated_rejection_type = if let Some(rejection) = rejection { let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection {
quote! { #rejection } let rejection = quote! { #rejection };
let map_err = quote! { ::std::convert::From::from };
(rejection, map_err)
} else { } else {
quote! { let rejection = quote! {
<#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection ::axum::response::Response
} };
let map_err = quote! { ::axum::response::IntoResponse::into_response };
(rejection, map_err)
}; };
let path_span = path.span(); let path_span = path.span();
@ -730,12 +550,13 @@ fn impl_enum_by_extracting_all_at_once(
type Rejection = #associated_rejection_type; type Rejection = #associated_rejection_type;
async fn from_request( async fn from_request(
req: &mut ::axum::extract::RequestParts<S, B>, req: ::axum::http::Request<B>,
state: &S
) -> ::std::result::Result<Self, Self::Rejection> { ) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequest::<S, B>::from_request(req) ::axum::extract::FromRequest::from_request(req, state)
.await .await
.map(|#path(inner)| inner) .map(|#path(inner)| inner)
.map_err(::std::convert::From::from) .map_err(#map_err)
} }
} }
}) })

View file

@ -16,13 +16,11 @@ pub(crate) enum FromRequestContainerAttr {
rejection: Option<syn::Path>, rejection: Option<syn::Path>,
}, },
Rejection(syn::Path), Rejection(syn::Path),
RejectionDerive(kw::rejection_derive, RejectionDeriveOptOuts),
None, None,
} }
pub(crate) mod kw { pub(crate) mod kw {
syn::custom_keyword!(via); syn::custom_keyword!(via);
syn::custom_keyword!(rejection_derive);
syn::custom_keyword!(rejection); syn::custom_keyword!(rejection);
syn::custom_keyword!(Display); syn::custom_keyword!(Display);
syn::custom_keyword!(Debug); syn::custom_keyword!(Debug);
@ -55,7 +53,6 @@ pub(crate) fn parse_container_attrs(
let attrs = parse_attrs::<ContainerAttr>(attrs)?; let attrs = parse_attrs::<ContainerAttr>(attrs)?;
let mut out_via = None; let mut out_via = None;
let mut out_rejection_derive = None;
let mut out_rejection = None; let mut out_rejection = None;
// we track the index of the attribute to know which comes last // we track the index of the attribute to know which comes last
@ -69,16 +66,6 @@ pub(crate) fn parse_container_attrs(
out_via = Some((idx, via, path)); out_via = Some((idx, via, path));
} }
} }
ContainerAttr::RejectionDerive {
rejection_derive,
opt_outs,
} => {
if out_rejection_derive.is_some() {
return Err(double_attr_error("rejection_derive", rejection_derive));
} else {
out_rejection_derive = Some((idx, rejection_derive, opt_outs));
}
}
ContainerAttr::Rejection { rejection, path } => { ContainerAttr::Rejection { rejection, path } => {
if out_rejection.is_some() { if out_rejection.is_some() {
return Err(double_attr_error("rejection", rejection)); return Err(double_attr_error("rejection", rejection));
@ -89,55 +76,20 @@ pub(crate) fn parse_container_attrs(
} }
} }
match (out_via, out_rejection_derive, out_rejection) { match (out_via, out_rejection) {
(Some((via_idx, via, _)), Some((rejection_derive_idx, rejection_derive, _)), _) => { (Some((_, _, path)), None) => Ok(FromRequestContainerAttr::Via {
if via_idx > rejection_derive_idx {
Err(syn::Error::new_spanned(
via,
"cannot use both `rejection_derive` and `via`",
))
} else {
Err(syn::Error::new_spanned(
rejection_derive,
"cannot use both `via` and `rejection_derive`",
))
}
}
(
_,
Some((rejection_derive_idx, rejection_derive, _)),
Some((rejection_idx, rejection, _)),
) => {
if rejection_idx > rejection_derive_idx {
Err(syn::Error::new_spanned(
rejection,
"cannot use both `rejection_derive` and `rejection`",
))
} else {
Err(syn::Error::new_spanned(
rejection_derive,
"cannot use both `rejection` and `rejection_derive`",
))
}
}
(Some((_, _, path)), None, None) => Ok(FromRequestContainerAttr::Via {
path, path,
rejection: None, rejection: None,
}), }),
(Some((_, _, path)), None, Some((_, _, rejection))) => Ok(FromRequestContainerAttr::Via {
(Some((_, _, path)), Some((_, _, rejection))) => Ok(FromRequestContainerAttr::Via {
path, path,
rejection: Some(rejection), rejection: Some(rejection),
}), }),
(None, Some((_, rejection_derive, opt_outs)), _) => Ok( (None, Some((_, _, rejection))) => Ok(FromRequestContainerAttr::Rejection(rejection)),
FromRequestContainerAttr::RejectionDerive(rejection_derive, opt_outs),
),
(None, None, Some((_, _, rejection))) => Ok(FromRequestContainerAttr::Rejection(rejection)), (None, None) => Ok(FromRequestContainerAttr::None),
(None, None, None) => Ok(FromRequestContainerAttr::None),
} }
} }
@ -172,10 +124,6 @@ enum ContainerAttr {
rejection: kw::rejection, rejection: kw::rejection,
path: syn::Path, path: syn::Path,
}, },
RejectionDerive {
rejection_derive: kw::rejection_derive,
opt_outs: RejectionDeriveOptOuts,
},
} }
impl Parse for ContainerAttr { impl Parse for ContainerAttr {
@ -186,14 +134,6 @@ impl Parse for ContainerAttr {
let content; let content;
syn::parenthesized!(content in input); syn::parenthesized!(content in input);
content.parse().map(|path| Self::Via { via, path }) content.parse().map(|path| Self::Via { via, path })
} else if lh.peek(kw::rejection_derive) {
let rejection_derive = input.parse::<kw::rejection_derive>()?;
let content;
syn::parenthesized!(content in input);
content.parse().map(|opt_outs| Self::RejectionDerive {
rejection_derive,
opt_outs,
})
} else if lh.peek(kw::rejection) { } else if lh.peek(kw::rejection) {
let rejection = input.parse::<kw::rejection>()?; let rejection = input.parse::<kw::rejection>()?;
let content; let content;
@ -224,82 +164,3 @@ impl Parse for FieldAttr {
} }
} }
} }
#[derive(Default)]
pub(crate) struct RejectionDeriveOptOuts {
debug: Option<kw::Debug>,
display: Option<kw::Display>,
error: Option<kw::Error>,
}
impl RejectionDeriveOptOuts {
pub(crate) fn derive_debug(&self) -> bool {
self.debug.is_none()
}
pub(crate) fn derive_display(&self) -> bool {
self.display.is_none()
}
pub(crate) fn derive_error(&self) -> bool {
self.error.is_none()
}
}
impl Parse for RejectionDeriveOptOuts {
fn parse(input: ParseStream) -> syn::Result<Self> {
fn parse_opt_out<T>(out: &mut Option<T>, ident: &str, input: ParseStream) -> syn::Result<()>
where
T: Parse,
{
if out.is_some() {
Err(input.error(format!("`{}` opt out specified more than once", ident)))
} else {
*out = Some(input.parse()?);
Ok(())
}
}
let mut debug = None::<kw::Debug>;
let mut display = None::<kw::Display>;
let mut error = None::<kw::Error>;
while !input.is_empty() {
input.parse::<Token![!]>()?;
let lh = input.lookahead1();
if lh.peek(kw::Debug) {
parse_opt_out(&mut debug, "Debug", input)?;
} else if lh.peek(kw::Display) {
parse_opt_out(&mut display, "Display", input)?;
} else if lh.peek(kw::Error) {
parse_opt_out(&mut error, "Error", input)?;
} else {
return Err(lh.error());
}
input.parse::<Token![,]>().ok();
}
if error.is_none() {
match (debug, display) {
(Some(debug), Some(_)) => {
return Err(syn::Error::new_spanned(debug, "opt out of `Debug` and `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Display, !Error))]`"));
}
(Some(debug), None) => {
return Err(syn::Error::new_spanned(debug, "opt out of `Debug` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Error))]`"));
}
(None, Some(display)) => {
return Err(syn::Error::new_spanned(display, "opt out of `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Display, !Error))]`"));
}
(None, None) => {}
}
}
Ok(Self {
debug,
display,
error,
})
}
}

View file

@ -86,6 +86,20 @@ mod typed_path;
/// ///
/// This requires that each field is an extractor (i.e. implements [`FromRequest`]). /// This requires that each field is an extractor (i.e. implements [`FromRequest`]).
/// ///
/// ```compile_fail
/// use axum_macros::FromRequest;
/// use axum::body::Bytes;
///
/// #[derive(FromRequest)]
/// struct MyExtractor {
/// // only the last field can implement `FromRequest`
/// // other fields must only implement `FromRequestParts`
/// bytes: Bytes,
/// string: String,
/// }
/// ```
/// Note that only the last field can consume the request body. Therefore this doesn't compile:
///
/// ## Extracting via another extractor /// ## Extracting via another extractor
/// ///
/// You can use `#[from_request(via(...))]` to extract a field via another extractor, meaning the /// You can use `#[from_request(via(...))]` to extract a field via another extractor, meaning the
@ -157,95 +171,15 @@ mod typed_path;
/// ///
/// ## The rejection /// ## The rejection
/// ///
/// A rejection enum is also generated. It has a variant for each field: /// By default [`axum::response::Response`] will be used as the rejection. You can also use your own
/// /// rejection type with `#[from_request(rejection(YourType))]`:
/// ```
/// use axum_macros::FromRequest;
/// use axum::{
/// extract::{Extension, TypedHeader},
/// headers::ContentType,
/// body::Bytes,
/// };
///
/// #[derive(FromRequest)]
/// struct MyExtractor {
/// #[from_request(via(Extension))]
/// state: State,
/// #[from_request(via(TypedHeader))]
/// content_type: ContentType,
/// request_body: Bytes,
/// }
///
/// // also generates
/// //
/// // #[derive(Debug)]
/// // enum MyExtractorRejection {
/// // State(ExtensionRejection),
/// // ContentType(TypedHeaderRejection),
/// // RequestBody(BytesRejection),
/// // }
/// //
/// // impl axum::response::IntoResponse for MyExtractor { ... }
/// //
/// // impl std::fmt::Display for MyExtractor { ... }
/// //
/// // impl std::error::Error for MyExtractor { ... }
///
/// #[derive(Clone)]
/// struct State {
/// // ...
/// }
/// ```
///
/// The rejection's `std::error::Error::source` implementation returns the inner rejection. This
/// can be used to access source errors for example to customize rejection responses. Note this
/// means the inner rejection types must themselves implement `std::error::Error`. All extractors
/// in axum does this.
///
/// You can opt out of this using `#[from_request(rejection_derive(...))]`:
///
/// ```
/// use axum_macros::FromRequest;
/// use axum::{
/// extract::{FromRequest, RequestParts},
/// http::StatusCode,
/// headers::ContentType,
/// body::Bytes,
/// async_trait,
/// };
///
/// #[derive(FromRequest)]
/// #[from_request(rejection_derive(!Display, !Error))]
/// struct MyExtractor {
/// other: OtherExtractor,
/// }
///
/// struct OtherExtractor;
///
/// #[async_trait]
/// impl<S, B> FromRequest<S, B> for OtherExtractor
/// where
/// B: Send,
/// S: Send + Sync,
/// {
/// // this rejection doesn't implement `Display` and `Error`
/// type Rejection = (StatusCode, String);
///
/// async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // ...
/// # unimplemented!()
/// }
/// }
/// ```
///
/// You can also use your own rejection type with `#[from_request(rejection(YourType))]`:
/// ///
/// ``` /// ```
/// use axum_macros::FromRequest; /// use axum_macros::FromRequest;
/// use axum::{ /// use axum::{
/// extract::{ /// extract::{
/// rejection::{ExtensionRejection, StringRejection}, /// rejection::{ExtensionRejection, StringRejection},
/// FromRequest, RequestParts, /// FromRequest,
/// }, /// },
/// Extension, /// Extension,
/// response::{Response, IntoResponse}, /// response::{Response, IntoResponse},
@ -414,6 +348,7 @@ mod typed_path;
/// ``` /// ```
/// ///
/// [`FromRequest`]: https://docs.rs/axum/latest/axum/extract/trait.FromRequest.html /// [`FromRequest`]: https://docs.rs/axum/latest/axum/extract/trait.FromRequest.html
/// [`axum::response::Response`]: https://docs.rs/axum/0.6/axum/response/type.Response.html
/// [`axum::extract::rejection::ExtensionRejection`]: https://docs.rs/axum/latest/axum/extract/rejection/enum.ExtensionRejection.html /// [`axum::extract::rejection::ExtensionRejection`]: https://docs.rs/axum/latest/axum/extract/rejection/enum.ExtensionRejection.html
#[proc_macro_derive(FromRequest, attributes(from_request))] #[proc_macro_derive(FromRequest, attributes(from_request))]
pub fn derive_from_request(item: TokenStream) -> TokenStream { pub fn derive_from_request(item: TokenStream) -> TokenStream {

View file

@ -127,15 +127,17 @@ fn expand_named_fields(
let from_request_impl = quote! { let from_request_impl = quote! {
#[::axum::async_trait] #[::axum::async_trait]
#[automatically_derived] #[automatically_derived]
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident impl<S> ::axum::extract::FromRequestParts<S> for #ident
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = #rejection_assoc_type; type Rejection = #rejection_assoc_type;
async fn from_request(req: &mut ::axum::extract::RequestParts<S, B>) -> ::std::result::Result<Self, Self::Rejection> { async fn from_request_parts(
::axum::extract::Path::from_request(req) parts: &mut ::axum::http::request::Parts,
state: &S,
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request_parts(parts, state)
.await .await
.map(|path| path.0) .map(|path| path.0)
#map_err_rejection #map_err_rejection
@ -230,15 +232,17 @@ fn expand_unnamed_fields(
let from_request_impl = quote! { let from_request_impl = quote! {
#[::axum::async_trait] #[::axum::async_trait]
#[automatically_derived] #[automatically_derived]
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident impl<S> ::axum::extract::FromRequestParts<S> for #ident
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = #rejection_assoc_type; type Rejection = #rejection_assoc_type;
async fn from_request(req: &mut ::axum::extract::RequestParts<S, B>) -> ::std::result::Result<Self, Self::Rejection> { async fn from_request_parts(
::axum::extract::Path::from_request(req) parts: &mut ::axum::http::request::Parts,
state: &S,
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request_parts(parts, state)
.await .await
.map(|path| path.0) .map(|path| path.0)
#map_err_rejection #map_err_rejection
@ -312,15 +316,17 @@ fn expand_unit_fields(
let from_request_impl = quote! { let from_request_impl = quote! {
#[::axum::async_trait] #[::axum::async_trait]
#[automatically_derived] #[automatically_derived]
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident impl<S> ::axum::extract::FromRequestParts<S> for #ident
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = #rejection_assoc_type; type Rejection = #rejection_assoc_type;
async fn from_request(req: &mut ::axum::extract::RequestParts<S, B>) -> ::std::result::Result<Self, Self::Rejection> { async fn from_request_parts(
if req.uri().path() == <Self as ::axum_extra::routing::TypedPath>::PATH { parts: &mut ::axum::http::request::Parts,
_state: &S,
) -> ::std::result::Result<Self, Self::Rejection> {
if parts.uri.path() == <Self as ::axum_extra::routing::TypedPath>::PATH {
Ok(Self) Ok(Self)
} else { } else {
#create_rejection #create_rejection
@ -390,7 +396,7 @@ enum Segment {
fn path_rejection() -> TokenStream { fn path_rejection() -> TokenStream {
quote! { quote! {
<::axum::extract::Path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection <::axum::extract::Path<Self> as ::axum::extract::FromRequestParts<S>>::Rejection
} }
} }

View file

@ -1,17 +1,22 @@
error[E0277]: the trait bound `bool: FromRequest<(), Body>` is not satisfied error[E0277]: the trait bound `bool: FromRequestParts<()>` is not satisfied
--> tests/debug_handler/fail/argument_not_extractor.rs:4:23 --> tests/debug_handler/fail/argument_not_extractor.rs:4:23
| |
4 | async fn handler(foo: bool) {} 4 | async fn handler(foo: bool) {}
| ^^^^ the trait `FromRequest<(), Body>` is not implemented for `bool` | ^^^^ the trait `FromRequestParts<()>` is not implemented for `bool`
| |
= help: the following other types implement trait `FromRequest<S, B>`: = help: the following other types implement trait `FromRequestParts<S>`:
<() as FromRequest<S, B>> <() as FromRequestParts<S>>
<(T1, T2) as FromRequest<S, B>> <(T1, T2) as FromRequestParts<S>>
<(T1, T2, T3) as FromRequest<S, B>> <(T1, T2, T3) as FromRequestParts<S>>
<(T1, T2, T3, T4) as FromRequest<S, B>> <(T1, T2, T3, T4) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5) as FromRequest<S, B>> <(T1, T2, T3, T4, T5) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6) as FromRequest<S, B>> <(T1, T2, T3, T4, T5, T6) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6, T7) as FromRequest<S, B>> <(T1, T2, T3, T4, T5, T6, T7) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequest<S, B>> <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequestParts<S>>
and 34 others and 26 others
= help: see issue #48214 = note: required because of the requirements on the impl of `FromRequest<(), Body, axum_core::extract::private::ViaParts>` for `bool`
note: required by a bound in `__axum_macros_check_handler_0_from_request_check`
--> tests/debug_handler/fail/argument_not_extractor.rs:4:23
|
4 | async fn handler(foo: bool) {}
| ^^^^ required by this bound in `__axum_macros_check_handler_0_from_request_check`

View file

@ -1,6 +1,7 @@
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRequest, RequestParts}, extract::FromRequest,
http::Request,
}; };
use axum_macros::debug_handler; use axum_macros::debug_handler;
@ -9,12 +10,12 @@ struct A;
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for A impl<S, B> FromRequest<S, B> for A
where where
B: Send, B: Send + 'static,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = (); type Rejection = ();
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(_req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
unimplemented!() unimplemented!()
} }
} }

View file

@ -1,5 +1,5 @@
error: Handlers must only take owned values error: Handlers must only take owned values
--> tests/debug_handler/fail/extract_self_mut.rs:24:22 --> tests/debug_handler/fail/extract_self_mut.rs:25:22
| |
24 | async fn handler(&mut self) {} 25 | async fn handler(&mut self) {}
| ^^^^^^^^^ | ^^^^^^^^^

View file

@ -1,6 +1,7 @@
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRequest, RequestParts}, extract::FromRequest,
http::Request,
}; };
use axum_macros::debug_handler; use axum_macros::debug_handler;
@ -9,12 +10,12 @@ struct A;
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for A impl<S, B> FromRequest<S, B> for A
where where
B: Send, B: Send + 'static,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = (); type Rejection = ();
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(_req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
unimplemented!() unimplemented!()
} }
} }

View file

@ -1,5 +1,5 @@
error: Handlers must only take owned values error: Handlers must only take owned values
--> tests/debug_handler/fail/extract_self_ref.rs:24:22 --> tests/debug_handler/fail/extract_self_ref.rs:25:22
| |
24 | async fn handler(&self) {} 25 | async fn handler(&self) {}
| ^^^^^ | ^^^^^

View file

@ -1,7 +0,0 @@
use axum_macros::debug_handler;
use axum::body::Bytes;
#[debug_handler]
async fn handler(_: String, _: Bytes) {}
fn main() {}

View file

@ -1,11 +0,0 @@
error: Only one body extractor can be applied
--> tests/debug_handler/fail/multiple_body_extractors.rs:5:18
|
5 | async fn handler(_: String, _: Bytes) {}
| ^^^^^^^^^
error: Only one body extractor can be applied
--> tests/debug_handler/fail/multiple_body_extractors.rs:5:29
|
5 | async fn handler(_: String, _: Bytes) {}
| ^^^^^^^^

View file

@ -1,7 +0,0 @@
use axum::{body::Body, extract::Extension, http::Request};
use axum_macros::debug_handler;
#[debug_handler]
async fn handler(_: Request<Body>, _: Extension<String>) {}
fn main() {}

View file

@ -1,5 +0,0 @@
error: `Request` extractor should always be last
--> tests/debug_handler/fail/request_not_last.rs:5:18
|
5 | async fn handler(_: Request<Body>, _: Extension<String>) {}
| ^^^^^^^^^^^^^^^^

View file

@ -13,7 +13,7 @@ error[E0277]: the trait bound `bool: IntoResponse` is not satisfied
(Response<()>, T1, T2, R) (Response<()>, T1, T2, R)
(Response<()>, T1, T2, T3, R) (Response<()>, T1, T2, T3, R)
(Response<()>, T1, T2, T3, T4, R) (Response<()>, T1, T2, T3, T4, R)
and 123 others and 122 others
note: required by a bound in `__axum_macros_check_handler_into_response::{closure#0}::check` note: required by a bound in `__axum_macros_check_handler_into_response::{closure#0}::check`
--> tests/debug_handler/fail/wrong_return_type.rs:4:23 --> tests/debug_handler/fail/wrong_return_type.rs:4:23
| |

View file

@ -1,8 +1,4 @@
use axum::{ use axum::{async_trait, extract::FromRequest, http::Request, response::IntoResponse};
async_trait,
extract::{FromRequest, RequestParts},
response::IntoResponse,
};
use axum_macros::debug_handler; use axum_macros::debug_handler;
fn main() {} fn main() {}
@ -122,12 +118,12 @@ impl A {
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for A impl<S, B> FromRequest<S, B> for A
where where
B: Send, B: Send + 'static,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = (); type Rejection = ();
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(_req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
unimplemented!() unimplemented!()
} }
} }

View file

@ -1,6 +1,7 @@
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRequest, RequestParts}, extract::FromRequest,
http::Request,
}; };
use axum_macros::debug_handler; use axum_macros::debug_handler;
@ -9,12 +10,25 @@ struct A;
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for A impl<S, B> FromRequest<S, B> for A
where where
B: Send, B: Send + 'static,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = (); type Rejection = ();
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(_req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}
#[async_trait]
impl<S, B> FromRequest<S, B> for Box<A>
where
B: Send + 'static,
S: Send + Sync,
{
type Rejection = ();
async fn from_request(_req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
unimplemented!() unimplemented!()
} }
} }
@ -22,6 +36,9 @@ where
impl A { impl A {
#[debug_handler] #[debug_handler]
async fn handler(self) {} async fn handler(self) {}
#[debug_handler]
async fn handler_with_qualified_self(self: Box<Self>) {}
} }
fn main() {} fn main() {}

View file

@ -1,6 +1,7 @@
use axum_macros::debug_handler; use axum_macros::debug_handler;
use axum::extract::{FromRef, FromRequest, RequestParts}; use axum::extract::{FromRef, FromRequest};
use axum::async_trait; use axum::async_trait;
use axum::http::Request;
#[debug_handler(state = AppState)] #[debug_handler(state = AppState)]
async fn handler(_: A) {} async fn handler(_: A) {}
@ -13,13 +14,13 @@ struct A;
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for A impl<S, B> FromRequest<S, B> for A
where where
B: Send, B: Send + 'static,
S: Send + Sync, S: Send + Sync,
AppState: FromRef<S>, AppState: FromRef<S>,
{ {
type Rejection = (); type Rejection = ();
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(_req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
unimplemented!() unimplemented!()
} }
} }

View file

@ -1,9 +0,0 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Debug, !Display))]
struct Extractor {
body: String,
}
fn main() {}

View file

@ -1,5 +0,0 @@
error: opt out of `Debug` and `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Display, !Error))]`
--> tests/from_request/fail/derive_opt_out_debug_and_display_without_error.rs:4:34
|
4 | #[from_request(rejection_derive(!Debug, !Display))]
| ^^^^^

View file

@ -1,9 +0,0 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Debug))]
struct Extractor {
body: String,
}
fn main() {}

View file

@ -1,5 +0,0 @@
error: opt out of `Debug` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Error))]`
--> tests/from_request/fail/derive_opt_out_debug_without_error.rs:4:34
|
4 | #[from_request(rejection_derive(!Debug))]
| ^^^^^

View file

@ -1,9 +0,0 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Display))]
struct Extractor {
body: String,
}
fn main() {}

View file

@ -1,5 +0,0 @@
error: opt out of `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Display, !Error))]`
--> tests/from_request/fail/derive_opt_out_display_without_error.rs:4:34
|
4 | #[from_request(rejection_derive(!Display))]
| ^^^^^^^

View file

@ -1,9 +0,0 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Error, !Error))]
struct Extractor {
body: String,
}
fn main() {}

View file

@ -1,5 +0,0 @@
error: `Error` opt out specified more than once
--> tests/from_request/fail/derive_opt_out_duplicate.rs:4:42
|
4 | #[from_request(rejection_derive(!Error, !Error))]
| ^^^^^

View file

@ -1,7 +0,0 @@
use axum_macros::FromRequest;
#[derive(FromRequest, Clone)]
#[from_request(rejection_derive(!Error))]
enum Extractor {}
fn main() {}

View file

@ -1,5 +0,0 @@
error: cannot use `rejection_derive` on enums
--> tests/from_request/fail/enum_rejection_derive.rs:4:16
|
4 | #[from_request(rejection_derive(!Error))]
| ^^^^^^^^^^^^^^^^

View file

@ -1,12 +0,0 @@
use axum::{body::Body, routing::get, Router};
use axum_macros::FromRequest;
#[derive(FromRequest, Clone)]
#[from_request(rejection_derive(!Error))]
struct Extractor<T>(T);
async fn foo(_: Extractor<()>) {}
fn main() {
Router::<(), Body>::new().route("/", get(foo));
}

View file

@ -1,21 +0,0 @@
error: #[derive(FromRequest)] only supports generics when used with #[from_request(via)]
--> tests/from_request/fail/generic_without_via_rejection_derive.rs:6:18
|
6 | struct Extractor<T>(T);
| ^
error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {foo}: Handler<_, _, _>` is not satisfied
--> tests/from_request/fail/generic_without_via_rejection_derive.rs:11:46
|
11 | Router::<(), Body>::new().route("/", get(foo));
| --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future<Output = ()> {foo}`
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|
| top_level_handler_fn!(get, GET);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `axum::routing::get`
= note: this error originates in the macro `top_level_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info)

View file

@ -1,9 +0,0 @@
use axum_macros::FromRequest;
#[derive(FromRequest, Clone)]
#[from_request(rejection_derive(!Error), via(axum::Extension))]
struct Extractor {
config: String,
}
fn main() {}

View file

@ -1,5 +0,0 @@
error: cannot use both `rejection_derive` and `via`
--> tests/from_request/fail/rejection_derive_and_via.rs:4:42
|
4 | #[from_request(rejection_derive(!Error), via(axum::Extension))]
| ^^^

View file

@ -1,4 +1,4 @@
error: expected one of: `via`, `rejection_derive`, `rejection` error: expected `via` or `rejection`
--> tests/from_request/fail/unknown_attr_container.rs:4:16 --> tests/from_request/fail/unknown_attr_container.rs:4:16
| |
4 | #[from_request(foo)] 4 | #[from_request(foo)]

View file

@ -1,9 +0,0 @@
use axum_macros::FromRequest;
#[derive(FromRequest, Clone)]
#[from_request(via(axum::Extension), rejection_derive(!Error))]
struct Extractor {
config: String,
}
fn main() {}

View file

@ -1,5 +0,0 @@
error: cannot use both `via` and `rejection_derive`
--> tests/from_request/fail/via_and_rejection_derive.rs:4:38
|
4 | #[from_request(via(axum::Extension), rejection_derive(!Error))]
| ^^^^^^^^^^^^^^^^

View file

@ -1,6 +1,7 @@
use axum::{ use axum::{
body::Body, body::Body,
extract::{rejection::JsonRejection, FromRequest, Json}, extract::{FromRequest, Json},
response::Response,
}; };
use axum_macros::FromRequest; use axum_macros::FromRequest;
use serde::Deserialize; use serde::Deserialize;
@ -15,7 +16,7 @@ struct Extractor {
fn assert_from_request() fn assert_from_request()
where where
Extractor: FromRequest<(), Body, Rejection = JsonRejection>, Extractor: FromRequest<(), Body, Rejection = Response>,
{ {
} }

View file

@ -1,38 +0,0 @@
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
response::{IntoResponse, Response},
};
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Display, !Error))]
struct Extractor {
other: OtherExtractor,
}
struct OtherExtractor;
#[async_trait]
impl<S, B> FromRequest<S, B> for OtherExtractor
where
B: Send,
S: Send + Sync,
{
type Rejection = OtherExtractorRejection;
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}
#[derive(Debug)]
struct OtherExtractorRejection;
impl IntoResponse for OtherExtractorRejection {
fn into_response(self) -> Response {
unimplemented!()
}
}
fn main() {}

View file

@ -1,10 +1,10 @@
use axum::{ use axum::{
body::Body, body::Body,
extract::{FromRequest, TypedHeader, rejection::{TypedHeaderRejection, StringRejection}}, extract::{FromRequest, TypedHeader, rejection::TypedHeaderRejection},
response::Response,
headers::{self, UserAgent}, headers::{self, UserAgent},
}; };
use axum_macros::FromRequest; use axum_macros::FromRequest;
use std::convert::Infallible;
#[derive(FromRequest)] #[derive(FromRequest)]
struct Extractor { struct Extractor {
@ -18,34 +18,8 @@ struct Extractor {
fn assert_from_request() fn assert_from_request()
where where
Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>, Extractor: FromRequest<(), Body, Rejection = Response>,
{ {
} }
fn assert_rejection(rejection: ExtractorRejection)
where
ExtractorRejection: std::fmt::Debug + std::fmt::Display + std::error::Error,
{
match rejection {
ExtractorRejection::Uri(inner) => {
let _: Infallible = inner;
}
ExtractorRejection::Body(inner) => {
let _: StringRejection = inner;
}
ExtractorRejection::UserAgent(inner) => {
let _: TypedHeaderRejection = inner;
}
ExtractorRejection::ContentType(inner) => {
let _: TypedHeaderRejection = inner;
}
ExtractorRejection::Etag(inner) => {
let _: Infallible = inner;
}
ExtractorRejection::Host(inner) => {
let _: Infallible = inner;
}
}
}
fn main() {} fn main() {}

View file

@ -1,13 +1,13 @@
use axum::{ use axum::{
body::Body, body::Body,
response::Response,
extract::{ extract::{
rejection::{ExtensionRejection, TypedHeaderRejection}, rejection::TypedHeaderRejection,
Extension, FromRequest, TypedHeader, Extension, FromRequest, TypedHeader,
}, },
headers::{self, UserAgent}, headers::{self, UserAgent},
}; };
use axum_macros::FromRequest; use axum_macros::FromRequest;
use std::convert::Infallible;
#[derive(FromRequest)] #[derive(FromRequest)]
struct Extractor { struct Extractor {
@ -25,33 +25,10 @@ struct Extractor {
fn assert_from_request() fn assert_from_request()
where where
Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>, Extractor: FromRequest<(), Body, Rejection = Response>,
{ {
} }
fn assert_rejection(rejection: ExtractorRejection)
where
ExtractorRejection: std::fmt::Debug + std::fmt::Display + std::error::Error,
{
match rejection {
ExtractorRejection::State(inner) => {
let _: ExtensionRejection = inner;
}
ExtractorRejection::UserAgent(inner) => {
let _: TypedHeaderRejection = inner;
}
ExtractorRejection::ContentType(inner) => {
let _: TypedHeaderRejection = inner;
}
ExtractorRejection::Etag(inner) => {
let _: Infallible = inner;
}
ExtractorRejection::Host(inner) => {
let _: Infallible = inner;
}
}
}
#[derive(Clone)] #[derive(Clone)]
struct State; struct State;

View file

@ -1,7 +1,7 @@
use axum::{ use axum::{
async_trait, async_trait,
extract::{rejection::ExtensionRejection, FromRequest, RequestParts}, extract::{rejection::ExtensionRejection, FromRequest},
http::StatusCode, http::{StatusCode, Request},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
routing::get, routing::get,
Extension, Router, Extension, Router,
@ -36,7 +36,7 @@ where
// this rejection doesn't implement `Display` and `Error` // this rejection doesn't implement `Display` and `Error`
type Rejection = (StatusCode, String); type Rejection = (StatusCode, String);
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(_req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
todo!() todo!()
} }
} }

View file

@ -1,4 +1,4 @@
use axum::extract::{Query, rejection::*}; use axum::extract::Query;
use axum_macros::FromRequest; use axum_macros::FromRequest;
use serde::Deserialize; use serde::Deserialize;
@ -17,18 +17,4 @@ where
{ {
} }
fn assert_rejection(rejection: ExtractorRejection)
where
ExtractorRejection: std::fmt::Debug + std::fmt::Display + std::error::Error,
{
match rejection {
ExtractorRejection::QueryPayload(inner) => {
let _: QueryRejection = inner;
}
ExtractorRejection::JsonPayload(inner) => {
let _: JsonRejection = inner;
}
}
}
fn main() {} fn main() {}

View file

@ -1,4 +1,5 @@
use axum::extract::{Query, rejection::*}; use axum::extract::Query;
use axum::response::Response;
use axum_macros::FromRequest; use axum_macros::FromRequest;
use serde::Deserialize; use serde::Deserialize;
@ -8,26 +9,12 @@ struct Extractor(
#[from_request(via(axum::extract::Json))] Payload, #[from_request(via(axum::extract::Json))] Payload,
); );
fn assert_rejection(rejection: ExtractorRejection)
where
ExtractorRejection: std::fmt::Debug + std::fmt::Display + std::error::Error,
{
match rejection {
ExtractorRejection::QueryPayload(inner) => {
let _: QueryRejection = inner;
}
ExtractorRejection::JsonPayload(inner) => {
let _: JsonRejection = inner;
}
}
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct Payload {} struct Payload {}
fn assert_from_request() fn assert_from_request()
where where
Extractor: axum::extract::FromRequest<(), axum::body::Body>, Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = Response>,
{ {
} }

View file

@ -15,5 +15,5 @@ error[E0277]: the trait bound `for<'de> MyPath: serde::de::Deserialize<'de>` is
(T0, T1, T2, T3) (T0, T1, T2, T3)
and 138 others and 138 others
= note: required because of the requirements on the impl of `serde::de::DeserializeOwned` for `MyPath` = note: required because of the requirements on the impl of `serde::de::DeserializeOwned` for `MyPath`
= note: required because of the requirements on the impl of `FromRequest<S, B>` for `axum::extract::Path<MyPath>` = note: required because of the requirements on the impl of `FromRequestParts<S>` for `axum::extract::Path<MyPath>`
= note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info) = note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info)

View file

@ -16,7 +16,6 @@ async fn result_handler(_: Result<UsersShow, PathRejection>) {}
#[typed_path("/users")] #[typed_path("/users")]
struct UsersIndex; struct UsersIndex;
#[axum_macros::debug_handler]
async fn result_handler_unit_struct(_: Result<UsersIndex, StatusCode>) {} async fn result_handler_unit_struct(_: Result<UsersIndex, StatusCode>) {}
fn main() { fn main() {

View file

@ -237,18 +237,48 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
} }
} }
``` ```
- **breaking:** The following types or traits have a new `S` type param - **breaking:** It is now only possible for one extractor per handler to consume
(`()` by default) which represents the state ([#1155]): the request body. In 0.5 doing so would result in runtime errors but in 0.6 it
- `FromRequest` is a compile error ([#1272])
- `RequestParts`
- `Router` axum enforces this by only allowing the _last_ extractor to consume the
- `MethodRouter` request.
- `Handler`
For example:
```rust
use axum::{Json, http::HeaderMap};
// This wont compile on 0.6 because both `Json` and `String` need to consume
// the request body. You can use either `Json` or `String`, but not both.
async fn handler_1(
json: Json<serde_json::Value>,
string: String,
) {}
// This won't work either since `Json` is not the last extractor.
async fn handler_2(
json: Json<serde_json::Value>,
headers: HeaderMap,
) {}
// This works!
async fn handler_3(
headers: HeaderMap,
json: Json<serde_json::Value>,
) {}
```
This is done by reworking the `FromRequest` trait and introducing a new
`FromRequestParts` trait.
If your extractor needs to consume the request body then you should implement
`FromRequest`, otherwise implement `FromRequestParts`.
This extractor in 0.5: This extractor in 0.5:
```rust ```rust
struct MyExtractor; struct MyExtractor { /* ... */ }
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for MyExtractor impl<B> FromRequest<B> for MyExtractor
@ -266,22 +296,53 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Becomes this in 0.6: Becomes this in 0.6:
```rust ```rust
struct MyExtractor; use axum::{
extract::{FromRequest, FromRequestParts},
http::{StatusCode, Request, request::Parts},
async_trait,
};
struct MyExtractor { /* ... */ }
// implement `FromRequestParts` if you don't need to consume the request body
#[async_trait]
impl<S> FromRequestParts<S> for MyExtractor
where
S: Send + Sync,
B: Send + 'static,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// ...
}
}
// implement `FromRequest` if you do need to consume the request body
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for MyExtractor impl<S, B> FromRequest<S, B> for MyExtractor
where where
S: Send + Sync, S: Send + Sync,
B: Send, B: Send + 'static,
{ {
type Rejection = StatusCode; type Rejection = StatusCode;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
// ... // ...
} }
} }
``` ```
- **breaking:** `RequestParts` has been removed as part of the `FromRequest`
rework ([#1272])
- **breaking:** `BodyAlreadyExtracted` has been removed ([#1272])
- **breaking:** The following types or traits have a new `S` type param
which represents the state ([#1155]):
- `Router`, defaults to `()`
- `MethodRouter`, defaults to `()`
- `FromRequest`, no default
- `Handler`, no default
## Middleware ## Middleware
- **breaking:** Remove `extractor_middleware` which was previously deprecated. - **breaking:** Remove `extractor_middleware` which was previously deprecated.
@ -310,6 +371,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1155]: https://github.com/tokio-rs/axum/pull/1155 [#1155]: https://github.com/tokio-rs/axum/pull/1155
[#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1239]: https://github.com/tokio-rs/axum/pull/1239
[#1248]: https://github.com/tokio-rs/axum/pull/1248 [#1248]: https://github.com/tokio-rs/axum/pull/1248
[#1272]: https://github.com/tokio-rs/axum/pull/1272
[#924]: https://github.com/tokio-rs/axum/pull/924 [#924]: https://github.com/tokio-rs/axum/pull/924
# 0.5.15 (9. August, 2022) # 0.5.15 (9. August, 2022)

View file

@ -5,14 +5,15 @@ Types and traits for extracting data from requests.
- [Intro](#intro) - [Intro](#intro)
- [Common extractors](#common-extractors) - [Common extractors](#common-extractors)
- [Applying multiple extractors](#applying-multiple-extractors) - [Applying multiple extractors](#applying-multiple-extractors)
- [Be careful when extracting `Request`](#be-careful-when-extracting-request) - [The order of extractors](#the-order-of-extractors)
- [Optional extractors](#optional-extractors) - [Optional extractors](#optional-extractors)
- [Customizing extractor responses](#customizing-extractor-responses) - [Customizing extractor responses](#customizing-extractor-responses)
- [Accessing inner errors](#accessing-inner-errors) - [Accessing inner errors](#accessing-inner-errors)
- [Defining custom extractors](#defining-custom-extractors) - [Defining custom extractors](#defining-custom-extractors)
- [Accessing other extractors in `FromRequest` implementations](#accessing-other-extractors-in-fromrequest-implementations) - [Accessing other extractors in `FromRequest` or `FromRequestParts` implementations](#accessing-other-extractors-in-fromrequest-or-fromrequestparts-implementations)
- [Request body extractors](#request-body-extractors) - [Request body extractors](#request-body-extractors)
- [Running extractors from middleware](#running-extractors-from-middleware) - [Running extractors from middleware](#running-extractors-from-middleware)
- [Wrapping extractors](#wrapping-extractors)
# Intro # Intro
@ -152,83 +153,74 @@ async fn get_user_things(
# }; # };
``` ```
# The order of extractors
Extractors always run in the order of the function parameters that is from Extractors always run in the order of the function parameters that is from
left to right. left to right.
# Be careful when extracting `Request` The request body is an asynchronous stream that can only be consumed once.
Therefore you can only have one extractor that consumes the request body. axum
enforces by that requiring such extractors to be the _last_ argument your
handler takes.
[`Request`] is itself an extractor: For example
```rust,no_run ```rust
use axum::{http::Request, body::Body}; use axum::http::{Method, HeaderMap};
async fn handler(request: Request<Body>) { async fn handler(
// `Method` and `HeaderMap` don't consume the request body so they can
// put anywhere in the argument list
method: Method,
headers: HeaderMap,
// `String` consumes the request body and thus must be the last extractor
body: String,
) {
// ... // ...
} }
#
# let _: axum::routing::MethodRouter = axum::routing::get(handler);
``` ```
However be careful when combining it with other extractors since it will consume We get a compile error if `String` isn't the last extractor:
all extensions and the request body. Therefore it is recommended to always apply
the request extractor last:
```rust,no_run ```rust,compile_fail
use axum::{http::Request, Extension, body::Body}; use axum::http::Method;
// this will fail at runtime since `Request<Body>` will have consumed all the async fn handler(
// extensions so `Extension<State>` will be missing // this doesn't work since `String` must be the last argument
async fn broken( body: String,
request: Request<Body>, method: Method,
Extension(state): Extension<State>,
) { ) {
// ... // ...
} }
#
// this will work since we extract `Extension<State>` before `Request<Body>` # let _: axum::routing::MethodRouter = axum::routing::get(handler);
async fn works(
Extension(state): Extension<State>,
request: Request<Body>,
) {
// ...
}
#[derive(Clone)]
struct State {};
``` ```
# Extracting request bodies This also means you cannot consume the request body twice:
Since request bodies are asynchronous streams they can only be extracted once: ```rust,compile_fail
use axum::Json;
use serde::Deserialize;
```rust,no_run #[derive(Deserialize)]
use axum::{Json, http::Request, body::{Bytes, Body}}; struct Payload {}
use serde_json::Value;
// this will fail at runtime since `Json<Value>` and `Bytes` both attempt to extract async fn handler(
// the body // `String` and `Json` both consume the request body
// // so they cannot both be used
// the solution is to only extract the body once so remove either string_body: String,
// `body_json: Json<Value>` or `body_bytes: Bytes` json_body: Json<Payload>,
async fn broken(
body_json: Json<Value>,
body_bytes: Bytes,
) {
// ...
}
// this doesn't work either for the same reason: `Bytes` and `Request<Body>`
// both extract the body
async fn also_broken(
body_json: Json<Value>,
request: Request<Body>,
) { ) {
// ... // ...
} }
#
# let _: axum::routing::MethodRouter = axum::routing::get(handler);
``` ```
Also keep this in mind if you extract or otherwise consume the body in axum enforces this by requiring the last extractor implements [`FromRequest`]
middleware. You either need to not extract the body in handlers or make sure and all others implement [`FromRequestParts`].
your middleware reinserts the body using [`RequestParts::body_mut`] so it's
available to handlers.
# Optional extractors # Optional extractors
@ -407,29 +399,38 @@ happen without major breaking versions.
# Defining custom extractors # Defining custom extractors
You can also define your own extractors by implementing [`FromRequest`]: You can also define your own extractors by implementing either
[`FromRequestParts`] or [`FromRequest`].
## Implementing `FromRequestParts`
Implement `FromRequestParts` if your extractor doesn't need access to the
request body:
```rust,no_run ```rust,no_run
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRequest, RequestParts}, extract::FromRequestParts,
routing::get, routing::get,
Router, Router,
http::{
StatusCode,
header::{HeaderValue, USER_AGENT},
request::Parts,
},
}; };
use http::{StatusCode, header::{HeaderValue, USER_AGENT}};
struct ExtractUserAgent(HeaderValue); struct ExtractUserAgent(HeaderValue);
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for ExtractUserAgent impl<S> FromRequestParts<S> for ExtractUserAgent
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = (StatusCode, &'static str); type Rejection = (StatusCode, &'static str);
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
if let Some(user_agent) = req.headers().get(USER_AGENT) { if let Some(user_agent) = parts.headers.get(USER_AGENT) {
Ok(ExtractUserAgent(user_agent.clone())) Ok(ExtractUserAgent(user_agent.clone()))
} else { } else {
Err((StatusCode::BAD_REQUEST, "`User-Agent` header is missing")) Err((StatusCode::BAD_REQUEST, "`User-Agent` header is missing"))
@ -447,7 +448,58 @@ let app = Router::new().route("/foo", get(handler));
# }; # };
``` ```
# Accessing other extractors in [`FromRequest`] implementations ## Implementing `FromRequest`
If your extractor needs to consume the request body you must implement [`FromRequest`]
```rust,no_run
use axum::{
async_trait,
extract::FromRequest,
response::{Response, IntoResponse},
body::Bytes,
routing::get,
Router,
http::{
StatusCode,
header::{HeaderValue, USER_AGENT},
Request,
},
};
struct ValidatedBody(Bytes);
#[async_trait]
impl<S, B> FromRequest<S, B> for ValidatedBody
where
Bytes: FromRequest<S, B>,
B: Send + 'static,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let body = Bytes::from_request(req, state)
.await
.map_err(IntoResponse::into_response)?;
// do validation...
Ok(Self(body))
}
}
async fn handler(ValidatedBody(body): ValidatedBody) {
// ...
}
let app = Router::new().route("/foo", get(handler));
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
```
# Accessing other extractors in `FromRequest` or `FromRequestParts` implementations
When defining custom extractors you often need to access another extractors When defining custom extractors you often need to access another extractors
in your implementation. in your implementation.
@ -455,9 +507,9 @@ in your implementation.
```rust ```rust
use axum::{ use axum::{
async_trait, async_trait,
extract::{Extension, FromRequest, RequestParts, TypedHeader}, extract::{Extension, FromRequestParts, TypedHeader},
headers::{authorization::Bearer, Authorization}, headers::{authorization::Bearer, Authorization},
http::StatusCode, http::{StatusCode, request::Parts},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
routing::get, routing::get,
Router, Router,
@ -473,20 +525,19 @@ struct AuthenticatedUser {
} }
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for AuthenticatedUser impl<S> FromRequestParts<S> for AuthenticatedUser
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = Response; type Rejection = Response;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(token)) = let TypedHeader(Authorization(token)) =
TypedHeader::<Authorization<Bearer>>::from_request(req) TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
.await .await
.map_err(|err| err.into_response())?; .map_err(|err| err.into_response())?;
let Extension(state): Extension<State> = Extension::from_request(req) let Extension(state): Extension<State> = Extension::from_request_parts(parts, state)
.await .await
.map_err(|err| err.into_response())?; .map_err(|err| err.into_response())?;
@ -584,14 +635,13 @@ let app = Router::new()
# Running extractors from middleware # Running extractors from middleware
Extractors can also be run from middleware by making a [`RequestParts`] and Extractors can also be run from middleware:
running your extractor:
```rust ```rust
use axum::{ use axum::{
Router, Router,
middleware::{self, Next}, middleware::{self, Next},
extract::{RequestParts, TypedHeader}, extract::{TypedHeader, FromRequestParts},
http::{Request, StatusCode}, http::{Request, StatusCode},
response::Response, response::Response,
headers::authorization::{Authorization, Bearer}, headers::authorization::{Authorization, Bearer},
@ -604,12 +654,11 @@ async fn auth_middleware<B>(
where where
B: Send, B: Send,
{ {
// running extractors requires a `RequestParts` // running extractors requires a `axum::http::request::Parts`
let mut request_parts = RequestParts::new(request); let (mut parts, body) = request.into_parts();
// `TypedHeader<Authorization<Bearer>>` extracts the auth token but // `TypedHeader<Authorization<Bearer>>` extracts the auth token
// `RequestParts::extract` works with anything that implements `FromRequest` let auth = TypedHeader::<Authorization<Bearer>>::from_request_parts(&mut parts, &())
let auth = request_parts.extract::<TypedHeader<Authorization<Bearer>>>()
.await .await
.map_err(|_| StatusCode::UNAUTHORIZED)?; .map_err(|_| StatusCode::UNAUTHORIZED)?;
@ -617,14 +666,8 @@ where
return Err(StatusCode::UNAUTHORIZED); return Err(StatusCode::UNAUTHORIZED);
} }
// get the request back so we can run `next` // reconstruct the request
// let request = Request::from_parts(parts, body);
// `try_into_request` will fail if you have extracted the request body. We
// know that `TypedHeader` never does that.
//
// see the `consume-body-in-extractor-or-middleware` example if you need to
// extract the body
let request = request_parts.try_into_request().expect("body extracted");
Ok(next.run(request).await) Ok(next.run(request).await)
} }
@ -638,8 +681,81 @@ let app = Router::new().layer(middleware::from_fn(auth_middleware));
# let _: Router<()> = app; # let _: Router<()> = app;
``` ```
# Wrapping extractors
If you want write an extractor that generically wraps another extractor (that
may or may not consume the request body) you should implement both
[`FromRequest`] and [`FromRequestParts`]:
```rust
use axum::{
Router,
routing::get,
extract::{FromRequest, FromRequestParts},
http::{Request, HeaderMap, request::Parts},
async_trait,
};
use std::time::{Instant, Duration};
// an extractor that wraps another and measures how long time it takes to run
struct Timing<E> {
extractor: E,
duration: Duration,
}
// we must implement both `FromRequestParts`
#[async_trait]
impl<S, T> FromRequestParts<S> for Timing<T>
where
S: Send + Sync,
T: FromRequestParts<S>,
{
type Rejection = T::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let start = Instant::now();
let extractor = T::from_request_parts(parts, state).await?;
let duration = start.elapsed();
Ok(Timing {
extractor,
duration,
})
}
}
// and `FromRequest`
#[async_trait]
impl<S, B, T> FromRequest<S, B> for Timing<T>
where
B: Send + 'static,
S: Send + Sync,
T: FromRequest<S, B>,
{
type Rejection = T::Rejection;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let start = Instant::now();
let extractor = T::from_request(req, state).await?;
let duration = start.elapsed();
Ok(Timing {
extractor,
duration,
})
}
}
async fn handler(
// this uses the `FromRequestParts` impl
_: Timing<HeaderMap>,
// this uses the `FromRequest` impl
_: Timing<String>,
) {}
# let _: axum::routing::MethodRouter = axum::routing::get(handler);
```
[`body::Body`]: crate::body::Body [`body::Body`]: crate::body::Body
[customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs [customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs
[`HeaderMap`]: https://docs.rs/http/latest/http/header/struct.HeaderMap.html [`HeaderMap`]: https://docs.rs/http/latest/http/header/struct.HeaderMap.html
[`Request`]: https://docs.rs/http/latest/http/struct.Request.html [`Request`]: https://docs.rs/http/latest/http/struct.Request.html
[`RequestParts::body_mut`]: crate::extract::RequestParts::body_mut [`RequestParts::body_mut`]: crate::extract::RequestParts::body_mut
[`JsonRejection::JsonDataError`]: rejection::JsonRejection::JsonDataError

View file

@ -1,8 +1,8 @@
#![doc = include_str!("../docs/error_handling.md")] #![doc = include_str!("../docs/error_handling.md")]
use crate::{ use crate::{
extract::{FromRequest, RequestParts}, extract::FromRequestParts,
http::{Request, StatusCode}, http::Request,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use std::{ use std::{
@ -161,7 +161,7 @@ macro_rules! impl_service {
F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static, F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send, Fut: Future<Output = Res> + Send,
Res: IntoResponse, Res: IntoResponse,
$( $ty: FromRequest<(), B> + Send,)* $( $ty: FromRequestParts<()> + Send,)*
B: Send + 'static, B: Send + 'static,
{ {
type Response = Response; type Response = Response;
@ -181,21 +181,16 @@ macro_rules! impl_service {
let inner = std::mem::replace(&mut self.inner, clone); let inner = std::mem::replace(&mut self.inner, clone);
let future = Box::pin(async move { let future = Box::pin(async move {
let mut req = RequestParts::new(req); let (mut parts, body) = req.into_parts();
$( $(
let $ty = match $ty::from_request(&mut req).await { let $ty = match $ty::from_request_parts(&mut parts, &()).await {
Ok(value) => value, Ok(value) => value,
Err(rejection) => return Ok(rejection.into_response()), Err(rejection) => return Ok(rejection.into_response()),
}; };
)* )*
let req = match req.try_into_request() { let req = Request::from_parts(parts, body);
Ok(req) => req,
Err(err) => {
return Ok((StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response());
}
};
match inner.oneshot(req).await { match inner.oneshot(req).await {
Ok(res) => Ok(res.into_response()), Ok(res) => Ok(res.into_response()),

View file

@ -1,10 +1,10 @@
use crate::{ use crate::{extract::rejection::*, response::IntoResponseParts};
extract::{rejection::*, FromRequest, RequestParts},
response::IntoResponseParts,
};
use async_trait::async_trait; use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response, ResponseParts}; use axum_core::{
use http::Request; extract::FromRequestParts,
response::{IntoResponse, Response, ResponseParts},
};
use http::{request::Parts, Request};
use std::{ use std::{
convert::Infallible, convert::Infallible,
ops::Deref, ops::Deref,
@ -73,17 +73,16 @@ use tower_service::Service;
pub struct Extension<T>(pub T); pub struct Extension<T>(pub T);
#[async_trait] #[async_trait]
impl<T, S, B> FromRequest<S, B> for Extension<T> impl<T, S> FromRequestParts<S> for Extension<T>
where where
T: Clone + Send + Sync + 'static, T: Clone + Send + Sync + 'static,
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = ExtensionRejection; type Rejection = ExtensionRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let value = req let value = req
.extensions() .extensions
.get::<T>() .get::<T>()
.ok_or_else(|| { .ok_or_else(|| {
MissingExtension::from_err(format!( MissingExtension::from_err(format!(

View file

@ -4,9 +4,10 @@
//! //!
//! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info //! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
use super::{Extension, FromRequest, RequestParts}; use super::{Extension, FromRequestParts};
use crate::middleware::AddExtension; use crate::middleware::AddExtension;
use async_trait::async_trait; use async_trait::async_trait;
use http::request::Parts;
use hyper::server::conn::AddrStream; use hyper::server::conn::AddrStream;
use std::{ use std::{
convert::Infallible, convert::Infallible,
@ -128,16 +129,15 @@ opaque_future! {
pub struct ConnectInfo<T>(pub T); pub struct ConnectInfo<T>(pub T);
#[async_trait] #[async_trait]
impl<S, B, T> FromRequest<S, B> for ConnectInfo<T> impl<S, T> FromRequestParts<S> for ConnectInfo<T>
where where
B: Send,
S: Send + Sync, S: Send + Sync,
T: Clone + Send + Sync + 'static, T: Clone + Send + Sync + 'static,
{ {
type Rejection = <Extension<Self> as FromRequest<S, B>>::Rejection; type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Extension(connect_info) = Extension::<Self>::from_request(req).await?; let Extension(connect_info) = Extension::<Self>::from_request_parts(parts, state).await?;
Ok(connect_info) Ok(connect_info)
} }
} }

View file

@ -1,7 +1,7 @@
use super::{rejection::*, FromRequest, RequestParts}; use super::{rejection::*, FromRequest};
use async_trait::async_trait; use async_trait::async_trait;
use axum_core::response::IntoResponse; use axum_core::{extract::FromRequestParts, response::IntoResponse};
use http::Method; use http::{request::Parts, Method, Request};
use std::ops::Deref; use std::ops::Deref;
/// Extractor that will reject requests with a body larger than some size. /// Extractor that will reject requests with a body larger than some size.
@ -40,25 +40,58 @@ impl<T, S, B, const N: u64> FromRequest<S, B> for ContentLengthLimit<T, N>
where where
T: FromRequest<S, B>, T: FromRequest<S, B>,
T::Rejection: IntoResponse, T::Rejection: IntoResponse,
B: Send, B: Send + 'static,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = ContentLengthLimitRejection<T::Rejection>; type Rejection = ContentLengthLimitRejection<T::Rejection>;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let content_length = req let (parts, body) = req.into_parts();
.headers() validate::<_, N>(&parts)?;
let req = Request::from_parts(parts, body);
let value = T::from_request(req, state)
.await
.map_err(ContentLengthLimitRejection::Inner)?;
Ok(Self(value))
}
}
#[async_trait]
impl<T, S, const N: u64> FromRequestParts<S> for ContentLengthLimit<T, N>
where
T: FromRequestParts<S>,
T::Rejection: IntoResponse,
S: Send + Sync,
{
type Rejection = ContentLengthLimitRejection<T::Rejection>;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
validate::<_, N>(parts)?;
let value = T::from_request_parts(parts, state)
.await
.map_err(ContentLengthLimitRejection::Inner)?;
Ok(Self(value))
}
}
fn validate<E, const N: u64>(parts: &Parts) -> Result<(), ContentLengthLimitRejection<E>> {
let content_length = parts
.headers
.get(http::header::CONTENT_LENGTH) .get(http::header::CONTENT_LENGTH)
.and_then(|value| value.to_str().ok()?.parse::<u64>().ok()); .and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
match (content_length, req.method()) { match (content_length, &parts.method) {
(content_length, &(Method::GET | Method::HEAD | Method::OPTIONS)) => { (content_length, &(Method::GET | Method::HEAD | Method::OPTIONS)) => {
if content_length.is_some() { if content_length.is_some() {
return Err(ContentLengthLimitRejection::ContentLengthNotAllowed( return Err(ContentLengthLimitRejection::ContentLengthNotAllowed(
ContentLengthNotAllowed, ContentLengthNotAllowed,
)); ));
} else if req } else if parts
.headers() .headers
.get(http::header::TRANSFER_ENCODING) .get(http::header::TRANSFER_ENCODING)
.map_or(false, |value| value.as_bytes() == b"chunked") .map_or(false, |value| value.as_bytes() == b"chunked")
{ {
@ -76,12 +109,7 @@ where
_ => {} _ => {}
} }
let value = T::from_request(req) Ok(())
.await
.map_err(ContentLengthLimitRejection::Inner)?;
Ok(Self(value))
}
} }
impl<T, const N: u64> Deref for ContentLengthLimit<T, N> { impl<T, const N: u64> Deref for ContentLengthLimit<T, N> {

View file

@ -1,9 +1,12 @@
use super::{ use super::{
rejection::{FailedToResolveHost, HostRejection}, rejection::{FailedToResolveHost, HostRejection},
FromRequest, RequestParts, FromRequestParts,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use http::header::{HeaderMap, FORWARDED}; use http::{
header::{HeaderMap, FORWARDED},
request::Parts,
};
const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
@ -21,35 +24,34 @@ const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
pub struct Host(pub String); pub struct Host(pub String);
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for Host impl<S> FromRequestParts<S> for Host
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = HostRejection; type Rejection = HostRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(host) = parse_forwarded(req.headers()) { if let Some(host) = parse_forwarded(&parts.headers) {
return Ok(Host(host.to_owned())); return Ok(Host(host.to_owned()));
} }
if let Some(host) = req if let Some(host) = parts
.headers() .headers
.get(X_FORWARDED_HOST_HEADER_KEY) .get(X_FORWARDED_HOST_HEADER_KEY)
.and_then(|host| host.to_str().ok()) .and_then(|host| host.to_str().ok())
{ {
return Ok(Host(host.to_owned())); return Ok(Host(host.to_owned()));
} }
if let Some(host) = req if let Some(host) = parts
.headers() .headers
.get(http::header::HOST) .get(http::header::HOST)
.and_then(|host| host.to_str().ok()) .and_then(|host| host.to_str().ok())
{ {
return Ok(Host(host.to_owned())); return Ok(Host(host.to_owned()));
} }
if let Some(host) = req.uri().host() { if let Some(host) = parts.uri.host() {
return Ok(Host(host.to_owned())); return Ok(Host(host.to_owned()));
} }

View file

@ -1,5 +1,6 @@
use super::{rejection::*, FromRequest, RequestParts}; use super::{rejection::*, FromRequestParts};
use async_trait::async_trait; use async_trait::async_trait;
use http::request::Parts;
use std::sync::Arc; use std::sync::Arc;
/// Access the path in the router that matches the request. /// Access the path in the router that matches the request.
@ -64,16 +65,15 @@ impl MatchedPath {
} }
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for MatchedPath impl<S> FromRequestParts<S> for MatchedPath
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = MatchedPathRejection; type Rejection = MatchedPathRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let matched_path = req let matched_path = parts
.extensions() .extensions
.get::<Self>() .get::<Self>()
.ok_or(MatchedPathRejection::MatchedPathMissing(MatchedPathMissing))? .ok_or(MatchedPathRejection::MatchedPathMissing(MatchedPathMissing))?
.clone(); .clone();

View file

@ -1,7 +1,6 @@
#![doc = include_str!("../docs/extract.md")] #![doc = include_str!("../docs/extract.md")]
use http::header; use http::header::{self, HeaderMap};
use rejection::*;
pub mod connect_info; pub mod connect_info;
pub mod path; pub mod path;
@ -17,7 +16,7 @@ mod request_parts;
mod state; mod state;
#[doc(inline)] #[doc(inline)]
pub use axum_core::extract::{FromRef, FromRequest, RequestParts}; pub use axum_core::extract::{FromRef, FromRequest, FromRequestParts};
#[doc(inline)] #[doc(inline)]
#[allow(deprecated)] #[allow(deprecated)]
@ -75,16 +74,9 @@ pub use self::ws::WebSocketUpgrade;
#[doc(no_inline)] #[doc(no_inline)]
pub use crate::TypedHeader; pub use crate::TypedHeader;
pub(crate) fn take_body<S, B>(req: &mut RequestParts<S, B>) -> Result<B, BodyAlreadyExtracted> {
req.take_body().ok_or_else(BodyAlreadyExtracted::default)
}
// this is duplicated in `axum-extra/src/extract/form.rs` // this is duplicated in `axum-extra/src/extract/form.rs`
pub(super) fn has_content_type<S, B>( pub(super) fn has_content_type(headers: &HeaderMap, expected_content_type: &mime::Mime) -> bool {
req: &RequestParts<S, B>, let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
expected_content_type: &mime::Mime,
) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
content_type content_type
} else { } else {
return false; return false;

View file

@ -2,12 +2,13 @@
//! //!
//! See [`Multipart`] for more details. //! See [`Multipart`] for more details.
use super::{rejection::*, BodyStream, FromRequest, RequestParts}; use super::{BodyStream, FromRequest};
use crate::body::{Bytes, HttpBody}; use crate::body::{Bytes, HttpBody};
use crate::BoxError; use crate::BoxError;
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::stream::Stream; use futures_util::stream::Stream;
use http::header::{HeaderMap, CONTENT_TYPE}; use http::header::{HeaderMap, CONTENT_TYPE};
use http::Request;
use std::{ use std::{
fmt, fmt,
pin::Pin, pin::Pin,
@ -58,10 +59,12 @@ where
{ {
type Rejection = MultipartRejection; type Rejection = MultipartRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let stream = BodyStream::from_request(req).await?; let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
let headers = req.headers(); let stream = match BodyStream::from_request(req, state).await {
let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?; Ok(stream) => stream,
Err(err) => match err {},
};
let multipart = multer::Multipart::new(stream, boundary); let multipart = multer::Multipart::new(stream, boundary);
Ok(Self { inner: multipart }) Ok(Self { inner: multipart })
} }
@ -224,7 +227,6 @@ composite_rejection! {
/// ///
/// Contains one variant for each way the [`Multipart`] extractor can fail. /// Contains one variant for each way the [`Multipart`] extractor can fail.
pub enum MultipartRejection { pub enum MultipartRejection {
BodyAlreadyExtracted,
InvalidBoundary, InvalidBoundary,
} }
} }

View file

@ -4,12 +4,12 @@
mod de; mod de;
use crate::{ use crate::{
extract::{rejection::*, FromRequest, RequestParts}, extract::{rejection::*, FromRequestParts},
routing::url_params::UrlParams, routing::url_params::UrlParams,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response}; use axum_core::response::{IntoResponse, Response};
use http::StatusCode; use http::{request::Parts, StatusCode};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::{ use std::{
fmt, fmt,
@ -163,16 +163,15 @@ impl<T> DerefMut for Path<T> {
} }
#[async_trait] #[async_trait]
impl<T, S, B> FromRequest<S, B> for Path<T> impl<T, S> FromRequestParts<S> for Path<T>
where where
T: DeserializeOwned + Send, T: DeserializeOwned + Send,
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = PathRejection; type Rejection = PathRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let params = match req.extensions_mut().get::<UrlParams>() { let params = match parts.extensions.get::<UrlParams>() {
Some(UrlParams::Params(params)) => params, Some(UrlParams::Params(params)) => params,
Some(UrlParams::InvalidUtf8InPathParam { key }) => { Some(UrlParams::InvalidUtf8InPathParam { key }) => {
let err = PathDeserializationError { let err = PathDeserializationError {
@ -413,8 +412,7 @@ impl std::error::Error for FailedToDeserializePathParams {}
mod tests { mod tests {
use super::*; use super::*;
use crate::{routing::get, test_helpers::*, Router}; use crate::{routing::get, test_helpers::*, Router};
use http::{Request, StatusCode}; use http::StatusCode;
use hyper::Body;
use std::collections::HashMap; use std::collections::HashMap;
#[tokio::test] #[tokio::test]
@ -519,20 +517,6 @@ mod tests {
assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.status(), StatusCode::OK);
} }
#[tokio::test]
async fn when_extensions_are_missing() {
let app = Router::new().route("/:key", get(|_: Request<Body>, _: Path<String>| async {}));
let client = TestClient::new(app);
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
res.text().await,
"No paths parameters found for matched route. Are you also extracting `Request<_>`?"
);
}
#[tokio::test] #[tokio::test]
async fn str_reference_deserialize() { async fn str_reference_deserialize() {
struct Param(String); struct Param(String);

View file

@ -1,5 +1,6 @@
use super::{rejection::*, FromRequest, RequestParts}; use super::{rejection::*, FromRequestParts};
use async_trait::async_trait; use async_trait::async_trait;
use http::request::Parts;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::ops::Deref; use std::ops::Deref;
@ -49,16 +50,15 @@ use std::ops::Deref;
pub struct Query<T>(pub T); pub struct Query<T>(pub T);
#[async_trait] #[async_trait]
impl<T, S, B> FromRequest<S, B> for Query<T> impl<T, S> FromRequestParts<S> for Query<T>
where where
T: DeserializeOwned, T: DeserializeOwned,
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = QueryRejection; type Rejection = QueryRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let query = req.uri().query().unwrap_or_default(); let query = parts.uri.query().unwrap_or_default();
let value = serde_urlencoded::from_str(query) let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?; .map_err(FailedToDeserializeQueryString::__private_new)?;
Ok(Query(value)) Ok(Query(value))
@ -76,15 +76,17 @@ impl<T> Deref for Query<T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::extract::RequestParts; use axum_core::extract::FromRequest;
use http::Request; use http::Request;
use serde::Deserialize; use serde::Deserialize;
use std::fmt::Debug; use std::fmt::Debug;
async fn check<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) { async fn check<T>(uri: impl AsRef<str>, value: T)
where
T: DeserializeOwned + PartialEq + Debug,
{
let req = Request::builder().uri(uri.as_ref()).body(()).unwrap(); let req = Request::builder().uri(uri.as_ref()).body(()).unwrap();
let mut req = RequestParts::new(req); assert_eq!(Query::<T>::from_request(req, &()).await.unwrap().0, value);
assert_eq!(Query::<T>::from_request(&mut req).await.unwrap().0, value);
} }
#[tokio::test] #[tokio::test]

View file

@ -1,5 +1,6 @@
use super::{FromRequest, RequestParts}; use super::FromRequestParts;
use async_trait::async_trait; use async_trait::async_trait;
use http::request::Parts;
use std::convert::Infallible; use std::convert::Infallible;
/// Extractor that extracts the raw query string, without parsing it. /// Extractor that extracts the raw query string, without parsing it.
@ -27,15 +28,14 @@ use std::convert::Infallible;
pub struct RawQuery(pub Option<String>); pub struct RawQuery(pub Option<String>);
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for RawQuery impl<S> FromRequestParts<S> for RawQuery
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let query = req.uri().query().map(|query| query.to_owned()); let query = parts.uri.query().map(|query| query.to_owned());
Ok(Self(query)) Ok(Self(query))
} }
} }

View file

@ -73,7 +73,7 @@ define_rejection! {
define_rejection! { define_rejection! {
#[status = INTERNAL_SERVER_ERROR] #[status = INTERNAL_SERVER_ERROR]
#[body = "No paths parameters found for matched route. Are you also extracting `Request<_>`?"] #[body = "No paths parameters found for matched route"]
/// Rejection type used if axum's internal representation of path parameters /// Rejection type used if axum's internal representation of path parameters
/// is missing. This is commonly caused by extracting `Request<_>`. `Path` /// is missing. This is commonly caused by extracting `Request<_>`. `Path`
/// must be extracted first. /// must be extracted first.

View file

@ -1,11 +1,11 @@
use super::{rejection::*, take_body, Extension, FromRequest, RequestParts}; use super::{Extension, FromRequest, FromRequestParts};
use crate::{ use crate::{
body::{Body, Bytes, HttpBody}, body::{Body, Bytes, HttpBody},
BoxError, Error, BoxError, Error,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::stream::Stream; use futures_util::stream::Stream;
use http::Uri; use http::{request::Parts, Request, Uri};
use std::{ use std::{
convert::Infallible, convert::Infallible,
fmt, fmt,
@ -86,17 +86,16 @@ pub struct OriginalUri(pub Uri);
#[cfg(feature = "original-uri")] #[cfg(feature = "original-uri")]
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for OriginalUri impl<S> FromRequestParts<S> for OriginalUri
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let uri = Extension::<Self>::from_request(req) let uri = Extension::<Self>::from_request_parts(parts, state)
.await .await
.unwrap_or_else(|_| Extension(OriginalUri(req.uri().clone()))) .unwrap_or_else(|_| Extension(OriginalUri(parts.uri.clone())))
.0; .0;
Ok(uri) Ok(uri)
} }
@ -148,10 +147,11 @@ where
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = BodyAlreadyExtracted; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
let body = take_body(req)? let body = req
.into_body()
.map_data(Into::into) .map_data(Into::into)
.map_err(|err| Error::new(err.into())); .map_err(|err| Error::new(err.into()));
let stream = BodyStream(SyncWrapper::new(Box::pin(body))); let stream = BodyStream(SyncWrapper::new(Box::pin(body)));
@ -203,40 +203,17 @@ where
B: Send, B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = BodyAlreadyExtracted; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
let body = take_body(req)?; Ok(Self(req.into_body()))
Ok(Self(body))
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{ use crate::{extract::Extension, routing::get, test_helpers::*, Router};
body::Body, use http::{Method, StatusCode};
extract::Extension,
routing::{get, post},
test_helpers::*,
Router,
};
use http::{Method, Request, StatusCode};
#[tokio::test]
async fn multiple_request_extractors() {
async fn handler(_: Request<Body>, _: Request<Body>) {}
let app = Router::new().route("/", post(handler));
let client = TestClient::new(app);
let res = client.post("/").body("hi there").send().await;
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
res.text().await,
"Cannot have two request body extractors for a single handler"
);
}
#[tokio::test] #[tokio::test]
async fn extract_request_parts() { async fn extract_request_parts() {
@ -256,19 +233,4 @@ mod tests {
let res = client.get("/").header("x-foo", "123").send().await; let res = client.get("/").header("x-foo", "123").send().await;
assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.status(), StatusCode::OK);
} }
#[tokio::test]
async fn extract_request_parts_doesnt_consume_the_body() {
#[derive(Clone)]
struct Ext;
async fn handler(_parts: http::request::Parts, body: String) {
assert_eq!(body, "foo");
}
let client = TestClient::new(Router::new().route("/", get(handler)));
let res = client.get("/").body("foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
} }

View file

@ -1,5 +1,6 @@
use async_trait::async_trait; use async_trait::async_trait;
use axum_core::extract::{FromRef, FromRequest, RequestParts}; use axum_core::extract::{FromRef, FromRequestParts};
use http::request::Parts;
use std::{ use std::{
convert::Infallible, convert::Infallible,
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
@ -139,7 +140,8 @@ use std::{
/// to do it: /// to do it:
/// ///
/// ```rust /// ```rust
/// use axum_core::extract::{FromRequest, RequestParts, FromRef}; /// use axum_core::extract::{FromRequestParts, FromRef};
/// use http::request::Parts;
/// use async_trait::async_trait; /// use async_trait::async_trait;
/// use std::convert::Infallible; /// use std::convert::Infallible;
/// ///
@ -147,9 +149,8 @@ use std::{
/// struct MyLibraryExtractor; /// struct MyLibraryExtractor;
/// ///
/// #[async_trait] /// #[async_trait]
/// impl<S, B> FromRequest<S, B> for MyLibraryExtractor /// impl<S> FromRequestParts<S> for MyLibraryExtractor
/// where /// where
/// B: Send,
/// // keep `S` generic but require that it can produce a `MyLibraryState` /// // keep `S` generic but require that it can produce a `MyLibraryState`
/// // this means users will have to implement `FromRef<UserState> for MyLibraryState` /// // this means users will have to implement `FromRef<UserState> for MyLibraryState`
/// MyLibraryState: FromRef<S>, /// MyLibraryState: FromRef<S>,
@ -157,9 +158,9 @@ use std::{
/// { /// {
/// type Rejection = Infallible; /// type Rejection = Infallible;
/// ///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
/// // get a `MyLibraryState` from a reference to the state /// // get a `MyLibraryState` from a reference to the state
/// let state = MyLibraryState::from_ref(req.state()); /// let state = MyLibraryState::from_ref(state);
/// ///
/// // ... /// // ...
/// # todo!() /// # todo!()
@ -171,23 +172,22 @@ use std::{
/// // ... /// // ...
/// } /// }
/// ``` /// ```
///
/// Note that you don't need to use the `State` extractor since you can access the state directly
/// from [`RequestParts`].
#[derive(Debug, Default, Clone, Copy)] #[derive(Debug, Default, Clone, Copy)]
pub struct State<S>(pub S); pub struct State<S>(pub S);
#[async_trait] #[async_trait]
impl<B, OuterState, InnerState> FromRequest<OuterState, B> for State<InnerState> impl<OuterState, InnerState> FromRequestParts<OuterState> for State<InnerState>
where where
B: Send,
InnerState: FromRef<OuterState>, InnerState: FromRef<OuterState>,
OuterState: Send + Sync, OuterState: Send + Sync,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<OuterState, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(
let inner_state = InnerState::from_ref(req.state()); _parts: &mut Parts,
state: &OuterState,
) -> Result<Self, Self::Rejection> {
let inner_state = InnerState::from_ref(state);
Ok(Self(inner_state)) Ok(Self(inner_state))
} }
} }

View file

@ -95,7 +95,7 @@
//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split //! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
use self::rejection::*; use self::rejection::*;
use super::{FromRequest, RequestParts}; use super::FromRequestParts;
use crate::{ use crate::{
body::{self, Bytes}, body::{self, Bytes},
response::Response, response::Response,
@ -107,7 +107,8 @@ use futures_util::{
stream::{Stream, StreamExt}, stream::{Stream, StreamExt},
}; };
use http::{ use http::{
header::{self, HeaderName, HeaderValue}, header::{self, HeaderMap, HeaderName, HeaderValue},
request::Parts,
Method, StatusCode, Method, StatusCode,
}; };
use hyper::upgrade::{OnUpgrade, Upgraded}; use hyper::upgrade::{OnUpgrade, Upgraded};
@ -275,41 +276,40 @@ impl WebSocketUpgrade {
} }
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for WebSocketUpgrade impl<S> FromRequestParts<S> for WebSocketUpgrade
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = WebSocketUpgradeRejection; type Rejection = WebSocketUpgradeRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if req.method() != Method::GET { if parts.method != Method::GET {
return Err(MethodNotGet.into()); return Err(MethodNotGet.into());
} }
if !header_contains(req, header::CONNECTION, "upgrade") { if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into()); return Err(InvalidConnectionHeader.into());
} }
if !header_eq(req, header::UPGRADE, "websocket") { if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into()); return Err(InvalidUpgradeHeader.into());
} }
if !header_eq(req, header::SEC_WEBSOCKET_VERSION, "13") { if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
return Err(InvalidWebSocketVersionHeader.into()); return Err(InvalidWebSocketVersionHeader.into());
} }
let sec_websocket_key = req let sec_websocket_key = parts
.headers_mut() .headers
.remove(header::SEC_WEBSOCKET_KEY) .remove(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?; .ok_or(WebSocketKeyHeaderMissing)?;
let on_upgrade = req let on_upgrade = parts
.extensions_mut() .extensions
.remove::<OnUpgrade>() .remove::<OnUpgrade>()
.ok_or(ConnectionNotUpgradable)?; .ok_or(ConnectionNotUpgradable)?;
let sec_websocket_protocol = req.headers().get(header::SEC_WEBSOCKET_PROTOCOL).cloned(); let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
Ok(Self { Ok(Self {
config: Default::default(), config: Default::default(),
@ -321,16 +321,16 @@ where
} }
} }
fn header_eq<S, B>(req: &RequestParts<S, B>, key: HeaderName, value: &'static str) -> bool { fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = req.headers().get(&key) { if let Some(header) = headers.get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
} else { } else {
false false
} }
} }
fn header_contains<S, B>(req: &RequestParts<S, B>, key: HeaderName, value: &'static str) -> bool { fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
let header = if let Some(header) = req.headers().get(&key) { let header = if let Some(header) = headers.get(&key) {
header header
} else { } else {
return false; return false;

View file

@ -1,10 +1,10 @@
use crate::body::{Bytes, HttpBody}; use crate::body::{Bytes, HttpBody};
use crate::extract::{has_content_type, rejection::*, FromRequest, RequestParts}; use crate::extract::{has_content_type, rejection::*, FromRequest};
use crate::BoxError; use crate::BoxError;
use async_trait::async_trait; use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response}; use axum_core::response::{IntoResponse, Response};
use http::header::CONTENT_TYPE; use http::header::CONTENT_TYPE;
use http::{Method, StatusCode}; use http::{Method, Request, StatusCode};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::Serialize; use serde::Serialize;
use std::ops::Deref; use std::ops::Deref;
@ -59,25 +59,25 @@ pub struct Form<T>(pub T);
impl<T, S, B> FromRequest<S, B> for Form<T> impl<T, S, B> FromRequest<S, B> for Form<T>
where where
T: DeserializeOwned, T: DeserializeOwned,
B: HttpBody + Send, B: HttpBody + Send + 'static,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = FormRejection; type Rejection = FormRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
if req.method() == Method::GET { if req.method() == Method::GET {
let query = req.uri().query().unwrap_or_default(); let query = req.uri().query().unwrap_or_default();
let value = serde_urlencoded::from_str(query) let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?; .map_err(FailedToDeserializeQueryString::__private_new)?;
Ok(Form(value)) Ok(Form(value))
} else { } else {
if !has_content_type(req, &mime::APPLICATION_WWW_FORM_URLENCODED) { if !has_content_type(req.headers(), &mime::APPLICATION_WWW_FORM_URLENCODED) {
return Err(InvalidFormContentType.into()); return Err(InvalidFormContentType.into());
} }
let bytes = Bytes::from_request(req).await?; let bytes = Bytes::from_request(req, state).await?;
let value = serde_urlencoded::from_bytes(&bytes) let value = serde_urlencoded::from_bytes(&bytes)
.map_err(FailedToDeserializeQueryString::__private_new)?; .map_err(FailedToDeserializeQueryString::__private_new)?;
@ -114,7 +114,6 @@ impl<T> Deref for Form<T> {
mod tests { mod tests {
use super::*; use super::*;
use crate::body::{Empty, Full}; use crate::body::{Empty, Full};
use crate::extract::RequestParts;
use http::Request; use http::Request;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt::Debug; use std::fmt::Debug;
@ -130,8 +129,7 @@ mod tests {
.uri(uri.as_ref()) .uri(uri.as_ref())
.body(Empty::<Bytes>::new()) .body(Empty::<Bytes>::new())
.unwrap(); .unwrap();
let mut req = RequestParts::new(req); assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
assert_eq!(Form::<T>::from_request(&mut req).await.unwrap().0, value);
} }
async fn check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T) { async fn check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T) {
@ -146,8 +144,7 @@ mod tests {
serde_urlencoded::to_string(&value).unwrap().into(), serde_urlencoded::to_string(&value).unwrap().into(),
)) ))
.unwrap(); .unwrap();
let mut req = RequestParts::new(req); assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
assert_eq!(Form::<T>::from_request(&mut req).await.unwrap().0, value);
} }
#[tokio::test] #[tokio::test]
@ -216,9 +213,8 @@ mod tests {
.into(), .into(),
)) ))
.unwrap(); .unwrap();
let mut req = RequestParts::new(req);
assert!(matches!( assert!(matches!(
Form::<Pagination>::from_request(&mut req) Form::<Pagination>::from_request(req, &())
.await .await
.unwrap_err(), .unwrap_err(),
FormRejection::InvalidFormContentType(InvalidFormContentType) FormRejection::InvalidFormContentType(InvalidFormContentType)

View file

@ -88,7 +88,7 @@ where
use futures_util::future::FutureExt; use futures_util::future::FutureExt;
let handler = self.handler.clone(); let handler = self.handler.clone();
let future = Handler::call(handler, Arc::clone(&self.state), req); let future = Handler::call(handler, req, Arc::clone(&self.state));
let future = future.map(Ok as _); let future = future.map(Ok as _);
super::future::IntoServiceFuture::new(future) super::future::IntoServiceFuture::new(future)

View file

@ -78,7 +78,7 @@ where
.expect("state extension missing. This is a bug in axum, please file an issue"); .expect("state extension missing. This is a bug in axum, please file an issue");
let handler = self.handler.clone(); let handler = self.handler.clone();
let future = Handler::call(handler, state, req); let future = Handler::call(handler, req, state);
let future = future.map(Ok as _); let future = future.map(Ok as _);
super::future::IntoServiceFuture::new(future) super::future::IntoServiceFuture::new(future)

View file

@ -37,7 +37,7 @@
use crate::{ use crate::{
body::Body, body::Body,
extract::{connect_info::IntoMakeServiceWithConnectInfo, FromRequest, RequestParts}, extract::{connect_info::IntoMakeServiceWithConnectInfo, FromRequest, FromRequestParts},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
routing::IntoMakeService, routing::IntoMakeService,
}; };
@ -95,12 +95,12 @@ pub use self::{into_service::IntoService, with_state::WithState};
/// {} /// {}
/// ``` /// ```
#[doc = include_str!("../docs/debugging_handler_type_errors.md")] #[doc = include_str!("../docs/debugging_handler_type_errors.md")]
pub trait Handler<T, S = (), B = Body>: Clone + Send + Sized + 'static { pub trait Handler<T, S, B = Body>: Clone + Send + Sized + 'static {
/// The type of future calling this handler returns. /// The type of future calling this handler returns.
type Future: Future<Output = Response> + Send + 'static; type Future: Future<Output = Response> + Send + 'static;
/// Call the handler with the given request. /// Call the handler with the given request.
fn call(self, state: Arc<S>, req: Request<B>) -> Self::Future; fn call(self, req: Request<B>, state: Arc<S>) -> Self::Future;
/// Apply a [`tower::Layer`] to the handler. /// Apply a [`tower::Layer`] to the handler.
/// ///
@ -162,7 +162,7 @@ pub trait Handler<T, S = (), B = Body>: Clone + Send + Sized + 'static {
} }
} }
impl<F, Fut, Res, S, B> Handler<(), S, B> for F impl<F, Fut, Res, S, B> Handler<((),), S, B> for F
where where
F: FnOnce() -> Fut + Clone + Send + 'static, F: FnOnce() -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send, Fut: Future<Output = Res> + Send,
@ -171,37 +171,48 @@ where
{ {
type Future = Pin<Box<dyn Future<Output = Response> + Send>>; type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, _state: Arc<S>, _req: Request<B>) -> Self::Future { fn call(self, _req: Request<B>, _state: Arc<S>) -> Self::Future {
Box::pin(async move { self().await.into_response() }) Box::pin(async move { self().await.into_response() })
} }
} }
macro_rules! impl_handler { macro_rules! impl_handler {
( $($ty:ident),* $(,)? ) => { (
#[allow(non_snake_case)] [$($ty:ident),*], $last:ident
impl<F, Fut, S, B, Res, $($ty,)*> Handler<($($ty,)*), S, B> for F ) => {
#[allow(non_snake_case, unused_mut)]
impl<F, Fut, S, B, Res, M, $($ty,)* $last> Handler<(M, $($ty,)* $last,), S, B> for F
where where
F: FnOnce($($ty,)*) -> Fut + Clone + Send + 'static, F: FnOnce($($ty,)* $last,) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send, Fut: Future<Output = Res> + Send,
B: Send + 'static, B: Send + 'static,
S: Send + Sync + 'static, S: Send + Sync + 'static,
Res: IntoResponse, Res: IntoResponse,
$( $ty: FromRequest<S, B> + Send,)* $( $ty: FromRequestParts<S> + Send, )*
$last: FromRequest<S, B, M> + Send,
{ {
type Future = Pin<Box<dyn Future<Output = Response> + Send>>; type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, state: Arc<S>, req: Request<B>) -> Self::Future { fn call(self, req: Request<B>, state: Arc<S>) -> Self::Future {
Box::pin(async move { Box::pin(async move {
let mut req = RequestParts::with_state_arc(state, req); let (mut parts, body) = req.into_parts();
let state = &state;
$( $(
let $ty = match $ty::from_request(&mut req).await { let $ty = match $ty::from_request_parts(&mut parts, state).await {
Ok(value) => value, Ok(value) => value,
Err(rejection) => return rejection.into_response(), Err(rejection) => return rejection.into_response(),
}; };
)* )*
let res = self($($ty,)*).await; let req = Request::from_parts(parts, body);
let $last = match $last::from_request(req, state).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
let res = self($($ty,)* $last,).await;
res.into_response() res.into_response()
}) })
@ -210,7 +221,31 @@ macro_rules! impl_handler {
}; };
} }
all_the_tuples!(impl_handler); impl_handler!([], T1);
impl_handler!([T1], T2);
impl_handler!([T1, T2], T3);
impl_handler!([T1, T2, T3], T4);
impl_handler!([T1, T2, T3, T4], T5);
impl_handler!([T1, T2, T3, T4, T5], T6);
impl_handler!([T1, T2, T3, T4, T5, T6], T7);
impl_handler!([T1, T2, T3, T4, T5, T6, T7], T8);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8], T9);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13);
impl_handler!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13],
T14
);
impl_handler!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14],
T15
);
impl_handler!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15],
T16
);
/// A [`Service`] created from a [`Handler`] by applying a Tower middleware. /// A [`Service`] created from a [`Handler`] by applying a Tower middleware.
/// ///
@ -259,7 +294,7 @@ where
{ {
type Future = future::LayeredFuture<B, L::Service>; type Future = future::LayeredFuture<B, L::Service>;
fn call(self, state: Arc<S>, req: Request<B>) -> Self::Future { fn call(self, req: Request<B>, state: Arc<S>) -> Self::Future {
use futures_util::future::{FutureExt, Map}; use futures_util::future::{FutureExt, Map};
let svc = self.handler.with_state_arc(state); let svc = self.handler.with_state_arc(state);

View file

@ -1,14 +1,14 @@
use crate::{ use crate::{
body::{Bytes, HttpBody}, body::{Bytes, HttpBody},
extract::{rejection::*, FromRequest, RequestParts}, extract::{rejection::*, FromRequest},
BoxError, BoxError,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response}; use axum_core::response::{IntoResponse, Response};
use bytes::{BufMut, BytesMut}; use bytes::{BufMut, BytesMut};
use http::{ use http::{
header::{self, HeaderValue}, header::{self, HeaderMap, HeaderValue},
StatusCode, Request, StatusCode,
}; };
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
@ -97,16 +97,16 @@ pub struct Json<T>(pub T);
impl<T, S, B> FromRequest<S, B> for Json<T> impl<T, S, B> FromRequest<S, B> for Json<T>
where where
T: DeserializeOwned, T: DeserializeOwned,
B: HttpBody + Send, B: HttpBody + Send + 'static,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = JsonRejection; type Rejection = JsonRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
if json_content_type(req) { if json_content_type(req.headers()) {
let bytes = Bytes::from_request(req).await?; let bytes = Bytes::from_request(req, state).await?;
let value = match serde_json::from_slice(&bytes) { let value = match serde_json::from_slice(&bytes) {
Ok(value) => value, Ok(value) => value,
@ -137,8 +137,8 @@ where
} }
} }
fn json_content_type<S, B>(req: &RequestParts<S, B>) -> bool { fn json_content_type(headers: &HeaderMap) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
content_type content_type
} else { } else {
return false; return false;

View file

@ -93,8 +93,8 @@
//! //!
//! # Extractors //! # Extractors
//! //!
//! An extractor is a type that implements [`FromRequest`]. Extractors is how //! An extractor is a type that implements [`FromRequest`] or [`FromRequestParts`]. Extractors is
//! you pick apart the incoming request to get the parts your handler needs. //! how you pick apart the incoming request to get the parts your handler needs.
//! //!
//! ```rust //! ```rust
//! use axum::extract::{Path, Query, Json}; //! use axum::extract::{Path, Query, Json};
@ -302,9 +302,10 @@
//! //!
//! # Building integrations for axum //! # Building integrations for axum
//! //!
//! Libraries authors that want to provide [`FromRequest`] or [`IntoResponse`] implementations //! Libraries authors that want to provide [`FromRequest`], [`FromRequestParts`], or
//! should depend on the [`axum-core`] crate, instead of `axum` if possible. [`axum-core`] contains //! [`IntoResponse`] implementations should depend on the [`axum-core`] crate, instead of `axum` if
//! core types and traits and is less likely to receive breaking changes. //! possible. [`axum-core`] contains core types and traits and is less likely to receive breaking
//! changes.
//! //!
//! # Required dependencies //! # Required dependencies
//! //!
@ -376,6 +377,7 @@
//! [tower-guides]: https://github.com/tower-rs/tower/tree/master/guides //! [tower-guides]: https://github.com/tower-rs/tower/tree/master/guides
//! [`Uuid`]: https://docs.rs/uuid/latest/uuid/ //! [`Uuid`]: https://docs.rs/uuid/latest/uuid/
//! [`FromRequest`]: crate::extract::FromRequest //! [`FromRequest`]: crate::extract::FromRequest
//! [`FromRequestParts`]: crate::extract::FromRequestParts
//! [`HeaderMap`]: http::header::HeaderMap //! [`HeaderMap`]: http::header::HeaderMap
//! [`Request`]: http::Request //! [`Request`]: http::Request
//! [customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs //! [customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs

View file

@ -1,5 +1,5 @@
use crate::{ use crate::{
extract::{FromRequest, RequestParts}, extract::FromRequestParts,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use futures_util::{future::BoxFuture, ready}; use futures_util::{future::BoxFuture, ready};
@ -33,28 +33,27 @@ use tower_service::Service;
/// ///
/// ```rust /// ```rust
/// use axum::{ /// use axum::{
/// extract::{FromRequest, RequestParts}, /// extract::FromRequestParts,
/// middleware::from_extractor, /// middleware::from_extractor,
/// routing::{get, post}, /// routing::{get, post},
/// Router, /// Router,
/// http::{header, StatusCode, request::Parts},
/// }; /// };
/// use http::{header, StatusCode};
/// use async_trait::async_trait; /// use async_trait::async_trait;
/// ///
/// // An extractor that performs authorization. /// // An extractor that performs authorization.
/// struct RequireAuth; /// struct RequireAuth;
/// ///
/// #[async_trait] /// #[async_trait]
/// impl<S, B> FromRequest<S, B> for RequireAuth /// impl<S> FromRequestParts<S> for RequireAuth
/// where /// where
/// B: Send,
/// S: Send + Sync, /// S: Send + Sync,
/// { /// {
/// type Rejection = StatusCode; /// type Rejection = StatusCode;
/// ///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
/// let auth_header = req /// let auth_header = parts
/// .headers() /// .headers
/// .get(header::AUTHORIZATION) /// .get(header::AUTHORIZATION)
/// .and_then(|value| value.to_str().ok()); /// .and_then(|value| value.to_str().ok());
/// ///
@ -169,7 +168,7 @@ where
impl<S, E, B> Service<Request<B>> for FromExtractor<S, E> impl<S, E, B> Service<Request<B>> for FromExtractor<S, E>
where where
E: FromRequest<(), B> + 'static, E: FromRequestParts<()> + 'static,
B: Default + Send + 'static, B: Default + Send + 'static,
S: Service<Request<B>> + Clone, S: Service<Request<B>> + Clone,
S::Response: IntoResponse, S::Response: IntoResponse,
@ -185,8 +184,9 @@ where
fn call(&mut self, req: Request<B>) -> Self::Future { fn call(&mut self, req: Request<B>) -> Self::Future {
let extract_future = Box::pin(async move { let extract_future = Box::pin(async move {
let mut req = RequestParts::new(req); let (mut parts, body) = req.into_parts();
let extracted = E::from_request(&mut req).await; let extracted = E::from_request_parts(&mut parts, &()).await;
let req = Request::from_parts(parts, body);
(req, extracted) (req, extracted)
}); });
@ -204,7 +204,7 @@ pin_project! {
#[allow(missing_debug_implementations)] #[allow(missing_debug_implementations)]
pub struct ResponseFuture<B, S, E> pub struct ResponseFuture<B, S, E>
where where
E: FromRequest<(), B>, E: FromRequestParts<()>,
S: Service<Request<B>>, S: Service<Request<B>>,
{ {
#[pin] #[pin]
@ -217,11 +217,11 @@ pin_project! {
#[project = StateProj] #[project = StateProj]
enum State<B, S, E> enum State<B, S, E>
where where
E: FromRequest<(), B>, E: FromRequestParts<()>,
S: Service<Request<B>>, S: Service<Request<B>>,
{ {
Extracting { Extracting {
future: BoxFuture<'static, (RequestParts<(), B>, Result<E, E::Rejection>)>, future: BoxFuture<'static, (Request<B>, Result<E, E::Rejection>)>,
}, },
Call { #[pin] future: S::Future }, Call { #[pin] future: S::Future },
} }
@ -229,7 +229,7 @@ pin_project! {
impl<B, S, E> Future for ResponseFuture<B, S, E> impl<B, S, E> Future for ResponseFuture<B, S, E>
where where
E: FromRequest<(), B>, E: FromRequestParts<()>,
S: Service<Request<B>>, S: Service<Request<B>>,
S::Response: IntoResponse, S::Response: IntoResponse,
B: Default, B: Default,
@ -247,7 +247,6 @@ where
match extracted { match extracted {
Ok(_) => { Ok(_) => {
let mut svc = this.svc.take().expect("future polled after completion"); let mut svc = this.svc.take().expect("future polled after completion");
let req = req.try_into_request().unwrap_or_default();
let future = svc.call(req); let future = svc.call(req);
State::Call { future } State::Call { future }
} }
@ -273,23 +272,25 @@ where
mod tests { mod tests {
use super::*; use super::*;
use crate::{handler::Handler, routing::get, test_helpers::*, Router}; use crate::{handler::Handler, routing::get, test_helpers::*, Router};
use http::{header, StatusCode}; use http::{header, request::Parts, StatusCode};
#[tokio::test] #[tokio::test]
async fn test_from_extractor() { async fn test_from_extractor() {
struct RequireAuth; struct RequireAuth;
#[async_trait::async_trait] #[async_trait::async_trait]
impl<S, B> FromRequest<S, B> for RequireAuth impl<S> FromRequestParts<S> for RequireAuth
where where
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = StatusCode; type Rejection = StatusCode;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(
if let Some(auth) = req parts: &mut Parts,
.headers() _state: &S,
) -> Result<Self, Self::Rejection> {
if let Some(auth) = parts
.headers
.get(header::AUTHORIZATION) .get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
{ {

View file

@ -1,5 +1,5 @@
use crate::response::{IntoResponse, Response}; use crate::response::{IntoResponse, Response};
use axum_core::extract::{FromRequest, RequestParts}; use axum_core::extract::{FromRequest, FromRequestParts};
use futures_util::future::BoxFuture; use futures_util::future::BoxFuture;
use http::Request; use http::Request;
use std::{ use std::{
@ -249,12 +249,15 @@ where
} }
macro_rules! impl_service { macro_rules! impl_service {
( $($ty:ident),* $(,)? ) => { (
#[allow(non_snake_case)] [$($ty:ident),*], $last:ident
impl<F, Fut, Out, S, B, $($ty,)*> Service<Request<B>> for FromFn<F, S, ($($ty,)*)> ) => {
#[allow(non_snake_case, unused_mut)]
impl<F, Fut, Out, S, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, ($($ty,)* $last,)>
where where
F: FnMut($($ty),*, Next<B>) -> Fut + Clone + Send + 'static, F: FnMut($($ty,)* $last, Next<B>) -> Fut + Clone + Send + 'static,
$( $ty: FromRequest<(), B> + Send, )* $( $ty: FromRequestParts<()> + Send, )*
$last: FromRequest<(), B> + Send,
Fut: Future<Output = Out> + Send + 'static, Fut: Future<Output = Out> + Send + 'static,
Out: IntoResponse + 'static, Out: IntoResponse + 'static,
S: Service<Request<B>, Error = Infallible> S: Service<Request<B>, Error = Infallible>
@ -280,21 +283,29 @@ macro_rules! impl_service {
let mut f = self.f.clone(); let mut f = self.f.clone();
let future = Box::pin(async move { let future = Box::pin(async move {
let mut parts = RequestParts::new(req); let (mut parts, body) = req.into_parts();
$( $(
let $ty = match $ty::from_request(&mut parts).await { let $ty = match $ty::from_request_parts(&mut parts, &()).await {
Ok(value) => value, Ok(value) => value,
Err(rejection) => return rejection.into_response(), Err(rejection) => return rejection.into_response(),
}; };
)* )*
let req = Request::from_parts(parts, body);
let $last = match $last::from_request(req, &()).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
let inner = ServiceBuilder::new() let inner = ServiceBuilder::new()
.boxed_clone() .boxed_clone()
.map_response(IntoResponse::into_response) .map_response(IntoResponse::into_response)
.service(ready_inner); .service(ready_inner);
let next = Next { inner }; let next = Next { inner };
f($($ty),*, next).await.into_response() f($($ty,)* $last, next).await.into_response()
}); });
ResponseFuture { ResponseFuture {
@ -305,7 +316,31 @@ macro_rules! impl_service {
}; };
} }
all_the_tuples!(impl_service); impl_service!([], T1);
impl_service!([T1], T2);
impl_service!([T1, T2], T3);
impl_service!([T1, T2, T3], T4);
impl_service!([T1, T2, T3, T4], T5);
impl_service!([T1, T2, T3, T4, T5], T6);
impl_service!([T1, T2, T3, T4, T5, T6], T7);
impl_service!([T1, T2, T3, T4, T5, T6, T7], T8);
impl_service!([T1, T2, T3, T4, T5, T6, T7, T8], T9);
impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10);
impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11);
impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12);
impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13);
impl_service!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13],
T14
);
impl_service!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14],
T15
);
impl_service!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15],
T16
);
impl<F, S, T> fmt::Debug for FromFn<F, S, T> impl<F, S, T> fmt::Debug for FromFn<F, S, T>
where where

View file

@ -1,7 +1,8 @@
use crate::extract::{FromRequest, RequestParts}; use crate::extract::FromRequestParts;
use async_trait::async_trait; use async_trait::async_trait;
use axum_core::response::{IntoResponse, IntoResponseParts, Response, ResponseParts}; use axum_core::response::{IntoResponse, IntoResponseParts, Response, ResponseParts};
use headers::HeaderMapExt; use headers::HeaderMapExt;
use http::request::Parts;
use std::{convert::Infallible, ops::Deref}; use std::{convert::Infallible, ops::Deref};
/// Extractor and response that works with typed header values from [`headers`]. /// Extractor and response that works with typed header values from [`headers`].
@ -52,16 +53,15 @@ use std::{convert::Infallible, ops::Deref};
pub struct TypedHeader<T>(pub T); pub struct TypedHeader<T>(pub T);
#[async_trait] #[async_trait]
impl<T, S, B> FromRequest<S, B> for TypedHeader<T> impl<T, S> FromRequestParts<S> for TypedHeader<T>
where where
T: headers::Header, T: headers::Header,
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = TypedHeaderRejection; type Rejection = TypedHeaderRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
match req.headers().typed_try_get::<T>() { match parts.headers.typed_try_get::<T>() {
Ok(Some(value)) => Ok(Self(value)), Ok(Some(value)) => Ok(Self(value)),
Ok(None) => Err(TypedHeaderRejection { Ok(None) => Err(TypedHeaderRejection {
name: T::name(), name: T::name(),

View file

@ -7,7 +7,7 @@
use axum::{ use axum::{
async_trait, async_trait,
body::{self, BoxBody, Bytes, Full}, body::{self, BoxBody, Bytes, Full},
extract::{FromRequest, RequestParts}, extract::FromRequest,
http::{Request, StatusCode}, http::{Request, StatusCode},
middleware::{self, Next}, middleware::{self, Next},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
@ -72,31 +72,28 @@ fn do_thing_with_request_body(bytes: Bytes) {
tracing::debug!(body = ?bytes); tracing::debug!(body = ?bytes);
} }
async fn handler(_: PrintRequestBody, body: Bytes) { async fn handler(BufferRequestBody(body): BufferRequestBody) {
tracing::debug!(?body, "handler received body"); tracing::debug!(?body, "handler received body");
} }
// extractor that shows how to consume the request body upfront // extractor that shows how to consume the request body upfront
struct PrintRequestBody; struct BufferRequestBody(Bytes);
// we must implement `FromRequest` (and not `FromRequestParts`) to consume the body
#[async_trait] #[async_trait]
impl<S> FromRequest<S, BoxBody> for PrintRequestBody impl<S> FromRequest<S, BoxBody> for BufferRequestBody
where where
S: Clone + Send + Sync, S: Send + Sync,
{ {
type Rejection = Response; type Rejection = Response;
async fn from_request(req: &mut RequestParts<S, BoxBody>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<BoxBody>, state: &S) -> Result<Self, Self::Rejection> {
let state = req.state().clone(); let body = Bytes::from_request(req, state)
let request = Request::from_request(req)
.await .await
.map_err(|err| err.into_response())?; .map_err(|err| err.into_response())?;
let request = buffer_request_body(request).await?; do_thing_with_request_body(body.clone());
*req = RequestParts::with_state(state, request); Ok(Self(body))
Ok(Self)
} }
} }

View file

@ -4,15 +4,13 @@
//! and `async/await`. This means that you can create more powerful rejections //! and `async/await`. This means that you can create more powerful rejections
//! - Boilerplate: Requires creating a new extractor for every custom rejection //! - Boilerplate: Requires creating a new extractor for every custom rejection
//! - Complexity: Manually implementing `FromRequest` results on more complex code //! - Complexity: Manually implementing `FromRequest` results on more complex code
use axum::extract::MatchedPath;
use axum::{ use axum::{
async_trait, async_trait,
extract::{rejection::JsonRejection, FromRequest, RequestParts}, extract::{rejection::JsonRejection, FromRequest, FromRequestParts, MatchedPath},
http::Request,
http::StatusCode, http::StatusCode,
response::IntoResponse, response::IntoResponse,
BoxError,
}; };
use serde::de::DeserializeOwned;
use serde_json::{json, Value}; use serde_json::{json, Value};
pub async fn handler(Json(value): Json<Value>) -> impl IntoResponse { pub async fn handler(Json(value): Json<Value>) -> impl IntoResponse {
@ -25,31 +23,33 @@ pub struct Json<T>(pub T);
#[async_trait] #[async_trait]
impl<S, B, T> FromRequest<S, B> for Json<T> impl<S, B, T> FromRequest<S, B> for Json<T>
where where
axum::Json<T>: FromRequest<S, B, Rejection = JsonRejection>,
S: Send + Sync, S: Send + Sync,
// these trait bounds are copied from `impl FromRequest for axum::Json` B: Send + 'static,
// `T: Send` is required to send this future across an await
T: DeserializeOwned + Send,
B: axum::body::HttpBody + Send,
B::Data: Send,
B::Error: Into<BoxError>,
{ {
type Rejection = (StatusCode, axum::Json<Value>); type Rejection = (StatusCode, axum::Json<Value>);
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
match axum::Json::<T>::from_request(req).await { let (mut parts, body) = req.into_parts();
Ok(value) => Ok(Self(value.0)),
// convert the error from `axum::Json` into whatever we want
Err(rejection) => {
let path = req
.extract::<MatchedPath>()
.await
.map(|x| x.as_str().to_owned())
.ok();
// We can use other extractors to provide better rejection // We can use other extractors to provide better rejection
// messages. For example, here we are using // messages. For example, here we are using
// `axum::extract::MatchedPath` to provide a better error // `axum::extract::MatchedPath` to provide a better error
// message // message
//
// Have to run that first since `Json::from_request` consumes
// the request
let path = MatchedPath::from_request_parts(&mut parts, state)
.await
.map(|path| path.as_str().to_owned())
.ok();
let req = Request::from_parts(parts, body);
match axum::Json::<T>::from_request(req, state).await {
Ok(value) => Ok(Self(value.0)),
// convert the error from `axum::Json` into whatever we want
Err(rejection) => {
let payload = json!({ let payload = json!({
"message": rejection.to_string(), "message": rejection.to_string(),
"origin": "custom_extractor", "origin": "custom_extractor",

View file

@ -47,7 +47,7 @@ impl From<JsonRejection> for ApiError {
} }
} }
// We implement `IntoResponse` so ApiError can be used as a response // We implement `IntoResponse` so `ApiError` can be used as a response
impl IntoResponse for ApiError { impl IntoResponse for ApiError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
let payload = json!({ let payload = json!({

View file

@ -6,8 +6,8 @@
use axum::{ use axum::{
async_trait, async_trait,
extract::{path::ErrorKind, rejection::PathRejection, FromRequest, RequestParts}, extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts},
http::StatusCode, http::{request::Parts, StatusCode},
response::IntoResponse, response::IntoResponse,
routing::get, routing::get,
Router, Router,
@ -52,17 +52,16 @@ struct Params {
struct Path<T>(T); struct Path<T>(T);
#[async_trait] #[async_trait]
impl<S, B, T> FromRequest<S, B> for Path<T> impl<S, T> FromRequestParts<S> for Path<T>
where where
// these trait bounds are copied from `impl FromRequest for axum::extract::path::Path` // these trait bounds are copied from `impl FromRequest for axum::extract::path::Path`
T: DeserializeOwned + Send, T: DeserializeOwned + Send,
B: Send,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = (StatusCode, axum::Json<PathError>); type Rejection = (StatusCode, axum::Json<PathError>);
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
match axum::extract::Path::<T>::from_request(req).await { match axum::extract::Path::<T>::from_request_parts(parts, state).await {
Ok(value) => Ok(Self(value.0)), Ok(value) => Ok(Self(value.0)),
Err(rejection) => { Err(rejection) => {
let (status, body) = match rejection { let (status, body) = match rejection {

View file

@ -65,8 +65,8 @@ async fn users_show(
/// Handler for `POST /users`. /// Handler for `POST /users`.
async fn users_create( async fn users_create(
Json(params): Json<CreateUser>,
State(user_repo): State<DynUserRepo>, State(user_repo): State<DynUserRepo>,
Json(params): Json<CreateUser>,
) -> Result<Json<User>, AppError> { ) -> Result<Json<User>, AppError> {
let user = user_repo.create(params).await?; let user = user_repo.create(params).await?;

View file

@ -8,9 +8,9 @@
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRequest, RequestParts, TypedHeader}, extract::{FromRequestParts, TypedHeader},
headers::{authorization::Bearer, Authorization}, headers::{authorization::Bearer, Authorization},
http::StatusCode, http::{request::Parts, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
routing::{get, post}, routing::{get, post},
Json, Router, Json, Router,
@ -122,17 +122,16 @@ impl AuthBody {
} }
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for Claims impl<S> FromRequestParts<S> for Claims
where where
S: Send + Sync, S: Send + Sync,
B: Send,
{ {
type Rejection = AuthError; type Rejection = AuthError;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// Extract the token from the authorization header // Extract the token from the authorization header
let TypedHeader(Authorization(bearer)) = let TypedHeader(Authorization(bearer)) =
TypedHeader::<Authorization<Bearer>>::from_request(req) TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
.await .await
.map_err(|_| AuthError::InvalidToken)?; .map_err(|_| AuthError::InvalidToken)?;
// Decode the user data // Decode the user data

View file

@ -96,8 +96,8 @@ async fn kv_get(
async fn kv_set( async fn kv_set(
Path(key): Path<String>, Path(key): Path<String>,
ContentLengthLimit(bytes): ContentLengthLimit<Bytes, { 1024 * 5_000 }>, // ~5mb
State(state): State<SharedState>, State(state): State<SharedState>,
ContentLengthLimit(bytes): ContentLengthLimit<Bytes, { 1024 * 5_000 }>, // ~5mb
) { ) {
state.write().unwrap().db.insert(key, bytes); state.write().unwrap().db.insert(key, bytes);
} }

View file

@ -12,15 +12,14 @@ use async_session::{MemoryStore, Session, SessionStore};
use axum::{ use axum::{
async_trait, async_trait,
extract::{ extract::{
rejection::TypedHeaderRejectionReason, FromRef, FromRequest, Query, RequestParts, State, rejection::TypedHeaderRejectionReason, FromRef, FromRequestParts, Query, State, TypedHeader,
TypedHeader,
}, },
http::{header::SET_COOKIE, HeaderMap}, http::{header::SET_COOKIE, HeaderMap},
response::{IntoResponse, Redirect, Response}, response::{IntoResponse, Redirect, Response},
routing::get, routing::get,
Router, Router,
}; };
use http::header; use http::{header, request::Parts};
use oauth2::{ use oauth2::{
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl, ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
@ -139,7 +138,7 @@ async fn discord_auth(State(client): State<BasicClient>) -> impl IntoResponse {
.url(); .url();
// Redirect to Discord's oauth service // Redirect to Discord's oauth service
Redirect::to(&auth_url.to_string()) Redirect::to(auth_url.as_ref())
} }
// Valid user session required. If there is none, redirect to the auth page // Valid user session required. If there is none, redirect to the auth page
@ -224,17 +223,18 @@ impl IntoResponse for AuthRedirect {
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<AppState, B> for User impl<S> FromRequestParts<S> for User
where where
B: Send, MemoryStore: FromRef<S>,
S: Send + Sync,
{ {
// If anything goes wrong or no session is found, redirect to the auth page // If anything goes wrong or no session is found, redirect to the auth page
type Rejection = AuthRedirect; type Rejection = AuthRedirect;
async fn from_request(req: &mut RequestParts<AppState, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let store = req.state().clone().store; let store = MemoryStore::from_ref(state);
let cookies = TypedHeader::<headers::Cookie>::from_request(req) let cookies = TypedHeader::<headers::Cookie>::from_request_parts(parts, state)
.await .await
.map_err(|e| match *e.name() { .map_err(|e| match *e.name() {
header::COOKIE => match e.reason() { header::COOKIE => match e.reason() {

View file

@ -7,11 +7,12 @@
use async_session::{MemoryStore, Session, SessionStore as _}; use async_session::{MemoryStore, Session, SessionStore as _};
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRequest, RequestParts, TypedHeader}, extract::{FromRef, FromRequestParts, TypedHeader},
headers::Cookie, headers::Cookie,
http::{ http::{
self, self,
header::{HeaderMap, HeaderValue}, header::{HeaderMap, HeaderValue},
request::Parts,
StatusCode, StatusCode,
}, },
response::IntoResponse, response::IntoResponse,
@ -80,16 +81,19 @@ enum UserIdFromSession {
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<MemoryStore, B> for UserIdFromSession impl<S> FromRequestParts<S> for UserIdFromSession
where where
B: Send, MemoryStore: FromRef<S>,
S: Send + Sync,
{ {
type Rejection = (StatusCode, &'static str); type Rejection = (StatusCode, &'static str);
async fn from_request(req: &mut RequestParts<MemoryStore, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let store = req.state().clone(); let store = MemoryStore::from_ref(state);
let cookie = req.extract::<Option<TypedHeader<Cookie>>>().await.unwrap(); let cookie = Option::<TypedHeader<Cookie>>::from_request_parts(parts, state)
.await
.unwrap();
let session_cookie = cookie let session_cookie = cookie
.as_ref() .as_ref()

View file

@ -15,8 +15,8 @@
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRequest, RequestParts, State}, extract::{FromRef, FromRequestParts, State},
http::StatusCode, http::{request::Parts, StatusCode},
routing::get, routing::get,
Router, Router,
}; };
@ -75,14 +75,15 @@ async fn using_connection_pool_extractor(
struct DatabaseConnection(sqlx::pool::PoolConnection<sqlx::Postgres>); struct DatabaseConnection(sqlx::pool::PoolConnection<sqlx::Postgres>);
#[async_trait] #[async_trait]
impl<B> FromRequest<PgPool, B> for DatabaseConnection impl<S> FromRequestParts<S> for DatabaseConnection
where where
B: Send, PgPool: FromRef<S>,
S: Send + Sync,
{ {
type Rejection = (StatusCode, String); type Rejection = (StatusCode, String);
async fn from_request(req: &mut RequestParts<PgPool, B>) -> Result<Self, Self::Rejection> { async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let pool = req.state().clone(); let pool = PgPool::from_ref(state);
let conn = pool.acquire().await.map_err(internal_error)?; let conn = pool.acquire().await.map_err(internal_error)?;

Some files were not shown because too many files have changed in this diff Show more