Only allow last extractor to mutate the request (#1272)

* Only allow last extractor to mutate the request

* Change `FromRequest` and add `FromRequestParts` trait (#1275)

* Add `Once`/`Mut` type parameter for `FromRequest` and `RequestParts`

* 🪄

* split traits

* `FromRequest` for tuples

* Remove `BodyAlreadyExtracted`

* don't need fully qualified path

* don't export `Once` and `Mut`

* remove temp tests

* depend on axum again

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

* Port `Handler` and most extractors (#1277)

* Port `Handler` and most extractors

* Put `M` inside `Handler` impls, not trait itself

* comment out tuples for now

* fix lints

* Reorder arguments to `Handler` (#1281)

I think `Request<B>, Arc<S>` is better since its consistent with
`FromRequest` and `FromRequestParts`.

* Port most things in axum-extra (#1282)

* Port `#[derive(TypedPath)]` and `#[debug_handler]` (#1283)

* port #[derive(TypedPath)]

* wip: #[debug_handler]

* fix #[debug_handler]

* don't need itertools

* also require `Send`

* update expected error

* support fully qualified `self`

* Implement FromRequest[Parts] for tuples (#1286)

* Port docs for axum and axum-core (#1285)

* Port axum-extra (#1287)

* Port axum-extra

* Update axum-core/Cargo.toml

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

* remove `impl FromRequest for Either*`

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

* New FromRequest[Parts] trait cleanup (#1288)

* Make private module truly private again

* Simplify tuple FromRequest implementation

* Port `#[derive(FromRequest)]` (#1289)

* fix tests

* fix docs

* revert examples

* fix docs link

* fix intra docs links

* Port examples (#1291)

* Document wrapping other extractors (#1292)

* axum-extra doesn't need to depend on axum-core (#1294)

Missed this in https://github.com/tokio-rs/axum/pull/1287

* Add `FromRequest` changes to changelogs (#1293)

* Update changelog

* Remove default type for `S` in `Handler`

* Clarify which types have default types for `S`

* Apply suggestions from code review

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

* remove unused import

* Rename `Mut` and `Once` (#1296)

* fix trybuild expected output

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
David Pedersen 2022-08-22 12:23:20 +02:00 committed by GitHub
parent f1769e5134
commit be624306f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
104 changed files with 1513 additions and 1936 deletions

View file

@ -7,10 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- **breaking:** `FromRequest` and `RequestParts` has a new `S` type param which
represents the state ([#1155])
- **breaking:** `FromRequest` has been reworked and `RequestParts` has been
removed. See axum's changelog for more details ([#1272])
- **added:** Added new `FromRequestParts` trait. See axum's changelog for more
details ([#1272])
- **breaking:** `BodyAlreadyExtracted` has been removed ([#1272])
[#1155]: https://github.com/tokio-rs/axum/pull/1155
[#1272]: https://github.com/tokio-rs/axum/pull/1272
# 0.2.7 (10. July, 2022)

View file

@ -4,11 +4,10 @@
//!
//! [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html
use self::rejection::*;
use crate::response::IntoResponse;
use async_trait::async_trait;
use http::{Extensions, HeaderMap, Method, Request, Uri, Version};
use std::{convert::Infallible, sync::Arc};
use http::{request::Parts, Request};
use std::convert::Infallible;
pub mod rejection;
@ -18,9 +17,44 @@ mod tuple;
pub use self::from_ref::FromRef;
mod private {
#[derive(Debug, Clone, Copy)]
pub enum ViaParts {}
#[derive(Debug, Clone, Copy)]
pub enum ViaRequest {}
}
/// Types that can be created from request parts.
///
/// Extractors that implement `FromRequestParts` cannot consume the request body and can thus be
/// run in any order for handlers.
///
/// If your extractor needs to consume the request body then you should implement [`FromRequest`]
/// and not [`FromRequestParts`].
///
/// See [`axum::extract`] for more general docs about extraxtors.
///
/// [`axum::extract`]: https://docs.rs/axum/0.6/axum/extract/index.html
#[async_trait]
pub trait FromRequestParts<S>: Sized {
/// If the extractor fails it'll use this "rejection" type. A rejection is
/// a kind of error that can be converted into a response.
type Rejection: IntoResponse;
/// Perform the extraction.
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection>;
}
/// Types that can be created from requests.
///
/// See [`axum::extract`] for more details.
/// Extractors that implement `FromRequest` can consume the request body and can thus only be run
/// once for handlers.
///
/// If your extractor doesn't need to consume the request body then you should implement
/// [`FromRequestParts`] and not [`FromRequest`].
///
/// See [`axum::extract`] for more general docs about extraxtors.
///
/// # What is the `B` type parameter?
///
@ -39,7 +73,8 @@ pub use self::from_ref::FromRef;
/// ```rust
/// use axum::{
/// async_trait,
/// extract::{FromRequest, RequestParts},
/// extract::FromRequest,
/// http::Request,
/// };
///
/// struct MyExtractor;
@ -48,12 +83,12 @@ pub use self::from_ref::FromRef;
/// impl<S, B> FromRequest<S, B> for MyExtractor
/// where
/// // these bounds are required by `async_trait`
/// B: Send,
/// B: Send + 'static,
/// S: Send + Sync,
/// {
/// type Rejection = http::StatusCode;
///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
/// // ...
/// # unimplemented!()
/// }
@ -63,231 +98,45 @@ pub use self::from_ref::FromRef;
/// This ensures your extractor is as flexible as possible.
///
/// [`http::Request<B>`]: http::Request
/// [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html
/// [`axum::extract`]: https://docs.rs/axum/0.6/axum/extract/index.html
#[async_trait]
pub trait FromRequest<S, B>: Sized {
pub trait FromRequest<S, B, M = private::ViaRequest>: Sized {
/// If the extractor fails it'll use this "rejection" type. A rejection is
/// a kind of error that can be converted into a response.
type Rejection: IntoResponse;
/// Perform the extraction.
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection>;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection>;
}
/// The type used with [`FromRequest`] to extract data from requests.
///
/// Has several convenience methods for getting owned parts of the request.
#[derive(Debug)]
pub struct RequestParts<S, B> {
pub(crate) state: Arc<S>,
method: Method,
uri: Uri,
version: Version,
headers: HeaderMap,
extensions: Extensions,
body: Option<B>,
}
impl<B> RequestParts<(), B> {
/// Create a new `RequestParts` without any state.
///
/// You generally shouldn't need to construct this type yourself, unless
/// using extractors outside of axum for example to implement a
/// [`tower::Service`].
///
/// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html
pub fn new(req: Request<B>) -> Self {
Self::with_state((), req)
}
}
impl<S, B> RequestParts<S, B> {
/// Create a new `RequestParts` with the given state.
///
/// You generally shouldn't need to construct this type yourself, unless
/// using extractors outside of axum for example to implement a
/// [`tower::Service`].
///
/// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html
pub fn with_state(state: S, req: Request<B>) -> Self {
Self::with_state_arc(Arc::new(state), req)
}
/// Create a new `RequestParts` with the given [`Arc`]'ed state.
///
/// You generally shouldn't need to construct this type yourself, unless
/// using extractors outside of axum for example to implement a
/// [`tower::Service`].
///
/// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html
pub fn with_state_arc(state: Arc<S>, req: Request<B>) -> Self {
let (
http::request::Parts {
method,
uri,
version,
headers,
extensions,
..
},
body,
) = req.into_parts();
RequestParts {
state,
method,
uri,
version,
headers,
extensions,
body: Some(body),
}
}
/// Apply an extractor to this `RequestParts`.
///
/// `req.extract::<Extractor>()` is equivalent to `Extractor::from_request(req)`.
/// This function simply exists as a convenience.
///
/// # Example
///
/// ```
/// # struct MyExtractor {}
///
/// use std::convert::Infallible;
///
/// use async_trait::async_trait;
/// use axum::extract::{FromRequest, RequestParts};
/// use http::{Method, Uri};
///
/// #[async_trait]
/// impl<S, B> FromRequest<S, B> for MyExtractor
/// where
/// B: Send,
/// S: Send + Sync,
/// {
/// type Rejection = Infallible;
///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Infallible> {
/// let method = req.extract::<Method>().await?;
/// let path = req.extract::<Uri>().await?.path().to_owned();
///
/// todo!()
/// }
/// }
/// ```
pub async fn extract<E>(&mut self) -> Result<E, E::Rejection>
#[async_trait]
impl<S, B, T> FromRequest<S, B, private::ViaParts> for T
where
E: FromRequest<S, B>,
B: Send + 'static,
S: Send + Sync,
T: FromRequestParts<S>,
{
E::from_request(self).await
type Rejection = <Self as FromRequestParts<S>>::Rejection;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let (mut parts, _) = req.into_parts();
Self::from_request_parts(&mut parts, state).await
}
}
/// Convert this `RequestParts` back into a [`Request`].
///
/// Fails if The request body has been extracted, that is [`take_body`] has
/// been called.
///
/// [`take_body`]: RequestParts::take_body
pub fn try_into_request(self) -> Result<Request<B>, BodyAlreadyExtracted> {
let Self {
state: _,
method,
uri,
version,
headers,
extensions,
mut body,
} = self;
#[async_trait]
impl<S, T> FromRequestParts<S> for Option<T>
where
T: FromRequestParts<S>,
S: Send + Sync,
{
type Rejection = Infallible;
let mut req = if let Some(body) = body.take() {
Request::new(body)
} else {
return Err(BodyAlreadyExtracted);
};
*req.method_mut() = method;
*req.uri_mut() = uri;
*req.version_mut() = version;
*req.headers_mut() = headers;
*req.extensions_mut() = extensions;
Ok(req)
}
/// Gets a reference to the request method.
pub fn method(&self) -> &Method {
&self.method
}
/// Gets a mutable reference to the request method.
pub fn method_mut(&mut self) -> &mut Method {
&mut self.method
}
/// Gets a reference to the request URI.
pub fn uri(&self) -> &Uri {
&self.uri
}
/// Gets a mutable reference to the request URI.
pub fn uri_mut(&mut self) -> &mut Uri {
&mut self.uri
}
/// Get the request HTTP version.
pub fn version(&self) -> Version {
self.version
}
/// Gets a mutable reference to the request HTTP version.
pub fn version_mut(&mut self) -> &mut Version {
&mut self.version
}
/// Gets a reference to the request headers.
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
/// Gets a mutable reference to the request headers.
pub fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.headers
}
/// Gets a reference to the request extensions.
pub fn extensions(&self) -> &Extensions {
&self.extensions
}
/// Gets a mutable reference to the request extensions.
pub fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}
/// Gets a reference to the request body.
///
/// Returns `None` if the body has been taken by another extractor.
pub fn body(&self) -> Option<&B> {
self.body.as_ref()
}
/// Gets a mutable reference to the request body.
///
/// Returns `None` if the body has been taken by another extractor.
// this returns `&mut Option<B>` rather than `Option<&mut B>` such that users can use it to set the body.
pub fn body_mut(&mut self) -> &mut Option<B> {
&mut self.body
}
/// Takes the body out of the request, leaving a `None` in its place.
pub fn take_body(&mut self) -> Option<B> {
self.body.take()
}
/// Get a reference to the state.
pub fn state(&self) -> &S {
&self.state
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Option<T>, Self::Rejection> {
Ok(T::from_request_parts(parts, state).await.ok())
}
}
@ -295,13 +144,26 @@ impl<S, B> RequestParts<S, B> {
impl<S, T, B> FromRequest<S, B> for Option<T>
where
T: FromRequest<S, B>,
B: Send,
B: Send + 'static,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Option<T>, Self::Rejection> {
Ok(T::from_request(req).await.ok())
async fn from_request(req: Request<B>, state: &S) -> Result<Option<T>, Self::Rejection> {
Ok(T::from_request(req, state).await.ok())
}
}
#[async_trait]
impl<S, T> FromRequestParts<S> for Result<T, T::Rejection>
where
T: FromRequestParts<S>,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
Ok(T::from_request_parts(parts, state).await)
}
}
@ -309,12 +171,12 @@ where
impl<S, T, B> FromRequest<S, B> for Result<T, T::Rejection>
where
T: FromRequest<S, B>,
B: Send,
B: Send + 'static,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(T::from_request(req).await)
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
Ok(T::from_request(req, state).await)
}
}

View file

@ -1,35 +1,6 @@
//! Rejection response types.
use crate::{
response::{IntoResponse, Response},
BoxError,
};
use http::StatusCode;
use std::fmt;
/// Rejection type used if you try and extract the request body more than
/// once.
#[derive(Debug, Default)]
#[non_exhaustive]
pub struct BodyAlreadyExtracted;
impl BodyAlreadyExtracted {
const BODY: &'static str = "Cannot have two request body extractors for a single handler";
}
impl IntoResponse for BodyAlreadyExtracted {
fn into_response(self) -> Response {
(StatusCode::INTERNAL_SERVER_ERROR, Self::BODY).into_response()
}
}
impl fmt::Display for BodyAlreadyExtracted {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", Self::BODY)
}
}
impl std::error::Error for BodyAlreadyExtracted {}
use crate::BoxError;
composite_rejection! {
/// Rejection type for extractors that buffer the request body. Used if the
@ -85,7 +56,6 @@ composite_rejection! {
/// Contains one variant for each way the [`Bytes`](bytes::Bytes) extractor
/// can fail.
pub enum BytesRejection {
BodyAlreadyExtracted,
FailedToBufferBody,
}
}
@ -95,7 +65,6 @@ composite_rejection! {
///
/// Contains one variant for each way the [`String`] extractor can fail.
pub enum StringRejection {
BodyAlreadyExtracted,
FailedToBufferBody,
InvalidUtf8,
}

View file

@ -1,9 +1,9 @@
use super::{rejection::*, FromRequest, RequestParts};
use super::{rejection::*, FromRequest, FromRequestParts};
use crate::BoxError;
use async_trait::async_trait;
use bytes::Bytes;
use http::{Extensions, HeaderMap, Method, Request, Uri, Version};
use std::{convert::Infallible, sync::Arc};
use http::{request::Parts, HeaderMap, Method, Request, Uri, Version};
use std::convert::Infallible;
#[async_trait]
impl<S, B> FromRequest<S, B> for Request<B>
@ -11,62 +11,46 @@ where
B: Send,
S: Send + Sync,
{
type Rejection = BodyAlreadyExtracted;
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let req = std::mem::replace(
req,
RequestParts {
state: Arc::clone(&req.state),
method: req.method.clone(),
version: req.version,
uri: req.uri.clone(),
headers: HeaderMap::new(),
extensions: Extensions::default(),
body: None,
},
);
req.try_into_request()
async fn from_request(req: Request<B>, _: &S) -> Result<Self, Self::Rejection> {
Ok(req)
}
}
#[async_trait]
impl<S, B> FromRequest<S, B> for Method
impl<S> FromRequestParts<S> for Method
where
B: Send,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.method().clone())
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
Ok(parts.method.clone())
}
}
#[async_trait]
impl<S, B> FromRequest<S, B> for Uri
impl<S> FromRequestParts<S> for Uri
where
B: Send,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.uri().clone())
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
Ok(parts.uri.clone())
}
}
#[async_trait]
impl<S, B> FromRequest<S, B> for Version
impl<S> FromRequestParts<S> for Version
where
B: Send,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.version())
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
Ok(parts.version)
}
}
@ -76,30 +60,29 @@ where
///
/// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html
#[async_trait]
impl<S, B> FromRequest<S, B> for HeaderMap
impl<S> FromRequestParts<S> for HeaderMap
where
B: Send,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.headers().clone())
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
Ok(parts.headers.clone())
}
}
#[async_trait]
impl<S, B> FromRequest<S, B> for Bytes
where
B: http_body::Body + Send,
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send + Sync,
{
type Rejection = BytesRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?;
async fn from_request(req: Request<B>, _: &S) -> Result<Self, Self::Rejection> {
let body = req.into_body();
let bytes = crate::body::to_bytes(body)
.await
@ -112,15 +95,15 @@ where
#[async_trait]
impl<S, B> FromRequest<S, B> for String
where
B: http_body::Body + Send,
B: http_body::Body + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send + Sync,
{
type Rejection = StringRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?;
async fn from_request(req: Request<B>, _: &S) -> Result<Self, Self::Rejection> {
let body = req.into_body();
let bytes = crate::body::to_bytes(body)
.await
@ -134,40 +117,14 @@ where
}
#[async_trait]
impl<S, B> FromRequest<S, B> for http::request::Parts
impl<S, B> FromRequest<S, B> for Parts
where
B: Send,
B: Send + 'static,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let method = unwrap_infallible(Method::from_request(req).await);
let uri = unwrap_infallible(Uri::from_request(req).await);
let version = unwrap_infallible(Version::from_request(req).await);
let headers = unwrap_infallible(HeaderMap::from_request(req).await);
let extensions = std::mem::take(req.extensions_mut());
let mut temp_request = Request::new(());
*temp_request.method_mut() = method;
*temp_request.uri_mut() = uri;
*temp_request.version_mut() = version;
*temp_request.headers_mut() = headers;
*temp_request.extensions_mut() = extensions;
let (parts, _) = temp_request.into_parts();
Ok(parts)
async fn from_request(req: Request<B>, _: &S) -> Result<Self, Self::Rejection> {
Ok(req.into_parts().0)
}
}
fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
match result {
Ok(value) => value,
Err(err) => match err {},
}
}
pub(crate) fn take_body<S, B>(req: &mut RequestParts<S, B>) -> Result<B, BodyAlreadyExtracted> {
req.take_body().ok_or(BodyAlreadyExtracted)
}

View file

@ -1,41 +1,143 @@
use super::{FromRequest, RequestParts};
use super::{FromRequest, FromRequestParts};
use crate::response::{IntoResponse, Response};
use async_trait::async_trait;
use http::request::{Parts, Request};
use std::convert::Infallible;
#[async_trait]
impl<S, B> FromRequest<S, B> for ()
impl<S> FromRequestParts<S> for ()
where
B: Send,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(_: &mut RequestParts<S, B>) -> Result<(), Self::Rejection> {
async fn from_request_parts(_: &mut Parts, _: &S) -> Result<(), Self::Rejection> {
Ok(())
}
}
macro_rules! impl_from_request {
() => {};
( $($ty:ident),* $(,)? ) => {
(
[$($ty:ident),*], $last:ident
) => {
#[async_trait]
#[allow(non_snake_case)]
impl<S, B, $($ty,)*> FromRequest<S, B> for ($($ty,)*)
#[allow(non_snake_case, unused_mut, unused_variables)]
impl<S, $($ty,)* $last> FromRequestParts<S> for ($($ty,)* $last,)
where
$( $ty: FromRequest<S, B> + Send, )*
B: Send,
$( $ty: FromRequestParts<S> + Send, )*
$last: FromRequestParts<S> + Send,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
$( let $ty = $ty::from_request(req).await.map_err(|err| err.into_response())?; )*
Ok(($($ty,)*))
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
$(
let $ty = $ty::from_request_parts(parts, state)
.await
.map_err(|err| err.into_response())?;
)*
let $last = $last::from_request_parts(parts, state)
.await
.map_err(|err| err.into_response())?;
Ok(($($ty,)* $last,))
}
}
// This impl must not be generic over M, otherwise it would conflict with the blanket
// implementation of `FromRequest<S, B, Mut>` for `T: FromRequestParts<S>`.
#[async_trait]
#[allow(non_snake_case, unused_mut, unused_variables)]
impl<S, B, $($ty,)* $last> FromRequest<S, B> for ($($ty,)* $last,)
where
$( $ty: FromRequestParts<S> + Send, )*
$last: FromRequest<S, B> + Send,
B: Send + 'static,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let (mut parts, body) = req.into_parts();
$(
let $ty = $ty::from_request_parts(&mut parts, state).await.map_err(|err| err.into_response())?;
)*
let req = Request::from_parts(parts, body);
let $last = $last::from_request(req, state).await.map_err(|err| err.into_response())?;
Ok(($($ty,)* $last,))
}
}
};
}
all_the_tuples!(impl_from_request);
impl_from_request!([], T1);
impl_from_request!([T1], T2);
impl_from_request!([T1, T2], T3);
impl_from_request!([T1, T2, T3], T4);
impl_from_request!([T1, T2, T3, T4], T5);
impl_from_request!([T1, T2, T3, T4, T5], T6);
impl_from_request!([T1, T2, T3, T4, T5, T6], T7);
impl_from_request!([T1, T2, T3, T4, T5, T6, T7], T8);
impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8], T9);
impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10);
impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11);
impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12);
impl_from_request!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13);
impl_from_request!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13],
T14
);
impl_from_request!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14],
T15
);
impl_from_request!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15],
T16
);
#[cfg(test)]
mod tests {
use bytes::Bytes;
use http::Method;
use crate::extract::{FromRequest, FromRequestParts};
fn assert_from_request<M, T>()
where
T: FromRequest<(), http_body::Full<Bytes>, M>,
{
}
fn assert_from_request_parts<T: FromRequestParts<()>>() {}
#[test]
fn unit() {
assert_from_request_parts::<()>();
assert_from_request::<_, ()>();
}
#[test]
fn tuple_of_one() {
assert_from_request_parts::<(Method,)>();
assert_from_request::<_, (Method,)>();
assert_from_request::<_, (Bytes,)>();
}
#[test]
fn tuple_of_two() {
assert_from_request_parts::<((), ())>();
assert_from_request::<_, ((), ())>();
assert_from_request::<_, (Method, Bytes)>();
}
#[test]
fn nested_tuple() {
assert_from_request_parts::<(((Method,),),)>();
assert_from_request::<_, ((((Bytes,),),),)>();
}
}

View file

@ -4,15 +4,50 @@
//!
//! ```
//! use axum_extra::either::Either3;
//! use axum::{body::Bytes, Json};
//! use axum::{
//! body::Bytes,
//! Router,
//! async_trait,
//! routing::get,
//! extract::FromRequestParts,
//! };
//!
//! // extractors for checking permissions
//! struct AdminPermissions {}
//!
//! #[async_trait]
//! impl<S> FromRequestParts<S> for AdminPermissions
//! where
//! S: Send + Sync,
//! {
//! // check for admin permissions...
//! # type Rejection = ();
//! # async fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> Result<Self, Self::Rejection> {
//! # todo!()
//! # }
//! }
//!
//! struct User {}
//!
//! #[async_trait]
//! impl<S> FromRequestParts<S> for User
//! where
//! S: Send + Sync,
//! {
//! // check for a logged in user...
//! # type Rejection = ();
//! # async fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> Result<Self, Self::Rejection> {
//! # todo!()
//! # }
//! }
//!
//! async fn handler(
//! body: Either3<Json<serde_json::Value>, String, Bytes>,
//! body: Either3<AdminPermissions, User, ()>,
//! ) {
//! match body {
//! Either3::E1(json) => { /* ... */ }
//! Either3::E2(string) => { /* ... */ }
//! Either3::E3(bytes) => { /* ... */ }
//! Either3::E1(admin) => { /* ... */ }
//! Either3::E2(user) => { /* ... */ }
//! Either3::E3(guest) => { /* ... */ }
//! }
//! }
//! #
@ -60,9 +95,10 @@
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
extract::FromRequestParts,
response::{IntoResponse, Response},
};
use http::request::Parts;
/// Combines two extractors or responses into a single type.
///
@ -190,23 +226,22 @@ macro_rules! impl_traits_for_either {
$last:ident $(,)?
) => {
#[async_trait]
impl<S, B, $($ident),*, $last> FromRequest<S, B> for $either<$($ident),*, $last>
impl<S, $($ident),*, $last> FromRequestParts<S> for $either<$($ident),*, $last>
where
$($ident: FromRequest<S, B>),*,
$last: FromRequest<S, B>,
B: Send,
$($ident: FromRequestParts<S>),*,
$last: FromRequestParts<S>,
S: Send + Sync,
{
type Rejection = $last::Rejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
$(
if let Ok(value) = req.extract().await {
if let Ok(value) = FromRequestParts::from_request_parts(parts, state).await {
return Ok(Self::$ident(value));
}
)*
req.extract().await.map(Self::$last)
FromRequestParts::from_request_parts(parts, state).await.map(Self::$last)
}
}

View file

@ -1,7 +1,8 @@
use axum::{
async_trait,
extract::{Extension, FromRequest, RequestParts},
extract::{Extension, FromRequest, FromRequestParts},
};
use http::{request::Parts, Request};
use std::ops::{Deref, DerefMut};
/// Cache results of other extractors.
@ -20,24 +21,23 @@ use std::ops::{Deref, DerefMut};
/// use axum_extra::extract::Cached;
/// use axum::{
/// async_trait,
/// extract::{FromRequest, RequestParts},
/// extract::FromRequestParts,
/// body::BoxBody,
/// response::{IntoResponse, Response},
/// http::StatusCode,
/// http::{StatusCode, request::Parts},
/// };
///
/// #[derive(Clone)]
/// struct Session { /* ... */ }
///
/// #[async_trait]
/// impl<S, B> FromRequest<S, B> for Session
/// impl<S> FromRequestParts<S> for Session
/// where
/// B: Send,
/// S: Send + Sync,
/// {
/// type Rejection = (StatusCode, String);
///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
/// // load session...
/// # unimplemented!()
/// }
@ -46,19 +46,18 @@ use std::ops::{Deref, DerefMut};
/// struct CurrentUser { /* ... */ }
///
/// #[async_trait]
/// impl<S, B> FromRequest<S, B> for CurrentUser
/// impl<S> FromRequestParts<S> for CurrentUser
/// where
/// B: Send,
/// S: Send + Sync,
/// {
/// type Rejection = Response;
///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
/// // loading a `CurrentUser` requires first loading the `Session`
/// //
/// // by using `Cached<Session>` we avoid extracting the session more than
/// // once, in case other extractors for the same request also loads the session
/// let session: Session = Cached::<Session>::from_request(req)
/// let session: Session = Cached::<Session>::from_request_parts(parts, state)
/// .await
/// .map_err(|err| err.into_response())?
/// .0;
@ -92,18 +91,40 @@ struct CachedEntry<T>(T);
#[async_trait]
impl<S, B, T> FromRequest<S, B> for Cached<T>
where
B: Send,
B: Send + 'static,
S: Send + Sync,
T: FromRequest<S, B> + Clone + Send + Sync + 'static,
T: FromRequestParts<S> + Clone + Send + Sync + 'static,
{
type Rejection = T::Rejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
match Extension::<CachedEntry<T>>::from_request(req).await {
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let (mut parts, _) = req.into_parts();
match Extension::<CachedEntry<T>>::from_request_parts(&mut parts, state).await {
Ok(Extension(CachedEntry(value))) => Ok(Self(value)),
Err(_) => {
let value = T::from_request(req).await?;
req.extensions_mut().insert(CachedEntry(value.clone()));
let value = T::from_request_parts(&mut parts, state).await?;
parts.extensions.insert(CachedEntry(value.clone()));
Ok(Self(value))
}
}
}
}
#[async_trait]
impl<S, T> FromRequestParts<S> for Cached<T>
where
S: Send + Sync,
T: FromRequestParts<S> + Clone + Send + Sync + 'static,
{
type Rejection = T::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
match Extension::<CachedEntry<T>>::from_request_parts(parts, state).await {
Ok(Extension(CachedEntry(value))) => Ok(Self(value)),
Err(_) => {
let value = T::from_request_parts(parts, state).await?;
parts.extensions.insert(CachedEntry(value.clone()));
Ok(Self(value))
}
}
@ -127,7 +148,8 @@ impl<T> DerefMut for Cached<T> {
#[cfg(test)]
mod tests {
use super::*;
use axum::http::Request;
use axum::{extract::FromRequestParts, http::Request};
use http::request::Parts;
use std::{
convert::Infallible,
sync::atomic::{AtomicU32, Ordering},
@ -142,25 +164,33 @@ mod tests {
struct Extractor(Instant);
#[async_trait]
impl<S, B> FromRequest<S, B> for Extractor
impl<S> FromRequestParts<S> for Extractor
where
B: Send,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request_parts(
_parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
COUNTER.fetch_add(1, Ordering::SeqCst);
Ok(Self(Instant::now()))
}
}
let mut req = RequestParts::new(Request::new(()));
let (mut parts, _) = Request::new(()).into_parts();
let first = Cached::<Extractor>::from_request(&mut req).await.unwrap().0;
let first = Cached::<Extractor>::from_request_parts(&mut parts, &())
.await
.unwrap()
.0;
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
let second = Cached::<Extractor>::from_request(&mut req).await.unwrap().0;
let second = Cached::<Extractor>::from_request_parts(&mut parts, &())
.await
.unwrap()
.0;
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
assert_eq!(first, second);

View file

@ -4,11 +4,12 @@
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
extract::FromRequestParts,
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use http::{
header::{COOKIE, SET_COOKIE},
request::Parts,
HeaderMap,
};
use std::convert::Infallible;
@ -88,15 +89,14 @@ pub struct CookieJar {
}
#[async_trait]
impl<S, B> FromRequest<S, B> for CookieJar
impl<S> FromRequestParts<S> for CookieJar
where
B: Send,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(Self::from_headers(req.headers()))
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
Ok(Self::from_headers(&parts.headers))
}
}
@ -115,7 +115,9 @@ impl CookieJar {
/// The cookies in `headers` will be added to the jar.
///
/// This is inteded to be used in middleware and other places where it might be difficult to
/// run extractors. Normally you should create `CookieJar`s through [`FromRequest`].
/// run extractors. Normally you should create `CookieJar`s through [`FromRequestParts`].
///
/// [`FromRequestParts`]: axum::extract::FromRequestParts
pub fn from_headers(headers: &HeaderMap) -> Self {
let mut jar = cookie::CookieJar::new();
for cookie in cookies_from_request(headers) {
@ -127,10 +129,12 @@ impl CookieJar {
/// Create a new empty `CookieJar`.
///
/// This is inteded to be used in middleware and other places where it might be difficult to
/// run extractors. Normally you should create `CookieJar`s through [`FromRequest`].
/// run extractors. Normally you should create `CookieJar`s through [`FromRequestParts`].
///
/// If you need a jar that contains the headers from a request use `impl From<&HeaderMap> for
/// CookieJar`.
///
/// [`FromRequestParts`]: axum::extract::FromRequestParts
pub fn new() -> Self {
Self::default()
}

View file

@ -1,11 +1,11 @@
use super::{cookies_from_request, set_cookies, Cookie, Key};
use axum::{
async_trait,
extract::{FromRef, FromRequest, RequestParts},
extract::{FromRef, FromRequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use cookie::PrivateJar;
use http::HeaderMap;
use http::{request::Parts, HeaderMap};
use std::{convert::Infallible, fmt, marker::PhantomData};
/// Extractor that grabs private cookies from the request and manages the jar.
@ -87,22 +87,21 @@ impl<K> fmt::Debug for PrivateCookieJar<K> {
}
#[async_trait]
impl<S, B, K> FromRequest<S, B> for PrivateCookieJar<K>
impl<S, K> FromRequestParts<S> for PrivateCookieJar<K>
where
B: Send,
S: Send + Sync,
K: FromRef<S> + Into<Key>,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let k = K::from_ref(req.state());
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let k = K::from_ref(state);
let key = k.into();
let PrivateCookieJar {
jar,
key,
_marker: _,
} = PrivateCookieJar::from_headers(req.headers(), key);
} = PrivateCookieJar::from_headers(&parts.headers, key);
Ok(PrivateCookieJar {
jar,
key,
@ -117,7 +116,9 @@ impl PrivateCookieJar {
/// The valid cookies in `headers` will be added to the jar.
///
/// This is inteded to be used in middleware and other where places it might be difficult to
/// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequest`].
/// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequestParts`].
///
/// [`FromRequestParts`]: axum::extract::FromRequestParts
pub fn from_headers(headers: &HeaderMap, key: Key) -> Self {
let mut jar = cookie::CookieJar::new();
let mut private_jar = jar.private_mut(&key);
@ -137,7 +138,9 @@ impl PrivateCookieJar {
/// Create a new empty `PrivateCookieJarIter`.
///
/// This is inteded to be used in middleware and other places where it might be difficult to
/// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequest`].
/// run extractors. Normally you should create `PrivateCookieJar`s through [`FromRequestParts`].
///
/// [`FromRequestParts`]: axum::extract::FromRequestParts
pub fn new(key: Key) -> Self {
Self {
jar: Default::default(),

View file

@ -1,12 +1,12 @@
use super::{cookies_from_request, set_cookies};
use axum::{
async_trait,
extract::{FromRef, FromRequest, RequestParts},
extract::{FromRef, FromRequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use cookie::SignedJar;
use cookie::{Cookie, Key};
use http::HeaderMap;
use http::{request::Parts, HeaderMap};
use std::{convert::Infallible, fmt, marker::PhantomData};
/// Extractor that grabs signed cookies from the request and manages the jar.
@ -105,22 +105,21 @@ impl<K> fmt::Debug for SignedCookieJar<K> {
}
#[async_trait]
impl<S, B, K> FromRequest<S, B> for SignedCookieJar<K>
impl<S, K> FromRequestParts<S> for SignedCookieJar<K>
where
B: Send,
S: Send + Sync,
K: FromRef<S> + Into<Key>,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let k = K::from_ref(req.state());
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let k = K::from_ref(state);
let key = k.into();
let SignedCookieJar {
jar,
key,
_marker: _,
} = SignedCookieJar::from_headers(req.headers(), key);
} = SignedCookieJar::from_headers(&parts.headers, key);
Ok(SignedCookieJar {
jar,
key,
@ -135,7 +134,9 @@ impl SignedCookieJar {
/// The valid cookies in `headers` will be added to the jar.
///
/// This is inteded to be used in middleware and other places where it might be difficult to
/// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequest`].
/// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequestParts`].
///
/// [`FromRequestParts`]: axum::extract::FromRequestParts
pub fn from_headers(headers: &HeaderMap, key: Key) -> Self {
let mut jar = cookie::CookieJar::new();
let mut signed_jar = jar.signed_mut(&key);
@ -155,7 +156,9 @@ impl SignedCookieJar {
/// Create a new empty `SignedCookieJar`.
///
/// This is inteded to be used in middleware and other places where it might be difficult to
/// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequest`].
/// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequestParts`].
///
/// [`FromRequestParts`]: axum::extract::FromRequestParts
pub fn new(key: Key) -> Self {
Self {
jar: Default::default(),

View file

@ -3,12 +3,12 @@ use axum::{
body::HttpBody,
extract::{
rejection::{FailedToDeserializeQueryString, FormRejection, InvalidFormContentType},
FromRequest, RequestParts,
FromRequest,
},
BoxError,
};
use bytes::Bytes;
use http::{header, Method};
use http::{header, HeaderMap, Method, Request};
use serde::de::DeserializeOwned;
use std::ops::Deref;
@ -58,25 +58,25 @@ impl<T> Deref for Form<T> {
impl<T, S, B> FromRequest<S, B> for Form<T>
where
T: DeserializeOwned,
B: HttpBody + Send,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send + Sync,
{
type Rejection = FormRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
if req.method() == Method::GET {
let query = req.uri().query().unwrap_or_default();
let value = serde_html_form::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?;
Ok(Form(value))
} else {
if !has_content_type(req, &mime::APPLICATION_WWW_FORM_URLENCODED) {
if !has_content_type(req.headers(), &mime::APPLICATION_WWW_FORM_URLENCODED) {
return Err(InvalidFormContentType::default().into());
}
let bytes = Bytes::from_request(req).await?;
let bytes = Bytes::from_request(req, state).await?;
let value = serde_html_form::from_bytes(&bytes)
.map_err(FailedToDeserializeQueryString::__private_new)?;
@ -86,8 +86,8 @@ where
}
// this is duplicated in `axum/src/extract/mod.rs`
fn has_content_type<S, B>(req: &RequestParts<S, B>, expected_content_type: &mime::Mime) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
fn has_content_type(headers: &HeaderMap, expected_content_type: &mime::Mime) -> bool {
let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
content_type
} else {
return false;

View file

@ -2,9 +2,10 @@ use axum::{
async_trait,
extract::{
rejection::{FailedToDeserializeQueryString, QueryRejection},
FromRequest, RequestParts,
FromRequestParts,
},
};
use http::request::Parts;
use serde::de::DeserializeOwned;
use std::ops::Deref;
@ -58,16 +59,15 @@ use std::ops::Deref;
pub struct Query<T>(pub T);
#[async_trait]
impl<T, S, B> FromRequest<S, B> for Query<T>
impl<T, S> FromRequestParts<S> for Query<T>
where
T: DeserializeOwned,
B: Send,
S: Send + Sync,
{
type Rejection = QueryRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let query = req.uri().query().unwrap_or_default();
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let query = parts.uri.query().unwrap_or_default();
let value = serde_html_form::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?;
Ok(Query(value))

View file

@ -1,6 +1,8 @@
use axum::async_trait;
use axum::extract::{FromRequest, RequestParts};
use axum::extract::{FromRequest, FromRequestParts};
use axum::response::IntoResponse;
use http::request::Parts;
use http::Request;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
@ -109,23 +111,40 @@ impl<E, R> DerefMut for WithRejection<E, R> {
#[async_trait]
impl<B, E, R, S> FromRequest<S, B> for WithRejection<E, R>
where
B: Send,
B: Send + 'static,
S: Send + Sync,
E: FromRequest<S, B>,
R: From<E::Rejection> + IntoResponse,
{
type Rejection = R;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let extractor = req.extract::<E>().await?;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let extractor = E::from_request(req, state).await?;
Ok(WithRejection(extractor, PhantomData))
}
}
#[async_trait]
impl<E, R, S> FromRequestParts<S> for WithRejection<E, R>
where
S: Send + Sync,
E: FromRequestParts<S>,
R: From<E::Rejection> + IntoResponse,
{
type Rejection = R;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let extractor = E::from_request_parts(parts, state).await?;
Ok(WithRejection(extractor, PhantomData))
}
}
#[cfg(test)]
mod tests {
use axum::extract::FromRequestParts;
use axum::http::Request;
use axum::response::Response;
use http::request::Parts;
use super::*;
@ -135,14 +154,16 @@ mod tests {
struct TestRejection;
#[async_trait]
impl<S, B> FromRequest<S, B> for TestExtractor
impl<S> FromRequestParts<S> for TestExtractor
where
B: Send,
S: Send + Sync,
{
type Rejection = ();
async fn from_request(_: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request_parts(
_parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
Err(())
}
}
@ -159,12 +180,14 @@ mod tests {
}
}
let mut req = RequestParts::new(Request::new(()));
let req = Request::new(());
let result = WithRejection::<TestExtractor, TestRejection>::from_request(req, &()).await;
assert!(matches!(result, Err(TestRejection)));
let result = req
.extract::<WithRejection<TestExtractor, TestRejection>>()
let (mut parts, _) = Request::new(()).into_parts();
let result =
WithRejection::<TestExtractor, TestRejection>::from_request_parts(&mut parts, &())
.await;
assert!(matches!(result, Err(TestRejection)))
assert!(matches!(result, Err(TestRejection)));
}
}

View file

@ -1,7 +1,7 @@
//! Additional handler utilities.
use axum::{
extract::{FromRequest, RequestParts},
extract::FromRequest,
handler::Handler,
response::{IntoResponse, Response},
};
@ -26,8 +26,8 @@ pub trait HandlerCallWithExtractors<T, S, B>: Sized {
/// Call the handler with the extracted inputs.
fn call(
self,
state: Arc<S>,
extractors: T,
state: Arc<S>,
) -> <Self as HandlerCallWithExtractors<T, S, B>>::Future;
/// Conver this `HandlerCallWithExtractors` into [`Handler`].
@ -51,7 +51,7 @@ pub trait HandlerCallWithExtractors<T, S, B>: Sized {
/// Router,
/// async_trait,
/// routing::get,
/// extract::FromRequest,
/// extract::FromRequestParts,
/// };
///
/// // handlers for varying levels of access
@ -71,14 +71,13 @@ pub trait HandlerCallWithExtractors<T, S, B>: Sized {
/// struct AdminPermissions {}
///
/// #[async_trait]
/// impl<S, B> FromRequest<S, B> for AdminPermissions
/// impl<S> FromRequestParts<S> for AdminPermissions
/// where
/// B: Send,
/// S: Send + Sync,
/// {
/// // check for admin permissions...
/// # type Rejection = ();
/// # async fn from_request(req: &mut axum::extract::RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// # async fn from_request_parts(parts: &mut http::request::Parts, state: &S) -> Result<Self, Self::Rejection> {
/// # todo!()
/// # }
/// }
@ -86,14 +85,13 @@ pub trait HandlerCallWithExtractors<T, S, B>: Sized {
/// struct User {}
///
/// #[async_trait]
/// impl<S, B> FromRequest<S, B> for User
/// impl<S> FromRequestParts<S> for User
/// where
/// B: Send,
/// S: Send + Sync,
/// {
/// // check for a logged in user...
/// # type Rejection = ();
/// # async fn from_request(req: &mut axum::extract::RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// # async fn from_request_parts(parts: &mut http::request::Parts, state: &S) -> Result<Self, Self::Rejection> {
/// # todo!()
/// # }
/// }
@ -134,8 +132,8 @@ macro_rules! impl_handler_call_with {
fn call(
self,
_state: Arc<S>,
($($ty,)*): ($($ty,)*),
_state: Arc<S>,
) -> <Self as HandlerCallWithExtractors<($($ty,)*), S, B>>::Future {
self($($ty,)*).map(IntoResponse::into_response)
}
@ -180,11 +178,10 @@ where
{
type Future = BoxFuture<'static, Response>;
fn call(self, state: Arc<S>, req: http::Request<B>) -> Self::Future {
fn call(self, req: http::Request<B>, state: Arc<S>) -> Self::Future {
Box::pin(async move {
let mut req = RequestParts::with_state_arc(Arc::clone(&state), req);
match req.extract::<T>().await {
Ok(t) => self.handler.call(state, t).await,
match T::from_request(req, &state).await {
Ok(t) => self.handler.call(t, state).await,
Err(rejection) => rejection.into_response(),
}
})

View file

@ -1,13 +1,12 @@
use super::HandlerCallWithExtractors;
use crate::either::Either;
use axum::{
extract::{FromRequest, RequestParts},
extract::{FromRequest, FromRequestParts},
handler::Handler,
http::Request,
response::{IntoResponse, Response},
};
use futures_util::future::{BoxFuture, Either as EitherFuture, FutureExt, Map};
use http::StatusCode;
use std::{future::Future, marker::PhantomData, sync::Arc};
/// [`Handler`] that runs one [`Handler`] and if that rejects it'll fallback to another
@ -37,30 +36,30 @@ where
fn call(
self,
state: Arc<S>,
extractors: Either<Lt, Rt>,
state: Arc<S>,
) -> <Self as HandlerCallWithExtractors<Either<Lt, Rt>, S, B>>::Future {
match extractors {
Either::E1(lt) => self
.lhs
.call(state, lt)
.call(lt, state)
.map(IntoResponse::into_response as _)
.left_future(),
Either::E2(rt) => self
.rhs
.call(state, rt)
.call(rt, state)
.map(IntoResponse::into_response as _)
.right_future(),
}
}
}
impl<S, B, L, R, Lt, Rt> Handler<(Lt, Rt), S, B> for Or<L, R, Lt, Rt, S, B>
impl<S, B, L, R, Lt, Rt, M> Handler<(M, Lt, Rt), S, B> for Or<L, R, Lt, Rt, S, B>
where
L: HandlerCallWithExtractors<Lt, S, B> + Clone + Send + 'static,
R: HandlerCallWithExtractors<Rt, S, B> + Clone + Send + 'static,
Lt: FromRequest<S, B> + Send + 'static,
Rt: FromRequest<S, B> + Send + 'static,
Lt: FromRequestParts<S> + Send + 'static,
Rt: FromRequest<S, B, M> + Send + 'static,
Lt::Rejection: Send,
Rt::Rejection: Send,
B: Send + 'static,
@ -69,19 +68,20 @@ where
// this puts `futures_util` in our public API but thats fine in axum-extra
type Future = BoxFuture<'static, Response>;
fn call(self, state: Arc<S>, req: Request<B>) -> Self::Future {
fn call(self, req: Request<B>, state: Arc<S>) -> Self::Future {
Box::pin(async move {
let mut req = RequestParts::with_state_arc(Arc::clone(&state), req);
let (mut parts, body) = req.into_parts();
if let Ok(lt) = req.extract::<Lt>().await {
return self.lhs.call(state, lt).await;
if let Ok(lt) = Lt::from_request_parts(&mut parts, &state).await {
return self.lhs.call(lt, state).await;
}
if let Ok(rt) = req.extract::<Rt>().await {
return self.rhs.call(state, rt).await;
}
let req = Request::from_parts(parts, body);
StatusCode::NOT_FOUND.into_response()
match Rt::from_request(req, &state).await {
Ok(rt) => self.rhs.call(rt, state).await,
Err(rejection) => rejection.into_response(),
}
})
}
}

View file

@ -3,15 +3,17 @@
use axum::{
async_trait,
body::{HttpBody, StreamBody},
extract::{rejection::BodyAlreadyExtracted, FromRequest, RequestParts},
extract::FromRequest,
response::{IntoResponse, Response},
BoxError,
};
use bytes::{BufMut, Bytes, BytesMut};
use futures_util::stream::{BoxStream, Stream, TryStream, TryStreamExt};
use http::Request;
use pin_project_lite::pin_project;
use serde::{de::DeserializeOwned, Serialize};
use std::{
convert::Infallible,
io::{self, Write},
marker::PhantomData,
pin::Pin,
@ -106,14 +108,14 @@ where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = BodyAlreadyExtracted;
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
// `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead`
// so we can call `AsyncRead::lines` and then convert it back to a `Stream`
let body = req.take_body().ok_or_else(BodyAlreadyExtracted::default)?;
let body = BodyStream { body };
let body = BodyStream {
body: req.into_body(),
};
let stream = body
.map_ok(Into::into)

View file

@ -3,12 +3,12 @@
use axum::{
async_trait,
body::{Bytes, HttpBody},
extract::{rejection::BytesRejection, FromRequest, RequestParts},
extract::{rejection::BytesRejection, FromRequest},
response::{IntoResponse, Response},
BoxError,
};
use bytes::BytesMut;
use http::StatusCode;
use http::{Request, StatusCode};
use prost::Message;
use std::ops::{Deref, DerefMut};
@ -100,15 +100,15 @@ pub struct ProtoBuf<T>(pub T);
impl<T, S, B> FromRequest<S, B> for ProtoBuf<T>
where
T: Message + Default,
B: HttpBody + Send,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send + Sync,
{
type Rejection = ProtoBufRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let mut bytes = Bytes::from_request(req).await?;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let mut bytes = Bytes::from_request(req, state).await?;
match T::decode(&mut bytes) {
Ok(value) => Ok(ProtoBuf(value)),

View file

@ -24,7 +24,7 @@ pub use self::resource::Resource;
pub use axum_macros::TypedPath;
#[cfg(feature = "typed-routing")]
pub use self::typed::{FirstElementIs, TypedPath};
pub use self::typed::{SecondElementIs, TypedPath};
#[cfg(feature = "spa")]
pub use self::spa::SpaRouter;
@ -41,7 +41,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_get<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath;
/// Add a typed `DELETE` route to the router.
@ -54,7 +54,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_delete<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath;
/// Add a typed `HEAD` route to the router.
@ -67,7 +67,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_head<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath;
/// Add a typed `OPTIONS` route to the router.
@ -80,7 +80,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_options<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath;
/// Add a typed `PATCH` route to the router.
@ -93,7 +93,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_patch<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath;
/// Add a typed `POST` route to the router.
@ -106,7 +106,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_post<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath;
/// Add a typed `PUT` route to the router.
@ -119,7 +119,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_put<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath;
/// Add a typed `TRACE` route to the router.
@ -132,7 +132,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
fn typed_trace<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath;
/// Add another route to the router with an additional "trailing slash redirect" route.
@ -184,7 +184,7 @@ where
fn typed_get<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::get(handler))
@ -194,7 +194,7 @@ where
fn typed_delete<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::delete(handler))
@ -204,7 +204,7 @@ where
fn typed_head<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::head(handler))
@ -214,7 +214,7 @@ where
fn typed_options<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::options(handler))
@ -224,7 +224,7 @@ where
fn typed_patch<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::patch(handler))
@ -234,7 +234,7 @@ where
fn typed_post<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::post(handler))
@ -244,7 +244,7 @@ where
fn typed_put<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::put(handler))
@ -254,7 +254,7 @@ where
fn typed_trace<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::trace(handler))

View file

@ -231,10 +231,10 @@ pub trait TypedPath: std::fmt::Display {
}
}
/// Utility trait used with [`RouterExt`] to ensure the first element of a tuple type is a
/// Utility trait used with [`RouterExt`] to ensure the second element of a tuple type is a
/// given type.
///
/// If you see it in type errors its most likely because the first argument to your handler doesn't
/// If you see it in type errors its most likely because the second argument to your handler doesn't
/// implement [`TypedPath`].
///
/// You normally shouldn't have to use this trait directly.
@ -242,56 +242,56 @@ pub trait TypedPath: std::fmt::Display {
/// It is sealed such that it cannot be implemented outside this crate.
///
/// [`RouterExt`]: super::RouterExt
pub trait FirstElementIs<P>: Sealed {}
pub trait SecondElementIs<P>: Sealed {}
macro_rules! impl_first_element_is {
macro_rules! impl_second_element_is {
( $($ty:ident),* $(,)? ) => {
impl<P, $($ty,)*> FirstElementIs<P> for (P, $($ty,)*)
impl<M, P, $($ty,)*> SecondElementIs<P> for (M, P, $($ty,)*)
where
P: TypedPath
{}
impl<P, $($ty,)*> Sealed for (P, $($ty,)*)
impl<M, P, $($ty,)*> Sealed for (M, P, $($ty,)*)
where
P: TypedPath
{}
impl<P, $($ty,)*> FirstElementIs<P> for (Option<P>, $($ty,)*)
impl<M, P, $($ty,)*> SecondElementIs<P> for (M, Option<P>, $($ty,)*)
where
P: TypedPath
{}
impl<P, $($ty,)*> Sealed for (Option<P>, $($ty,)*)
impl<M, P, $($ty,)*> Sealed for (M, Option<P>, $($ty,)*)
where
P: TypedPath
{}
impl<P, E, $($ty,)*> FirstElementIs<P> for (Result<P, E>, $($ty,)*)
impl<M, P, E, $($ty,)*> SecondElementIs<P> for (M, Result<P, E>, $($ty,)*)
where
P: TypedPath
{}
impl<P, E, $($ty,)*> Sealed for (Result<P, E>, $($ty,)*)
impl<M, P, E, $($ty,)*> Sealed for (M, Result<P, E>, $($ty,)*)
where
P: TypedPath
{}
};
}
impl_first_element_is!();
impl_first_element_is!(T1);
impl_first_element_is!(T1, T2);
impl_first_element_is!(T1, T2, T3);
impl_first_element_is!(T1, T2, T3, T4);
impl_first_element_is!(T1, T2, T3, T4, T5);
impl_first_element_is!(T1, T2, T3, T4, T5, T6);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
impl_second_element_is!();
impl_second_element_is!(T1);
impl_second_element_is!(T1, T2);
impl_second_element_is!(T1, T2, T3);
impl_second_element_is!(T1, T2, T3, T4);
impl_second_element_is!(T1, T2, T3, T4, T5);
impl_second_element_is!(T1, T2, T3, T4, T5, T6);
impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7);
impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8);
impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);

View file

@ -10,9 +10,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **change:** axum-macro's MSRV is now 1.60 ([#1239])
- **added:** Support using a different rejection for `#[derive(FromRequest)]`
with `#[from_request(rejection(MyRejection))]` ([#1256])
- **breaking:** `#[derive(FromRequest)]` will no longer generate a rejection
enum but instead generate `type Rejection = axum::response::Response`. Use the
new `#[from_request(rejection(MyRejection))]` attribute to change this.
The `rejection_derive` attribute has also been removed ([#1272])
[#1239]: https://github.com/tokio-rs/axum/pull/1239
[#1256]: https://github.com/tokio-rs/axum/pull/1256
[#1272]: https://github.com/tokio-rs/axum/pull/1272
# 0.2.3 (27. June, 2022)

View file

@ -1,14 +1,11 @@
use std::collections::HashSet;
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned};
use std::collections::HashSet;
use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Token, Type};
pub(crate) fn expand(mut attr: Attrs, item_fn: ItemFn) -> TokenStream {
let check_extractor_count = check_extractor_count(&item_fn);
let check_request_last_extractor = check_request_last_extractor(&item_fn);
let check_path_extractor = check_path_extractor(&item_fn);
let check_multiple_body_extractors = check_multiple_body_extractors(&item_fn);
let check_output_impls_into_response = check_output_impls_into_response(&item_fn);
// If the function is generic, we can't reliably check its inputs or whether the future it
@ -39,9 +36,7 @@ pub(crate) fn expand(mut attr: Attrs, item_fn: ItemFn) -> TokenStream {
quote! {
#item_fn
#check_extractor_count
#check_request_last_extractor
#check_path_extractor
#check_multiple_body_extractors
#check_output_impls_into_response
#check_inputs_and_future_send
}
@ -135,22 +130,6 @@ fn extractor_idents(item_fn: &ItemFn) -> impl Iterator<Item = (usize, &syn::FnAr
})
}
fn check_request_last_extractor(item_fn: &ItemFn) -> Option<TokenStream> {
let request_extractor_ident =
extractor_idents(item_fn).find(|(_, _, ident)| *ident == "Request");
if let Some((idx, fn_arg, _)) = request_extractor_ident {
if idx != item_fn.sig.inputs.len() - 1 {
return Some(
syn::Error::new_spanned(fn_arg, "`Request` extractor should always be last")
.to_compile_error(),
);
}
}
None
}
fn check_path_extractor(item_fn: &ItemFn) -> TokenStream {
let path_extractors = extractor_idents(item_fn)
.filter(|(_, _, ident)| *ident == "Path")
@ -174,30 +153,14 @@ fn check_path_extractor(item_fn: &ItemFn) -> TokenStream {
}
}
fn check_multiple_body_extractors(item_fn: &ItemFn) -> TokenStream {
let body_extractors = extractor_idents(item_fn)
.filter(|(_, _, ident)| {
*ident == "String"
|| *ident == "Bytes"
|| *ident == "Json"
|| *ident == "RawBody"
|| *ident == "BodyStream"
|| *ident == "Multipart"
|| *ident == "Request"
})
.collect::<Vec<_>>();
if body_extractors.len() > 1 {
body_extractors
.into_iter()
.map(|(_, arg, _)| {
syn::Error::new_spanned(arg, "Only one body extractor can be applied")
.to_compile_error()
})
.collect()
fn is_self_pat_type(typed: &syn::PatType) -> bool {
let ident = if let syn::Pat::Ident(ident) = &*typed.pat {
&ident.ident
} else {
quote! {}
}
return false;
};
ident == "self"
}
fn check_inputs_impls_from_request(
@ -205,6 +168,11 @@ fn check_inputs_impls_from_request(
body_ty: &Type,
state_ty: Type,
) -> TokenStream {
let takes_self = item_fn.sig.inputs.first().map_or(false, |arg| match arg {
FnArg::Receiver(_) => true,
FnArg::Typed(typed) => is_self_pat_type(typed),
});
item_fn
.sig
.inputs
@ -227,21 +195,53 @@ fn check_inputs_impls_from_request(
FnArg::Typed(typed) => {
let ty = &typed.ty;
let span = ty.span();
if is_self_pat_type(typed) {
(span, syn::parse_quote!(Self))
} else {
(span, ty.clone())
}
}
};
let name = format_ident!(
"__axum_macros_check_{}_{}_from_request",
let check_fn = format_ident!(
"__axum_macros_check_{}_{}_from_request_check",
item_fn.sig.ident,
idx
idx,
span = span,
);
let call_check_fn = format_ident!(
"__axum_macros_check_{}_{}_from_request_call_check",
item_fn.sig.ident,
idx,
span = span,
);
let call_check_fn_body = if takes_self {
quote_spanned! {span=>
Self::#check_fn();
}
} else {
quote_spanned! {span=>
#check_fn();
}
};
quote_spanned! {span=>
#[allow(warnings)]
fn #name()
fn #check_fn<M>()
where
#ty: ::axum::extract::FromRequest<#state_ty, #body_ty> + Send,
#ty: ::axum::extract::FromRequest<#state_ty, #body_ty, M> + Send,
{}
// we have to call the function to actually trigger a compile error
// since the function is generic, just defining it is not enough
#[allow(warnings)]
fn #call_check_fn()
{
#call_check_fn_body
}
}
})
.collect::<TokenStream>()
@ -380,11 +380,11 @@ fn check_future_send(item_fn: &ItemFn) -> TokenStream {
}
fn self_receiver(item_fn: &ItemFn) -> Option<TokenStream> {
let takes_self = item_fn
.sig
.inputs
.iter()
.any(|arg| matches!(arg, syn::FnArg::Receiver(_)));
let takes_self = item_fn.sig.inputs.iter().any(|arg| match arg {
FnArg::Receiver(_) => true,
FnArg::Typed(typed) => is_self_pat_type(typed),
});
if takes_self {
return Some(quote! { Self:: });
}

View file

@ -1,10 +1,8 @@
use self::attr::{
parse_container_attrs, parse_field_attrs, FromRequestContainerAttr, FromRequestFieldAttr,
RejectionDeriveOptOuts,
};
use heck::ToUpperCamelCase;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned};
use quote::{quote, quote_spanned};
use syn::{punctuated::Punctuated, spanned::Spanned, Ident, Token};
mod attr;
@ -18,7 +16,7 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> {
generics,
fields,
semi_token: _,
vis,
vis: _,
struct_token: _,
} = item;
@ -34,32 +32,15 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> {
generic_ident,
)
}
FromRequestContainerAttr::RejectionDerive(_, opt_outs) => {
error_on_generic_ident(generic_ident)?;
impl_struct_by_extracting_each_field(ident, fields, vis, opt_outs, None)
}
FromRequestContainerAttr::Rejection(rejection) => {
error_on_generic_ident(generic_ident)?;
impl_struct_by_extracting_each_field(
ident,
fields,
vis,
RejectionDeriveOptOuts::default(),
Some(rejection),
)
impl_struct_by_extracting_each_field(ident, fields, Some(rejection))
}
FromRequestContainerAttr::None => {
error_on_generic_ident(generic_ident)?;
impl_struct_by_extracting_each_field(
ident,
fields,
vis,
RejectionDeriveOptOuts::default(),
None,
)
impl_struct_by_extracting_each_field(ident, fields, None)
}
}
}
@ -88,12 +69,6 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> {
FromRequestContainerAttr::Via { path, rejection } => {
impl_enum_by_extracting_all_at_once(ident, variants, path, rejection)
}
FromRequestContainerAttr::RejectionDerive(rejection_derive, _) => {
Err(syn::Error::new_spanned(
rejection_derive,
"cannot use `rejection_derive` on enums",
))
}
FromRequestContainerAttr::Rejection(rejection) => Err(syn::Error::new_spanned(
rejection,
"cannot use `rejection` without `via`",
@ -197,22 +172,16 @@ fn error_on_generic_ident(generic_ident: Option<Ident>) -> syn::Result<()> {
fn impl_struct_by_extracting_each_field(
ident: syn::Ident,
fields: syn::Fields,
vis: syn::Visibility,
rejection_derive_opt_outs: RejectionDeriveOptOuts,
rejection: Option<syn::Path>,
) -> syn::Result<TokenStream> {
let extract_fields = extract_fields(&fields, &rejection)?;
let (rejection_ident, rejection) = if let Some(rejection) = rejection {
let rejection_ident = syn::parse_quote!(#rejection);
(rejection_ident, None)
let rejection_ident = if let Some(rejection) = rejection {
quote!(#rejection)
} else if has_no_fields(&fields) {
(syn::parse_quote!(::std::convert::Infallible), None)
quote!(::std::convert::Infallible)
} else {
let rejection_ident = rejection_ident(&ident);
let rejection =
extract_each_field_rejection(&ident, &fields, &vis, rejection_derive_opt_outs)?;
(rejection_ident, Some(rejection))
quote!(::axum::response::Response)
};
Ok(quote! {
@ -228,15 +197,14 @@ fn impl_struct_by_extracting_each_field(
type Rejection = #rejection_ident;
async fn from_request(
req: &mut ::axum::extract::RequestParts<S, B>,
mut req: axum::http::Request<B>,
state: &S,
) -> ::std::result::Result<Self, Self::Rejection> {
::std::result::Result::Ok(Self {
#(#extract_fields)*
})
}
}
#rejection
})
}
@ -248,11 +216,6 @@ fn has_no_fields(fields: &syn::Fields) -> bool {
}
}
fn rejection_ident(ident: &syn::Ident) -> syn::Type {
let ident = format_ident!("{}Rejection", ident);
syn::parse_quote!(#ident)
}
fn extract_fields(
fields: &syn::Fields,
rejection: &Option<syn::Path>,
@ -261,6 +224,8 @@ fn extract_fields(
.iter()
.enumerate()
.map(|(index, field)| {
let is_last = fields.len() - 1 == index;
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
let member = if let Some(ident) = &field.ident {
@ -286,40 +251,79 @@ fn extract_fields(
}
};
let rejection_variant_name = rejection_variant_name(field)?;
if peel_option(&field.ty).is_some() {
if is_last {
Ok(quote_spanned! {ty_span=>
#member: {
::axum::extract::FromRequest::from_request(req)
::axum::extract::FromRequest::from_request(req, state)
.await
.ok()
.map(#into_inner)
},
})
} else if peel_result_ok(&field.ty).is_some() {
} else {
Ok(quote_spanned! {ty_span=>
#member: {
::axum::extract::FromRequest::from_request(req)
let (mut parts, body) = req.into_parts();
let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state)
.await
.ok()
.map(#into_inner);
req = ::axum::http::Request::from_parts(parts, body);
value
},
})
}
} else if peel_result_ok(&field.ty).is_some() {
if is_last {
Ok(quote_spanned! {ty_span=>
#member: {
::axum::extract::FromRequest::from_request(req, state)
.await
.map(#into_inner)
},
})
} else {
Ok(quote_spanned! {ty_span=>
#member: {
let (mut parts, body) = req.into_parts();
let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state)
.await
.map(#into_inner);
req = ::axum::http::Request::from_parts(parts, body);
value
},
})
}
} else {
let map_err = if let Some(rejection) = rejection {
quote! { <#rejection as ::std::convert::From<_>>::from }
} else {
quote! { Self::Rejection::#rejection_variant_name }
quote! { ::axum::response::IntoResponse::into_response }
};
if is_last {
Ok(quote_spanned! {ty_span=>
#member: {
::axum::extract::FromRequest::from_request(req)
::axum::extract::FromRequest::from_request(req, state)
.await
.map(#into_inner)
.map_err(#map_err)?
},
})
} else {
Ok(quote_spanned! {ty_span=>
#member: {
let (mut parts, body) = req.into_parts();
let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state)
.await
.map(#into_inner)
.map_err(#map_err)?;
req = ::axum::http::Request::from_parts(parts, body);
value
},
})
}
}
})
.collect()
@ -387,199 +391,6 @@ fn peel_result_ok(ty: &syn::Type) -> Option<&syn::Type> {
}
}
fn extract_each_field_rejection(
ident: &syn::Ident,
fields: &syn::Fields,
vis: &syn::Visibility,
rejection_derive_opt_outs: RejectionDeriveOptOuts,
) -> syn::Result<TokenStream> {
let rejection_ident = rejection_ident(ident);
let variants = fields
.iter()
.map(|field| {
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
let field_ty = &field.ty;
let ty_span = field_ty.span();
let variant_name = rejection_variant_name(field)?;
let extractor_ty = if let Some((_, path)) = via {
if let Some(inner) = peel_option(field_ty) {
quote_spanned! {ty_span=>
::std::option::Option<#path<#inner>>
}
} else if let Some(inner) = peel_result_ok(field_ty) {
quote_spanned! {ty_span=>
::std::result::Result<#path<#inner>, TypedHeaderRejection>
}
} else {
quote_spanned! {ty_span=> #path<#field_ty> }
}
} else {
quote_spanned! {ty_span=> #field_ty }
};
Ok(quote_spanned! {ty_span=>
#[allow(non_camel_case_types)]
#variant_name(<#extractor_ty as ::axum::extract::FromRequest<(), ::axum::body::Body>>::Rejection),
})
})
.collect::<syn::Result<Vec<_>>>()?;
let impl_into_response = {
let arms = fields
.iter()
.map(|field| {
let variant_name = rejection_variant_name(field)?;
Ok(quote! {
Self::#variant_name(inner) => inner.into_response(),
})
})
.collect::<syn::Result<Vec<_>>>()?;
quote! {
#[automatically_derived]
impl ::axum::response::IntoResponse for #rejection_ident {
fn into_response(self) -> ::axum::response::Response {
match self {
#(#arms)*
}
}
}
}
};
let impl_display = if rejection_derive_opt_outs.derive_display() {
let arms = fields
.iter()
.map(|field| {
let variant_name = rejection_variant_name(field)?;
Ok(quote! {
Self::#variant_name(inner) => inner.fmt(f),
})
})
.collect::<syn::Result<Vec<_>>>()?;
Some(quote! {
#[automatically_derived]
impl ::std::fmt::Display for #rejection_ident {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
match self {
#(#arms)*
}
}
}
})
} else {
None
};
let impl_error = if rejection_derive_opt_outs.derive_error() {
let arms = fields
.iter()
.map(|field| {
let variant_name = rejection_variant_name(field)?;
Ok(quote! {
Self::#variant_name(inner) => Some(inner),
})
})
.collect::<syn::Result<Vec<_>>>()?;
Some(quote! {
#[automatically_derived]
impl ::std::error::Error for #rejection_ident {
fn source(&self) -> ::std::option::Option<&(dyn ::std::error::Error + 'static)> {
match self {
#(#arms)*
}
}
}
})
} else {
None
};
let impl_debug = rejection_derive_opt_outs.derive_debug().then(|| {
quote! { #[derive(Debug)] }
});
Ok(quote! {
#impl_debug
#vis enum #rejection_ident {
#(#variants)*
}
#impl_into_response
#impl_display
#impl_error
})
}
fn rejection_variant_name(field: &syn::Field) -> syn::Result<syn::Ident> {
fn rejection_variant_name_for_type(out: &mut String, ty: &syn::Type) -> syn::Result<()> {
if let syn::Type::Path(type_path) = ty {
let segment = type_path
.path
.segments
.last()
.ok_or_else(|| syn::Error::new_spanned(ty, "Empty type path"))?;
out.push_str(&segment.ident.to_string());
match &segment.arguments {
syn::PathArguments::AngleBracketed(args) => {
let ty = if args.args.len() == 1 {
args.args.last().unwrap()
} else if args.args.len() == 2 {
if segment.ident == "Result" {
args.args.first().unwrap()
} else {
return Err(syn::Error::new_spanned(
segment,
"Only `Result<T, E>` is supported with two generics type paramters",
));
}
} else {
return Err(syn::Error::new_spanned(
&args.args,
"Expected exactly one or two type paramters",
));
};
if let syn::GenericArgument::Type(ty) = ty {
rejection_variant_name_for_type(out, ty)
} else {
Err(syn::Error::new_spanned(ty, "Expected type path"))
}
}
syn::PathArguments::Parenthesized(args) => {
Err(syn::Error::new_spanned(args, "Unsupported"))
}
syn::PathArguments::None => Ok(()),
}
} else {
Err(syn::Error::new_spanned(ty, "Expected type path"))
}
}
if let Some(ident) = &field.ident {
Ok(format_ident!("{}", ident.to_string().to_upper_camel_case()))
} else {
let mut out = String::new();
rejection_variant_name_for_type(&mut out, &field.ty)?;
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
if let Some((_, path)) = via {
let via_ident = &path.segments.last().unwrap().ident;
Ok(format_ident!("{}{}", via_ident, out))
} else {
Ok(format_ident!("{}", out))
}
}
}
fn impl_struct_by_extracting_all_at_once(
ident: syn::Ident,
fields: syn::Fields,
@ -606,12 +417,16 @@ fn impl_struct_by_extracting_all_at_once(
let path_span = path.span();
let associated_rejection_type = if let Some(rejection) = &rejection {
quote! { #rejection }
let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection {
let rejection = quote! { #rejection };
let map_err = quote! { ::std::convert::From::from };
(rejection, map_err)
} else {
quote! {
<#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection
}
let rejection = quote! {
::axum::response::Response
};
let map_err = quote! { ::axum::response::IntoResponse::into_response };
(rejection, map_err)
};
let rejection_bound = rejection.as_ref().map(|rejection| {
@ -658,18 +473,19 @@ fn impl_struct_by_extracting_all_at_once(
where
#path<#via_type_generics>: ::axum::extract::FromRequest<S, B>,
#rejection_bound
B: ::std::marker::Send,
B: ::std::marker::Send + 'static,
S: ::std::marker::Send + ::std::marker::Sync,
{
type Rejection = #associated_rejection_type;
async fn from_request(
req: &mut ::axum::extract::RequestParts<S, B>,
req: ::axum::http::Request<B>,
state: &S
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequest::<S, B>::from_request(req)
::axum::extract::FromRequest::from_request(req, state)
.await
.map(|#path(value)| #value_to_self)
.map_err(::std::convert::From::from)
.map_err(#map_err)
}
}
})
@ -707,12 +523,16 @@ fn impl_enum_by_extracting_all_at_once(
}
}
let associated_rejection_type = if let Some(rejection) = rejection {
quote! { #rejection }
let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection {
let rejection = quote! { #rejection };
let map_err = quote! { ::std::convert::From::from };
(rejection, map_err)
} else {
quote! {
<#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection
}
let rejection = quote! {
::axum::response::Response
};
let map_err = quote! { ::axum::response::IntoResponse::into_response };
(rejection, map_err)
};
let path_span = path.span();
@ -730,12 +550,13 @@ fn impl_enum_by_extracting_all_at_once(
type Rejection = #associated_rejection_type;
async fn from_request(
req: &mut ::axum::extract::RequestParts<S, B>,
req: ::axum::http::Request<B>,
state: &S
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequest::<S, B>::from_request(req)
::axum::extract::FromRequest::from_request(req, state)
.await
.map(|#path(inner)| inner)
.map_err(::std::convert::From::from)
.map_err(#map_err)
}
}
})

View file

@ -16,13 +16,11 @@ pub(crate) enum FromRequestContainerAttr {
rejection: Option<syn::Path>,
},
Rejection(syn::Path),
RejectionDerive(kw::rejection_derive, RejectionDeriveOptOuts),
None,
}
pub(crate) mod kw {
syn::custom_keyword!(via);
syn::custom_keyword!(rejection_derive);
syn::custom_keyword!(rejection);
syn::custom_keyword!(Display);
syn::custom_keyword!(Debug);
@ -55,7 +53,6 @@ pub(crate) fn parse_container_attrs(
let attrs = parse_attrs::<ContainerAttr>(attrs)?;
let mut out_via = None;
let mut out_rejection_derive = None;
let mut out_rejection = None;
// we track the index of the attribute to know which comes last
@ -69,16 +66,6 @@ pub(crate) fn parse_container_attrs(
out_via = Some((idx, via, path));
}
}
ContainerAttr::RejectionDerive {
rejection_derive,
opt_outs,
} => {
if out_rejection_derive.is_some() {
return Err(double_attr_error("rejection_derive", rejection_derive));
} else {
out_rejection_derive = Some((idx, rejection_derive, opt_outs));
}
}
ContainerAttr::Rejection { rejection, path } => {
if out_rejection.is_some() {
return Err(double_attr_error("rejection", rejection));
@ -89,55 +76,20 @@ pub(crate) fn parse_container_attrs(
}
}
match (out_via, out_rejection_derive, out_rejection) {
(Some((via_idx, via, _)), Some((rejection_derive_idx, rejection_derive, _)), _) => {
if via_idx > rejection_derive_idx {
Err(syn::Error::new_spanned(
via,
"cannot use both `rejection_derive` and `via`",
))
} else {
Err(syn::Error::new_spanned(
rejection_derive,
"cannot use both `via` and `rejection_derive`",
))
}
}
(
_,
Some((rejection_derive_idx, rejection_derive, _)),
Some((rejection_idx, rejection, _)),
) => {
if rejection_idx > rejection_derive_idx {
Err(syn::Error::new_spanned(
rejection,
"cannot use both `rejection_derive` and `rejection`",
))
} else {
Err(syn::Error::new_spanned(
rejection_derive,
"cannot use both `rejection` and `rejection_derive`",
))
}
}
(Some((_, _, path)), None, None) => Ok(FromRequestContainerAttr::Via {
match (out_via, out_rejection) {
(Some((_, _, path)), None) => Ok(FromRequestContainerAttr::Via {
path,
rejection: None,
}),
(Some((_, _, path)), None, Some((_, _, rejection))) => Ok(FromRequestContainerAttr::Via {
(Some((_, _, path)), Some((_, _, rejection))) => Ok(FromRequestContainerAttr::Via {
path,
rejection: Some(rejection),
}),
(None, Some((_, rejection_derive, opt_outs)), _) => Ok(
FromRequestContainerAttr::RejectionDerive(rejection_derive, opt_outs),
),
(None, Some((_, _, rejection))) => Ok(FromRequestContainerAttr::Rejection(rejection)),
(None, None, Some((_, _, rejection))) => Ok(FromRequestContainerAttr::Rejection(rejection)),
(None, None, None) => Ok(FromRequestContainerAttr::None),
(None, None) => Ok(FromRequestContainerAttr::None),
}
}
@ -172,10 +124,6 @@ enum ContainerAttr {
rejection: kw::rejection,
path: syn::Path,
},
RejectionDerive {
rejection_derive: kw::rejection_derive,
opt_outs: RejectionDeriveOptOuts,
},
}
impl Parse for ContainerAttr {
@ -186,14 +134,6 @@ impl Parse for ContainerAttr {
let content;
syn::parenthesized!(content in input);
content.parse().map(|path| Self::Via { via, path })
} else if lh.peek(kw::rejection_derive) {
let rejection_derive = input.parse::<kw::rejection_derive>()?;
let content;
syn::parenthesized!(content in input);
content.parse().map(|opt_outs| Self::RejectionDerive {
rejection_derive,
opt_outs,
})
} else if lh.peek(kw::rejection) {
let rejection = input.parse::<kw::rejection>()?;
let content;
@ -224,82 +164,3 @@ impl Parse for FieldAttr {
}
}
}
#[derive(Default)]
pub(crate) struct RejectionDeriveOptOuts {
debug: Option<kw::Debug>,
display: Option<kw::Display>,
error: Option<kw::Error>,
}
impl RejectionDeriveOptOuts {
pub(crate) fn derive_debug(&self) -> bool {
self.debug.is_none()
}
pub(crate) fn derive_display(&self) -> bool {
self.display.is_none()
}
pub(crate) fn derive_error(&self) -> bool {
self.error.is_none()
}
}
impl Parse for RejectionDeriveOptOuts {
fn parse(input: ParseStream) -> syn::Result<Self> {
fn parse_opt_out<T>(out: &mut Option<T>, ident: &str, input: ParseStream) -> syn::Result<()>
where
T: Parse,
{
if out.is_some() {
Err(input.error(format!("`{}` opt out specified more than once", ident)))
} else {
*out = Some(input.parse()?);
Ok(())
}
}
let mut debug = None::<kw::Debug>;
let mut display = None::<kw::Display>;
let mut error = None::<kw::Error>;
while !input.is_empty() {
input.parse::<Token![!]>()?;
let lh = input.lookahead1();
if lh.peek(kw::Debug) {
parse_opt_out(&mut debug, "Debug", input)?;
} else if lh.peek(kw::Display) {
parse_opt_out(&mut display, "Display", input)?;
} else if lh.peek(kw::Error) {
parse_opt_out(&mut error, "Error", input)?;
} else {
return Err(lh.error());
}
input.parse::<Token![,]>().ok();
}
if error.is_none() {
match (debug, display) {
(Some(debug), Some(_)) => {
return Err(syn::Error::new_spanned(debug, "opt out of `Debug` and `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Display, !Error))]`"));
}
(Some(debug), None) => {
return Err(syn::Error::new_spanned(debug, "opt out of `Debug` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Error))]`"));
}
(None, Some(display)) => {
return Err(syn::Error::new_spanned(display, "opt out of `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Display, !Error))]`"));
}
(None, None) => {}
}
}
Ok(Self {
debug,
display,
error,
})
}
}

View file

@ -86,6 +86,20 @@ mod typed_path;
///
/// This requires that each field is an extractor (i.e. implements [`FromRequest`]).
///
/// ```compile_fail
/// use axum_macros::FromRequest;
/// use axum::body::Bytes;
///
/// #[derive(FromRequest)]
/// struct MyExtractor {
/// // only the last field can implement `FromRequest`
/// // other fields must only implement `FromRequestParts`
/// bytes: Bytes,
/// string: String,
/// }
/// ```
/// Note that only the last field can consume the request body. Therefore this doesn't compile:
///
/// ## Extracting via another extractor
///
/// You can use `#[from_request(via(...))]` to extract a field via another extractor, meaning the
@ -157,95 +171,15 @@ mod typed_path;
///
/// ## The rejection
///
/// A rejection enum is also generated. It has a variant for each field:
///
/// ```
/// use axum_macros::FromRequest;
/// use axum::{
/// extract::{Extension, TypedHeader},
/// headers::ContentType,
/// body::Bytes,
/// };
///
/// #[derive(FromRequest)]
/// struct MyExtractor {
/// #[from_request(via(Extension))]
/// state: State,
/// #[from_request(via(TypedHeader))]
/// content_type: ContentType,
/// request_body: Bytes,
/// }
///
/// // also generates
/// //
/// // #[derive(Debug)]
/// // enum MyExtractorRejection {
/// // State(ExtensionRejection),
/// // ContentType(TypedHeaderRejection),
/// // RequestBody(BytesRejection),
/// // }
/// //
/// // impl axum::response::IntoResponse for MyExtractor { ... }
/// //
/// // impl std::fmt::Display for MyExtractor { ... }
/// //
/// // impl std::error::Error for MyExtractor { ... }
///
/// #[derive(Clone)]
/// struct State {
/// // ...
/// }
/// ```
///
/// The rejection's `std::error::Error::source` implementation returns the inner rejection. This
/// can be used to access source errors for example to customize rejection responses. Note this
/// means the inner rejection types must themselves implement `std::error::Error`. All extractors
/// in axum does this.
///
/// You can opt out of this using `#[from_request(rejection_derive(...))]`:
///
/// ```
/// use axum_macros::FromRequest;
/// use axum::{
/// extract::{FromRequest, RequestParts},
/// http::StatusCode,
/// headers::ContentType,
/// body::Bytes,
/// async_trait,
/// };
///
/// #[derive(FromRequest)]
/// #[from_request(rejection_derive(!Display, !Error))]
/// struct MyExtractor {
/// other: OtherExtractor,
/// }
///
/// struct OtherExtractor;
///
/// #[async_trait]
/// impl<S, B> FromRequest<S, B> for OtherExtractor
/// where
/// B: Send,
/// S: Send + Sync,
/// {
/// // this rejection doesn't implement `Display` and `Error`
/// type Rejection = (StatusCode, String);
///
/// async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // ...
/// # unimplemented!()
/// }
/// }
/// ```
///
/// You can also use your own rejection type with `#[from_request(rejection(YourType))]`:
/// By default [`axum::response::Response`] will be used as the rejection. You can also use your own
/// rejection type with `#[from_request(rejection(YourType))]`:
///
/// ```
/// use axum_macros::FromRequest;
/// use axum::{
/// extract::{
/// rejection::{ExtensionRejection, StringRejection},
/// FromRequest, RequestParts,
/// FromRequest,
/// },
/// Extension,
/// response::{Response, IntoResponse},
@ -414,6 +348,7 @@ mod typed_path;
/// ```
///
/// [`FromRequest`]: https://docs.rs/axum/latest/axum/extract/trait.FromRequest.html
/// [`axum::response::Response`]: https://docs.rs/axum/0.6/axum/response/type.Response.html
/// [`axum::extract::rejection::ExtensionRejection`]: https://docs.rs/axum/latest/axum/extract/rejection/enum.ExtensionRejection.html
#[proc_macro_derive(FromRequest, attributes(from_request))]
pub fn derive_from_request(item: TokenStream) -> TokenStream {

View file

@ -127,15 +127,17 @@ fn expand_named_fields(
let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
impl<S> ::axum::extract::FromRequestParts<S> for #ident
where
B: Send,
S: Send + Sync,
{
type Rejection = #rejection_assoc_type;
async fn from_request(req: &mut ::axum::extract::RequestParts<S, B>) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request(req)
async fn from_request_parts(
parts: &mut ::axum::http::request::Parts,
state: &S,
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request_parts(parts, state)
.await
.map(|path| path.0)
#map_err_rejection
@ -230,15 +232,17 @@ fn expand_unnamed_fields(
let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
impl<S> ::axum::extract::FromRequestParts<S> for #ident
where
B: Send,
S: Send + Sync,
{
type Rejection = #rejection_assoc_type;
async fn from_request(req: &mut ::axum::extract::RequestParts<S, B>) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request(req)
async fn from_request_parts(
parts: &mut ::axum::http::request::Parts,
state: &S,
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request_parts(parts, state)
.await
.map(|path| path.0)
#map_err_rejection
@ -312,15 +316,17 @@ fn expand_unit_fields(
let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
impl<S> ::axum::extract::FromRequestParts<S> for #ident
where
B: Send,
S: Send + Sync,
{
type Rejection = #rejection_assoc_type;
async fn from_request(req: &mut ::axum::extract::RequestParts<S, B>) -> ::std::result::Result<Self, Self::Rejection> {
if req.uri().path() == <Self as ::axum_extra::routing::TypedPath>::PATH {
async fn from_request_parts(
parts: &mut ::axum::http::request::Parts,
_state: &S,
) -> ::std::result::Result<Self, Self::Rejection> {
if parts.uri.path() == <Self as ::axum_extra::routing::TypedPath>::PATH {
Ok(Self)
} else {
#create_rejection
@ -390,7 +396,7 @@ enum Segment {
fn path_rejection() -> TokenStream {
quote! {
<::axum::extract::Path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection
<::axum::extract::Path<Self> as ::axum::extract::FromRequestParts<S>>::Rejection
}
}

View file

@ -1,17 +1,22 @@
error[E0277]: the trait bound `bool: FromRequest<(), Body>` is not satisfied
error[E0277]: the trait bound `bool: FromRequestParts<()>` is not satisfied
--> tests/debug_handler/fail/argument_not_extractor.rs:4:23
|
4 | async fn handler(foo: bool) {}
| ^^^^ the trait `FromRequest<(), Body>` is not implemented for `bool`
| ^^^^ the trait `FromRequestParts<()>` is not implemented for `bool`
|
= help: the following other types implement trait `FromRequest<S, B>`:
<() as FromRequest<S, B>>
<(T1, T2) as FromRequest<S, B>>
<(T1, T2, T3) as FromRequest<S, B>>
<(T1, T2, T3, T4) as FromRequest<S, B>>
<(T1, T2, T3, T4, T5) as FromRequest<S, B>>
<(T1, T2, T3, T4, T5, T6) as FromRequest<S, B>>
<(T1, T2, T3, T4, T5, T6, T7) as FromRequest<S, B>>
<(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequest<S, B>>
and 34 others
= help: see issue #48214
= help: the following other types implement trait `FromRequestParts<S>`:
<() as FromRequestParts<S>>
<(T1, T2) as FromRequestParts<S>>
<(T1, T2, T3) as FromRequestParts<S>>
<(T1, T2, T3, T4) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6, T7) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequestParts<S>>
and 26 others
= note: required because of the requirements on the impl of `FromRequest<(), Body, axum_core::extract::private::ViaParts>` for `bool`
note: required by a bound in `__axum_macros_check_handler_0_from_request_check`
--> tests/debug_handler/fail/argument_not_extractor.rs:4:23
|
4 | async fn handler(foo: bool) {}
| ^^^^ required by this bound in `__axum_macros_check_handler_0_from_request_check`

View file

@ -1,6 +1,7 @@
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
extract::FromRequest,
http::Request,
};
use axum_macros::debug_handler;
@ -9,12 +10,12 @@ struct A;
#[async_trait]
impl<S, B> FromRequest<S, B> for A
where
B: Send,
B: Send + 'static,
S: Send + Sync,
{
type Rejection = ();
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request(_req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}

View file

@ -1,5 +1,5 @@
error: Handlers must only take owned values
--> tests/debug_handler/fail/extract_self_mut.rs:24:22
--> tests/debug_handler/fail/extract_self_mut.rs:25:22
|
24 | async fn handler(&mut self) {}
25 | async fn handler(&mut self) {}
| ^^^^^^^^^

View file

@ -1,6 +1,7 @@
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
extract::FromRequest,
http::Request,
};
use axum_macros::debug_handler;
@ -9,12 +10,12 @@ struct A;
#[async_trait]
impl<S, B> FromRequest<S, B> for A
where
B: Send,
B: Send + 'static,
S: Send + Sync,
{
type Rejection = ();
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request(_req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}

View file

@ -1,5 +1,5 @@
error: Handlers must only take owned values
--> tests/debug_handler/fail/extract_self_ref.rs:24:22
--> tests/debug_handler/fail/extract_self_ref.rs:25:22
|
24 | async fn handler(&self) {}
25 | async fn handler(&self) {}
| ^^^^^

View file

@ -1,7 +0,0 @@
use axum_macros::debug_handler;
use axum::body::Bytes;
#[debug_handler]
async fn handler(_: String, _: Bytes) {}
fn main() {}

View file

@ -1,11 +0,0 @@
error: Only one body extractor can be applied
--> tests/debug_handler/fail/multiple_body_extractors.rs:5:18
|
5 | async fn handler(_: String, _: Bytes) {}
| ^^^^^^^^^
error: Only one body extractor can be applied
--> tests/debug_handler/fail/multiple_body_extractors.rs:5:29
|
5 | async fn handler(_: String, _: Bytes) {}
| ^^^^^^^^

View file

@ -1,7 +0,0 @@
use axum::{body::Body, extract::Extension, http::Request};
use axum_macros::debug_handler;
#[debug_handler]
async fn handler(_: Request<Body>, _: Extension<String>) {}
fn main() {}

View file

@ -1,5 +0,0 @@
error: `Request` extractor should always be last
--> tests/debug_handler/fail/request_not_last.rs:5:18
|
5 | async fn handler(_: Request<Body>, _: Extension<String>) {}
| ^^^^^^^^^^^^^^^^

View file

@ -13,7 +13,7 @@ error[E0277]: the trait bound `bool: IntoResponse` is not satisfied
(Response<()>, T1, T2, R)
(Response<()>, T1, T2, T3, R)
(Response<()>, T1, T2, T3, T4, R)
and 123 others
and 122 others
note: required by a bound in `__axum_macros_check_handler_into_response::{closure#0}::check`
--> tests/debug_handler/fail/wrong_return_type.rs:4:23
|

View file

@ -1,8 +1,4 @@
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
response::IntoResponse,
};
use axum::{async_trait, extract::FromRequest, http::Request, response::IntoResponse};
use axum_macros::debug_handler;
fn main() {}
@ -122,12 +118,12 @@ impl A {
#[async_trait]
impl<S, B> FromRequest<S, B> for A
where
B: Send,
B: Send + 'static,
S: Send + Sync,
{
type Rejection = ();
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request(_req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}

View file

@ -1,6 +1,7 @@
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
extract::FromRequest,
http::Request,
};
use axum_macros::debug_handler;
@ -9,12 +10,25 @@ struct A;
#[async_trait]
impl<S, B> FromRequest<S, B> for A
where
B: Send,
B: Send + 'static,
S: Send + Sync,
{
type Rejection = ();
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request(_req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}
#[async_trait]
impl<S, B> FromRequest<S, B> for Box<A>
where
B: Send + 'static,
S: Send + Sync,
{
type Rejection = ();
async fn from_request(_req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}
@ -22,6 +36,9 @@ where
impl A {
#[debug_handler]
async fn handler(self) {}
#[debug_handler]
async fn handler_with_qualified_self(self: Box<Self>) {}
}
fn main() {}

View file

@ -1,6 +1,7 @@
use axum_macros::debug_handler;
use axum::extract::{FromRef, FromRequest, RequestParts};
use axum::extract::{FromRef, FromRequest};
use axum::async_trait;
use axum::http::Request;
#[debug_handler(state = AppState)]
async fn handler(_: A) {}
@ -13,13 +14,13 @@ struct A;
#[async_trait]
impl<S, B> FromRequest<S, B> for A
where
B: Send,
B: Send + 'static,
S: Send + Sync,
AppState: FromRef<S>,
{
type Rejection = ();
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request(_req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}

View file

@ -1,9 +0,0 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Debug, !Display))]
struct Extractor {
body: String,
}
fn main() {}

View file

@ -1,5 +0,0 @@
error: opt out of `Debug` and `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Display, !Error))]`
--> tests/from_request/fail/derive_opt_out_debug_and_display_without_error.rs:4:34
|
4 | #[from_request(rejection_derive(!Debug, !Display))]
| ^^^^^

View file

@ -1,9 +0,0 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Debug))]
struct Extractor {
body: String,
}
fn main() {}

View file

@ -1,5 +0,0 @@
error: opt out of `Debug` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Error))]`
--> tests/from_request/fail/derive_opt_out_debug_without_error.rs:4:34
|
4 | #[from_request(rejection_derive(!Debug))]
| ^^^^^

View file

@ -1,9 +0,0 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Display))]
struct Extractor {
body: String,
}
fn main() {}

View file

@ -1,5 +0,0 @@
error: opt out of `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Display, !Error))]`
--> tests/from_request/fail/derive_opt_out_display_without_error.rs:4:34
|
4 | #[from_request(rejection_derive(!Display))]
| ^^^^^^^

View file

@ -1,9 +0,0 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Error, !Error))]
struct Extractor {
body: String,
}
fn main() {}

View file

@ -1,5 +0,0 @@
error: `Error` opt out specified more than once
--> tests/from_request/fail/derive_opt_out_duplicate.rs:4:42
|
4 | #[from_request(rejection_derive(!Error, !Error))]
| ^^^^^

View file

@ -1,7 +0,0 @@
use axum_macros::FromRequest;
#[derive(FromRequest, Clone)]
#[from_request(rejection_derive(!Error))]
enum Extractor {}
fn main() {}

View file

@ -1,5 +0,0 @@
error: cannot use `rejection_derive` on enums
--> tests/from_request/fail/enum_rejection_derive.rs:4:16
|
4 | #[from_request(rejection_derive(!Error))]
| ^^^^^^^^^^^^^^^^

View file

@ -1,12 +0,0 @@
use axum::{body::Body, routing::get, Router};
use axum_macros::FromRequest;
#[derive(FromRequest, Clone)]
#[from_request(rejection_derive(!Error))]
struct Extractor<T>(T);
async fn foo(_: Extractor<()>) {}
fn main() {
Router::<(), Body>::new().route("/", get(foo));
}

View file

@ -1,21 +0,0 @@
error: #[derive(FromRequest)] only supports generics when used with #[from_request(via)]
--> tests/from_request/fail/generic_without_via_rejection_derive.rs:6:18
|
6 | struct Extractor<T>(T);
| ^
error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {foo}: Handler<_, _, _>` is not satisfied
--> tests/from_request/fail/generic_without_via_rejection_derive.rs:11:46
|
11 | Router::<(), Body>::new().route("/", get(foo));
| --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future<Output = ()> {foo}`
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|
| top_level_handler_fn!(get, GET);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `axum::routing::get`
= note: this error originates in the macro `top_level_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info)

View file

@ -1,9 +0,0 @@
use axum_macros::FromRequest;
#[derive(FromRequest, Clone)]
#[from_request(rejection_derive(!Error), via(axum::Extension))]
struct Extractor {
config: String,
}
fn main() {}

View file

@ -1,5 +0,0 @@
error: cannot use both `rejection_derive` and `via`
--> tests/from_request/fail/rejection_derive_and_via.rs:4:42
|
4 | #[from_request(rejection_derive(!Error), via(axum::Extension))]
| ^^^

View file

@ -1,4 +1,4 @@
error: expected one of: `via`, `rejection_derive`, `rejection`
error: expected `via` or `rejection`
--> tests/from_request/fail/unknown_attr_container.rs:4:16
|
4 | #[from_request(foo)]

View file

@ -1,9 +0,0 @@
use axum_macros::FromRequest;
#[derive(FromRequest, Clone)]
#[from_request(via(axum::Extension), rejection_derive(!Error))]
struct Extractor {
config: String,
}
fn main() {}

View file

@ -1,5 +0,0 @@
error: cannot use both `via` and `rejection_derive`
--> tests/from_request/fail/via_and_rejection_derive.rs:4:38
|
4 | #[from_request(via(axum::Extension), rejection_derive(!Error))]
| ^^^^^^^^^^^^^^^^

View file

@ -1,6 +1,7 @@
use axum::{
body::Body,
extract::{rejection::JsonRejection, FromRequest, Json},
extract::{FromRequest, Json},
response::Response,
};
use axum_macros::FromRequest;
use serde::Deserialize;
@ -15,7 +16,7 @@ struct Extractor {
fn assert_from_request()
where
Extractor: FromRequest<(), Body, Rejection = JsonRejection>,
Extractor: FromRequest<(), Body, Rejection = Response>,
{
}

View file

@ -1,38 +0,0 @@
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
response::{IntoResponse, Response},
};
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Display, !Error))]
struct Extractor {
other: OtherExtractor,
}
struct OtherExtractor;
#[async_trait]
impl<S, B> FromRequest<S, B> for OtherExtractor
where
B: Send,
S: Send + Sync,
{
type Rejection = OtherExtractorRejection;
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}
#[derive(Debug)]
struct OtherExtractorRejection;
impl IntoResponse for OtherExtractorRejection {
fn into_response(self) -> Response {
unimplemented!()
}
}
fn main() {}

View file

@ -1,10 +1,10 @@
use axum::{
body::Body,
extract::{FromRequest, TypedHeader, rejection::{TypedHeaderRejection, StringRejection}},
extract::{FromRequest, TypedHeader, rejection::TypedHeaderRejection},
response::Response,
headers::{self, UserAgent},
};
use axum_macros::FromRequest;
use std::convert::Infallible;
#[derive(FromRequest)]
struct Extractor {
@ -18,34 +18,8 @@ struct Extractor {
fn assert_from_request()
where
Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>,
Extractor: FromRequest<(), Body, Rejection = Response>,
{
}
fn assert_rejection(rejection: ExtractorRejection)
where
ExtractorRejection: std::fmt::Debug + std::fmt::Display + std::error::Error,
{
match rejection {
ExtractorRejection::Uri(inner) => {
let _: Infallible = inner;
}
ExtractorRejection::Body(inner) => {
let _: StringRejection = inner;
}
ExtractorRejection::UserAgent(inner) => {
let _: TypedHeaderRejection = inner;
}
ExtractorRejection::ContentType(inner) => {
let _: TypedHeaderRejection = inner;
}
ExtractorRejection::Etag(inner) => {
let _: Infallible = inner;
}
ExtractorRejection::Host(inner) => {
let _: Infallible = inner;
}
}
}
fn main() {}

View file

@ -1,13 +1,13 @@
use axum::{
body::Body,
response::Response,
extract::{
rejection::{ExtensionRejection, TypedHeaderRejection},
rejection::TypedHeaderRejection,
Extension, FromRequest, TypedHeader,
},
headers::{self, UserAgent},
};
use axum_macros::FromRequest;
use std::convert::Infallible;
#[derive(FromRequest)]
struct Extractor {
@ -25,33 +25,10 @@ struct Extractor {
fn assert_from_request()
where
Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>,
Extractor: FromRequest<(), Body, Rejection = Response>,
{
}
fn assert_rejection(rejection: ExtractorRejection)
where
ExtractorRejection: std::fmt::Debug + std::fmt::Display + std::error::Error,
{
match rejection {
ExtractorRejection::State(inner) => {
let _: ExtensionRejection = inner;
}
ExtractorRejection::UserAgent(inner) => {
let _: TypedHeaderRejection = inner;
}
ExtractorRejection::ContentType(inner) => {
let _: TypedHeaderRejection = inner;
}
ExtractorRejection::Etag(inner) => {
let _: Infallible = inner;
}
ExtractorRejection::Host(inner) => {
let _: Infallible = inner;
}
}
}
#[derive(Clone)]
struct State;

View file

@ -1,7 +1,7 @@
use axum::{
async_trait,
extract::{rejection::ExtensionRejection, FromRequest, RequestParts},
http::StatusCode,
extract::{rejection::ExtensionRejection, FromRequest},
http::{StatusCode, Request},
response::{IntoResponse, Response},
routing::get,
Extension, Router,
@ -36,7 +36,7 @@ where
// this rejection doesn't implement `Display` and `Error`
type Rejection = (StatusCode, String);
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request(_req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
todo!()
}
}

View file

@ -1,4 +1,4 @@
use axum::extract::{Query, rejection::*};
use axum::extract::Query;
use axum_macros::FromRequest;
use serde::Deserialize;
@ -17,18 +17,4 @@ where
{
}
fn assert_rejection(rejection: ExtractorRejection)
where
ExtractorRejection: std::fmt::Debug + std::fmt::Display + std::error::Error,
{
match rejection {
ExtractorRejection::QueryPayload(inner) => {
let _: QueryRejection = inner;
}
ExtractorRejection::JsonPayload(inner) => {
let _: JsonRejection = inner;
}
}
}
fn main() {}

View file

@ -1,4 +1,5 @@
use axum::extract::{Query, rejection::*};
use axum::extract::Query;
use axum::response::Response;
use axum_macros::FromRequest;
use serde::Deserialize;
@ -8,26 +9,12 @@ struct Extractor(
#[from_request(via(axum::extract::Json))] Payload,
);
fn assert_rejection(rejection: ExtractorRejection)
where
ExtractorRejection: std::fmt::Debug + std::fmt::Display + std::error::Error,
{
match rejection {
ExtractorRejection::QueryPayload(inner) => {
let _: QueryRejection = inner;
}
ExtractorRejection::JsonPayload(inner) => {
let _: JsonRejection = inner;
}
}
}
#[derive(Deserialize)]
struct Payload {}
fn assert_from_request()
where
Extractor: axum::extract::FromRequest<(), axum::body::Body>,
Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = Response>,
{
}

View file

@ -15,5 +15,5 @@ error[E0277]: the trait bound `for<'de> MyPath: serde::de::Deserialize<'de>` is
(T0, T1, T2, T3)
and 138 others
= note: required because of the requirements on the impl of `serde::de::DeserializeOwned` for `MyPath`
= note: required because of the requirements on the impl of `FromRequest<S, B>` for `axum::extract::Path<MyPath>`
= note: required because of the requirements on the impl of `FromRequestParts<S>` for `axum::extract::Path<MyPath>`
= note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info)

View file

@ -16,7 +16,6 @@ async fn result_handler(_: Result<UsersShow, PathRejection>) {}
#[typed_path("/users")]
struct UsersIndex;
#[axum_macros::debug_handler]
async fn result_handler_unit_struct(_: Result<UsersIndex, StatusCode>) {}
fn main() {

View file

@ -237,18 +237,48 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
}
}
```
- **breaking:** The following types or traits have a new `S` type param
(`()` by default) which represents the state ([#1155]):
- `FromRequest`
- `RequestParts`
- `Router`
- `MethodRouter`
- `Handler`
- **breaking:** It is now only possible for one extractor per handler to consume
the request body. In 0.5 doing so would result in runtime errors but in 0.6 it
is a compile error ([#1272])
axum enforces this by only allowing the _last_ extractor to consume the
request.
For example:
```rust
use axum::{Json, http::HeaderMap};
// This wont compile on 0.6 because both `Json` and `String` need to consume
// the request body. You can use either `Json` or `String`, but not both.
async fn handler_1(
json: Json<serde_json::Value>,
string: String,
) {}
// This won't work either since `Json` is not the last extractor.
async fn handler_2(
json: Json<serde_json::Value>,
headers: HeaderMap,
) {}
// This works!
async fn handler_3(
headers: HeaderMap,
json: Json<serde_json::Value>,
) {}
```
This is done by reworking the `FromRequest` trait and introducing a new
`FromRequestParts` trait.
If your extractor needs to consume the request body then you should implement
`FromRequest`, otherwise implement `FromRequestParts`.
This extractor in 0.5:
```rust
struct MyExtractor;
struct MyExtractor { /* ... */ }
#[async_trait]
impl<B> FromRequest<B> for MyExtractor
@ -266,22 +296,53 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Becomes this in 0.6:
```rust
struct MyExtractor;
use axum::{
extract::{FromRequest, FromRequestParts},
http::{StatusCode, Request, request::Parts},
async_trait,
};
struct MyExtractor { /* ... */ }
// implement `FromRequestParts` if you don't need to consume the request body
#[async_trait]
impl<S> FromRequestParts<S> for MyExtractor
where
S: Send + Sync,
B: Send + 'static,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// ...
}
}
// implement `FromRequest` if you do need to consume the request body
#[async_trait]
impl<S, B> FromRequest<S, B> for MyExtractor
where
S: Send + Sync,
B: Send,
B: Send + 'static,
{
type Rejection = StatusCode;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
// ...
}
}
```
- **breaking:** `RequestParts` has been removed as part of the `FromRequest`
rework ([#1272])
- **breaking:** `BodyAlreadyExtracted` has been removed ([#1272])
- **breaking:** The following types or traits have a new `S` type param
which represents the state ([#1155]):
- `Router`, defaults to `()`
- `MethodRouter`, defaults to `()`
- `FromRequest`, no default
- `Handler`, no default
## Middleware
- **breaking:** Remove `extractor_middleware` which was previously deprecated.
@ -310,6 +371,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1155]: https://github.com/tokio-rs/axum/pull/1155
[#1239]: https://github.com/tokio-rs/axum/pull/1239
[#1248]: https://github.com/tokio-rs/axum/pull/1248
[#1272]: https://github.com/tokio-rs/axum/pull/1272
[#924]: https://github.com/tokio-rs/axum/pull/924
# 0.5.15 (9. August, 2022)

View file

@ -5,14 +5,15 @@ Types and traits for extracting data from requests.
- [Intro](#intro)
- [Common extractors](#common-extractors)
- [Applying multiple extractors](#applying-multiple-extractors)
- [Be careful when extracting `Request`](#be-careful-when-extracting-request)
- [The order of extractors](#the-order-of-extractors)
- [Optional extractors](#optional-extractors)
- [Customizing extractor responses](#customizing-extractor-responses)
- [Accessing inner errors](#accessing-inner-errors)
- [Defining custom extractors](#defining-custom-extractors)
- [Accessing other extractors in `FromRequest` implementations](#accessing-other-extractors-in-fromrequest-implementations)
- [Accessing other extractors in `FromRequest` or `FromRequestParts` implementations](#accessing-other-extractors-in-fromrequest-or-fromrequestparts-implementations)
- [Request body extractors](#request-body-extractors)
- [Running extractors from middleware](#running-extractors-from-middleware)
- [Wrapping extractors](#wrapping-extractors)
# Intro
@ -152,83 +153,74 @@ async fn get_user_things(
# };
```
# The order of extractors
Extractors always run in the order of the function parameters that is from
left to right.
# Be careful when extracting `Request`
The request body is an asynchronous stream that can only be consumed once.
Therefore you can only have one extractor that consumes the request body. axum
enforces by that requiring such extractors to be the _last_ argument your
handler takes.
[`Request`] is itself an extractor:
For example
```rust,no_run
use axum::{http::Request, body::Body};
```rust
use axum::http::{Method, HeaderMap};
async fn handler(request: Request<Body>) {
async fn handler(
// `Method` and `HeaderMap` don't consume the request body so they can
// put anywhere in the argument list
method: Method,
headers: HeaderMap,
// `String` consumes the request body and thus must be the last extractor
body: String,
) {
// ...
}
#
# let _: axum::routing::MethodRouter = axum::routing::get(handler);
```
However be careful when combining it with other extractors since it will consume
all extensions and the request body. Therefore it is recommended to always apply
the request extractor last:
We get a compile error if `String` isn't the last extractor:
```rust,no_run
use axum::{http::Request, Extension, body::Body};
```rust,compile_fail
use axum::http::Method;
// this will fail at runtime since `Request<Body>` will have consumed all the
// extensions so `Extension<State>` will be missing
async fn broken(
request: Request<Body>,
Extension(state): Extension<State>,
async fn handler(
// this doesn't work since `String` must be the last argument
body: String,
method: Method,
) {
// ...
}
// this will work since we extract `Extension<State>` before `Request<Body>`
async fn works(
Extension(state): Extension<State>,
request: Request<Body>,
) {
// ...
}
#[derive(Clone)]
struct State {};
#
# let _: axum::routing::MethodRouter = axum::routing::get(handler);
```
# Extracting request bodies
This also means you cannot consume the request body twice:
Since request bodies are asynchronous streams they can only be extracted once:
```rust,compile_fail
use axum::Json;
use serde::Deserialize;
```rust,no_run
use axum::{Json, http::Request, body::{Bytes, Body}};
use serde_json::Value;
#[derive(Deserialize)]
struct Payload {}
// this will fail at runtime since `Json<Value>` and `Bytes` both attempt to extract
// the body
//
// the solution is to only extract the body once so remove either
// `body_json: Json<Value>` or `body_bytes: Bytes`
async fn broken(
body_json: Json<Value>,
body_bytes: Bytes,
) {
// ...
}
// this doesn't work either for the same reason: `Bytes` and `Request<Body>`
// both extract the body
async fn also_broken(
body_json: Json<Value>,
request: Request<Body>,
async fn handler(
// `String` and `Json` both consume the request body
// so they cannot both be used
string_body: String,
json_body: Json<Payload>,
) {
// ...
}
#
# let _: axum::routing::MethodRouter = axum::routing::get(handler);
```
Also keep this in mind if you extract or otherwise consume the body in
middleware. You either need to not extract the body in handlers or make sure
your middleware reinserts the body using [`RequestParts::body_mut`] so it's
available to handlers.
axum enforces this by requiring the last extractor implements [`FromRequest`]
and all others implement [`FromRequestParts`].
# Optional extractors
@ -407,29 +399,38 @@ happen without major breaking versions.
# Defining custom extractors
You can also define your own extractors by implementing [`FromRequest`]:
You can also define your own extractors by implementing either
[`FromRequestParts`] or [`FromRequest`].
## Implementing `FromRequestParts`
Implement `FromRequestParts` if your extractor doesn't need access to the
request body:
```rust,no_run
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
extract::FromRequestParts,
routing::get,
Router,
http::{
StatusCode,
header::{HeaderValue, USER_AGENT},
request::Parts,
},
};
use http::{StatusCode, header::{HeaderValue, USER_AGENT}};
struct ExtractUserAgent(HeaderValue);
#[async_trait]
impl<S, B> FromRequest<S, B> for ExtractUserAgent
impl<S> FromRequestParts<S> for ExtractUserAgent
where
B: Send,
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if let Some(user_agent) = req.headers().get(USER_AGENT) {
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
if let Some(user_agent) = parts.headers.get(USER_AGENT) {
Ok(ExtractUserAgent(user_agent.clone()))
} else {
Err((StatusCode::BAD_REQUEST, "`User-Agent` header is missing"))
@ -447,7 +448,58 @@ let app = Router::new().route("/foo", get(handler));
# };
```
# Accessing other extractors in [`FromRequest`] implementations
## Implementing `FromRequest`
If your extractor needs to consume the request body you must implement [`FromRequest`]
```rust,no_run
use axum::{
async_trait,
extract::FromRequest,
response::{Response, IntoResponse},
body::Bytes,
routing::get,
Router,
http::{
StatusCode,
header::{HeaderValue, USER_AGENT},
Request,
},
};
struct ValidatedBody(Bytes);
#[async_trait]
impl<S, B> FromRequest<S, B> for ValidatedBody
where
Bytes: FromRequest<S, B>,
B: Send + 'static,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let body = Bytes::from_request(req, state)
.await
.map_err(IntoResponse::into_response)?;
// do validation...
Ok(Self(body))
}
}
async fn handler(ValidatedBody(body): ValidatedBody) {
// ...
}
let app = Router::new().route("/foo", get(handler));
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
```
# Accessing other extractors in `FromRequest` or `FromRequestParts` implementations
When defining custom extractors you often need to access another extractors
in your implementation.
@ -455,9 +507,9 @@ in your implementation.
```rust
use axum::{
async_trait,
extract::{Extension, FromRequest, RequestParts, TypedHeader},
extract::{Extension, FromRequestParts, TypedHeader},
headers::{authorization::Bearer, Authorization},
http::StatusCode,
http::{StatusCode, request::Parts},
response::{IntoResponse, Response},
routing::get,
Router,
@ -473,20 +525,19 @@ struct AuthenticatedUser {
}
#[async_trait]
impl<S, B> FromRequest<S, B> for AuthenticatedUser
impl<S> FromRequestParts<S> for AuthenticatedUser
where
B: Send,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(token)) =
TypedHeader::<Authorization<Bearer>>::from_request(req)
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
.await
.map_err(|err| err.into_response())?;
let Extension(state): Extension<State> = Extension::from_request(req)
let Extension(state): Extension<State> = Extension::from_request_parts(parts, state)
.await
.map_err(|err| err.into_response())?;
@ -584,14 +635,13 @@ let app = Router::new()
# Running extractors from middleware
Extractors can also be run from middleware by making a [`RequestParts`] and
running your extractor:
Extractors can also be run from middleware:
```rust
use axum::{
Router,
middleware::{self, Next},
extract::{RequestParts, TypedHeader},
extract::{TypedHeader, FromRequestParts},
http::{Request, StatusCode},
response::Response,
headers::authorization::{Authorization, Bearer},
@ -604,12 +654,11 @@ async fn auth_middleware<B>(
where
B: Send,
{
// running extractors requires a `RequestParts`
let mut request_parts = RequestParts::new(request);
// running extractors requires a `axum::http::request::Parts`
let (mut parts, body) = request.into_parts();
// `TypedHeader<Authorization<Bearer>>` extracts the auth token but
// `RequestParts::extract` works with anything that implements `FromRequest`
let auth = request_parts.extract::<TypedHeader<Authorization<Bearer>>>()
// `TypedHeader<Authorization<Bearer>>` extracts the auth token
let auth = TypedHeader::<Authorization<Bearer>>::from_request_parts(&mut parts, &())
.await
.map_err(|_| StatusCode::UNAUTHORIZED)?;
@ -617,14 +666,8 @@ where
return Err(StatusCode::UNAUTHORIZED);
}
// get the request back so we can run `next`
//
// `try_into_request` will fail if you have extracted the request body. We
// know that `TypedHeader` never does that.
//
// see the `consume-body-in-extractor-or-middleware` example if you need to
// extract the body
let request = request_parts.try_into_request().expect("body extracted");
// reconstruct the request
let request = Request::from_parts(parts, body);
Ok(next.run(request).await)
}
@ -638,8 +681,81 @@ let app = Router::new().layer(middleware::from_fn(auth_middleware));
# let _: Router<()> = app;
```
# Wrapping extractors
If you want write an extractor that generically wraps another extractor (that
may or may not consume the request body) you should implement both
[`FromRequest`] and [`FromRequestParts`]:
```rust
use axum::{
Router,
routing::get,
extract::{FromRequest, FromRequestParts},
http::{Request, HeaderMap, request::Parts},
async_trait,
};
use std::time::{Instant, Duration};
// an extractor that wraps another and measures how long time it takes to run
struct Timing<E> {
extractor: E,
duration: Duration,
}
// we must implement both `FromRequestParts`
#[async_trait]
impl<S, T> FromRequestParts<S> for Timing<T>
where
S: Send + Sync,
T: FromRequestParts<S>,
{
type Rejection = T::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let start = Instant::now();
let extractor = T::from_request_parts(parts, state).await?;
let duration = start.elapsed();
Ok(Timing {
extractor,
duration,
})
}
}
// and `FromRequest`
#[async_trait]
impl<S, B, T> FromRequest<S, B> for Timing<T>
where
B: Send + 'static,
S: Send + Sync,
T: FromRequest<S, B>,
{
type Rejection = T::Rejection;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let start = Instant::now();
let extractor = T::from_request(req, state).await?;
let duration = start.elapsed();
Ok(Timing {
extractor,
duration,
})
}
}
async fn handler(
// this uses the `FromRequestParts` impl
_: Timing<HeaderMap>,
// this uses the `FromRequest` impl
_: Timing<String>,
) {}
# let _: axum::routing::MethodRouter = axum::routing::get(handler);
```
[`body::Body`]: crate::body::Body
[customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs
[`HeaderMap`]: https://docs.rs/http/latest/http/header/struct.HeaderMap.html
[`Request`]: https://docs.rs/http/latest/http/struct.Request.html
[`RequestParts::body_mut`]: crate::extract::RequestParts::body_mut
[`JsonRejection::JsonDataError`]: rejection::JsonRejection::JsonDataError

View file

@ -1,8 +1,8 @@
#![doc = include_str!("../docs/error_handling.md")]
use crate::{
extract::{FromRequest, RequestParts},
http::{Request, StatusCode},
extract::FromRequestParts,
http::Request,
response::{IntoResponse, Response},
};
use std::{
@ -161,7 +161,7 @@ macro_rules! impl_service {
F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
Res: IntoResponse,
$( $ty: FromRequest<(), B> + Send,)*
$( $ty: FromRequestParts<()> + Send,)*
B: Send + 'static,
{
type Response = Response;
@ -181,21 +181,16 @@ macro_rules! impl_service {
let inner = std::mem::replace(&mut self.inner, clone);
let future = Box::pin(async move {
let mut req = RequestParts::new(req);
let (mut parts, body) = req.into_parts();
$(
let $ty = match $ty::from_request(&mut req).await {
let $ty = match $ty::from_request_parts(&mut parts, &()).await {
Ok(value) => value,
Err(rejection) => return Ok(rejection.into_response()),
};
)*
let req = match req.try_into_request() {
Ok(req) => req,
Err(err) => {
return Ok((StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response());
}
};
let req = Request::from_parts(parts, body);
match inner.oneshot(req).await {
Ok(res) => Ok(res.into_response()),

View file

@ -1,10 +1,10 @@
use crate::{
extract::{rejection::*, FromRequest, RequestParts},
response::IntoResponseParts,
};
use crate::{extract::rejection::*, response::IntoResponseParts};
use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response, ResponseParts};
use http::Request;
use axum_core::{
extract::FromRequestParts,
response::{IntoResponse, Response, ResponseParts},
};
use http::{request::Parts, Request};
use std::{
convert::Infallible,
ops::Deref,
@ -73,17 +73,16 @@ use tower_service::Service;
pub struct Extension<T>(pub T);
#[async_trait]
impl<T, S, B> FromRequest<S, B> for Extension<T>
impl<T, S> FromRequestParts<S> for Extension<T>
where
T: Clone + Send + Sync + 'static,
B: Send,
S: Send + Sync,
{
type Rejection = ExtensionRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let value = req
.extensions()
.extensions
.get::<T>()
.ok_or_else(|| {
MissingExtension::from_err(format!(

View file

@ -4,9 +4,10 @@
//!
//! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
use super::{Extension, FromRequest, RequestParts};
use super::{Extension, FromRequestParts};
use crate::middleware::AddExtension;
use async_trait::async_trait;
use http::request::Parts;
use hyper::server::conn::AddrStream;
use std::{
convert::Infallible,
@ -128,16 +129,15 @@ opaque_future! {
pub struct ConnectInfo<T>(pub T);
#[async_trait]
impl<S, B, T> FromRequest<S, B> for ConnectInfo<T>
impl<S, T> FromRequestParts<S> for ConnectInfo<T>
where
B: Send,
S: Send + Sync,
T: Clone + Send + Sync + 'static,
{
type Rejection = <Extension<Self> as FromRequest<S, B>>::Rejection;
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let Extension(connect_info) = Extension::<Self>::from_request(req).await?;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Extension(connect_info) = Extension::<Self>::from_request_parts(parts, state).await?;
Ok(connect_info)
}
}

View file

@ -1,7 +1,7 @@
use super::{rejection::*, FromRequest, RequestParts};
use super::{rejection::*, FromRequest};
use async_trait::async_trait;
use axum_core::response::IntoResponse;
use http::Method;
use axum_core::{extract::FromRequestParts, response::IntoResponse};
use http::{request::Parts, Method, Request};
use std::ops::Deref;
/// Extractor that will reject requests with a body larger than some size.
@ -40,25 +40,58 @@ impl<T, S, B, const N: u64> FromRequest<S, B> for ContentLengthLimit<T, N>
where
T: FromRequest<S, B>,
T::Rejection: IntoResponse,
B: Send,
B: Send + 'static,
S: Send + Sync,
{
type Rejection = ContentLengthLimitRejection<T::Rejection>;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let content_length = req
.headers()
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let (parts, body) = req.into_parts();
validate::<_, N>(&parts)?;
let req = Request::from_parts(parts, body);
let value = T::from_request(req, state)
.await
.map_err(ContentLengthLimitRejection::Inner)?;
Ok(Self(value))
}
}
#[async_trait]
impl<T, S, const N: u64> FromRequestParts<S> for ContentLengthLimit<T, N>
where
T: FromRequestParts<S>,
T::Rejection: IntoResponse,
S: Send + Sync,
{
type Rejection = ContentLengthLimitRejection<T::Rejection>;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
validate::<_, N>(parts)?;
let value = T::from_request_parts(parts, state)
.await
.map_err(ContentLengthLimitRejection::Inner)?;
Ok(Self(value))
}
}
fn validate<E, const N: u64>(parts: &Parts) -> Result<(), ContentLengthLimitRejection<E>> {
let content_length = parts
.headers
.get(http::header::CONTENT_LENGTH)
.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
match (content_length, req.method()) {
match (content_length, &parts.method) {
(content_length, &(Method::GET | Method::HEAD | Method::OPTIONS)) => {
if content_length.is_some() {
return Err(ContentLengthLimitRejection::ContentLengthNotAllowed(
ContentLengthNotAllowed,
));
} else if req
.headers()
} else if parts
.headers
.get(http::header::TRANSFER_ENCODING)
.map_or(false, |value| value.as_bytes() == b"chunked")
{
@ -76,12 +109,7 @@ where
_ => {}
}
let value = T::from_request(req)
.await
.map_err(ContentLengthLimitRejection::Inner)?;
Ok(Self(value))
}
Ok(())
}
impl<T, const N: u64> Deref for ContentLengthLimit<T, N> {

View file

@ -1,9 +1,12 @@
use super::{
rejection::{FailedToResolveHost, HostRejection},
FromRequest, RequestParts,
FromRequestParts,
};
use async_trait::async_trait;
use http::header::{HeaderMap, FORWARDED};
use http::{
header::{HeaderMap, FORWARDED},
request::Parts,
};
const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
@ -21,35 +24,34 @@ const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
pub struct Host(pub String);
#[async_trait]
impl<S, B> FromRequest<S, B> for Host
impl<S> FromRequestParts<S> for Host
where
B: Send,
S: Send + Sync,
{
type Rejection = HostRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if let Some(host) = parse_forwarded(req.headers()) {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(host) = parse_forwarded(&parts.headers) {
return Ok(Host(host.to_owned()));
}
if let Some(host) = req
.headers()
if let Some(host) = parts
.headers
.get(X_FORWARDED_HOST_HEADER_KEY)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
}
if let Some(host) = req
.headers()
if let Some(host) = parts
.headers
.get(http::header::HOST)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
}
if let Some(host) = req.uri().host() {
if let Some(host) = parts.uri.host() {
return Ok(Host(host.to_owned()));
}

View file

@ -1,5 +1,6 @@
use super::{rejection::*, FromRequest, RequestParts};
use super::{rejection::*, FromRequestParts};
use async_trait::async_trait;
use http::request::Parts;
use std::sync::Arc;
/// Access the path in the router that matches the request.
@ -64,16 +65,15 @@ impl MatchedPath {
}
#[async_trait]
impl<S, B> FromRequest<S, B> for MatchedPath
impl<S> FromRequestParts<S> for MatchedPath
where
B: Send,
S: Send + Sync,
{
type Rejection = MatchedPathRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let matched_path = req
.extensions()
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let matched_path = parts
.extensions
.get::<Self>()
.ok_or(MatchedPathRejection::MatchedPathMissing(MatchedPathMissing))?
.clone();

View file

@ -1,7 +1,6 @@
#![doc = include_str!("../docs/extract.md")]
use http::header;
use rejection::*;
use http::header::{self, HeaderMap};
pub mod connect_info;
pub mod path;
@ -17,7 +16,7 @@ mod request_parts;
mod state;
#[doc(inline)]
pub use axum_core::extract::{FromRef, FromRequest, RequestParts};
pub use axum_core::extract::{FromRef, FromRequest, FromRequestParts};
#[doc(inline)]
#[allow(deprecated)]
@ -75,16 +74,9 @@ pub use self::ws::WebSocketUpgrade;
#[doc(no_inline)]
pub use crate::TypedHeader;
pub(crate) fn take_body<S, B>(req: &mut RequestParts<S, B>) -> Result<B, BodyAlreadyExtracted> {
req.take_body().ok_or_else(BodyAlreadyExtracted::default)
}
// this is duplicated in `axum-extra/src/extract/form.rs`
pub(super) fn has_content_type<S, B>(
req: &RequestParts<S, B>,
expected_content_type: &mime::Mime,
) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
pub(super) fn has_content_type(headers: &HeaderMap, expected_content_type: &mime::Mime) -> bool {
let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
content_type
} else {
return false;

View file

@ -2,12 +2,13 @@
//!
//! See [`Multipart`] for more details.
use super::{rejection::*, BodyStream, FromRequest, RequestParts};
use super::{BodyStream, FromRequest};
use crate::body::{Bytes, HttpBody};
use crate::BoxError;
use async_trait::async_trait;
use futures_util::stream::Stream;
use http::header::{HeaderMap, CONTENT_TYPE};
use http::Request;
use std::{
fmt,
pin::Pin,
@ -58,10 +59,12 @@ where
{
type Rejection = MultipartRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let stream = BodyStream::from_request(req).await?;
let headers = req.headers();
let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
let stream = match BodyStream::from_request(req, state).await {
Ok(stream) => stream,
Err(err) => match err {},
};
let multipart = multer::Multipart::new(stream, boundary);
Ok(Self { inner: multipart })
}
@ -224,7 +227,6 @@ composite_rejection! {
///
/// Contains one variant for each way the [`Multipart`] extractor can fail.
pub enum MultipartRejection {
BodyAlreadyExtracted,
InvalidBoundary,
}
}

View file

@ -4,12 +4,12 @@
mod de;
use crate::{
extract::{rejection::*, FromRequest, RequestParts},
extract::{rejection::*, FromRequestParts},
routing::url_params::UrlParams,
};
use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response};
use http::StatusCode;
use http::{request::Parts, StatusCode};
use serde::de::DeserializeOwned;
use std::{
fmt,
@ -163,16 +163,15 @@ impl<T> DerefMut for Path<T> {
}
#[async_trait]
impl<T, S, B> FromRequest<S, B> for Path<T>
impl<T, S> FromRequestParts<S> for Path<T>
where
T: DeserializeOwned + Send,
B: Send,
S: Send + Sync,
{
type Rejection = PathRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let params = match req.extensions_mut().get::<UrlParams>() {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let params = match parts.extensions.get::<UrlParams>() {
Some(UrlParams::Params(params)) => params,
Some(UrlParams::InvalidUtf8InPathParam { key }) => {
let err = PathDeserializationError {
@ -413,8 +412,7 @@ impl std::error::Error for FailedToDeserializePathParams {}
mod tests {
use super::*;
use crate::{routing::get, test_helpers::*, Router};
use http::{Request, StatusCode};
use hyper::Body;
use http::StatusCode;
use std::collections::HashMap;
#[tokio::test]
@ -519,20 +517,6 @@ mod tests {
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn when_extensions_are_missing() {
let app = Router::new().route("/:key", get(|_: Request<Body>, _: Path<String>| async {}));
let client = TestClient::new(app);
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
res.text().await,
"No paths parameters found for matched route. Are you also extracting `Request<_>`?"
);
}
#[tokio::test]
async fn str_reference_deserialize() {
struct Param(String);

View file

@ -1,5 +1,6 @@
use super::{rejection::*, FromRequest, RequestParts};
use super::{rejection::*, FromRequestParts};
use async_trait::async_trait;
use http::request::Parts;
use serde::de::DeserializeOwned;
use std::ops::Deref;
@ -49,16 +50,15 @@ use std::ops::Deref;
pub struct Query<T>(pub T);
#[async_trait]
impl<T, S, B> FromRequest<S, B> for Query<T>
impl<T, S> FromRequestParts<S> for Query<T>
where
T: DeserializeOwned,
B: Send,
S: Send + Sync,
{
type Rejection = QueryRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let query = req.uri().query().unwrap_or_default();
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let query = parts.uri.query().unwrap_or_default();
let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?;
Ok(Query(value))
@ -76,15 +76,17 @@ impl<T> Deref for Query<T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::extract::RequestParts;
use axum_core::extract::FromRequest;
use http::Request;
use serde::Deserialize;
use std::fmt::Debug;
async fn check<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) {
async fn check<T>(uri: impl AsRef<str>, value: T)
where
T: DeserializeOwned + PartialEq + Debug,
{
let req = Request::builder().uri(uri.as_ref()).body(()).unwrap();
let mut req = RequestParts::new(req);
assert_eq!(Query::<T>::from_request(&mut req).await.unwrap().0, value);
assert_eq!(Query::<T>::from_request(req, &()).await.unwrap().0, value);
}
#[tokio::test]

View file

@ -1,5 +1,6 @@
use super::{FromRequest, RequestParts};
use super::FromRequestParts;
use async_trait::async_trait;
use http::request::Parts;
use std::convert::Infallible;
/// Extractor that extracts the raw query string, without parsing it.
@ -27,15 +28,14 @@ use std::convert::Infallible;
pub struct RawQuery(pub Option<String>);
#[async_trait]
impl<S, B> FromRequest<S, B> for RawQuery
impl<S> FromRequestParts<S> for RawQuery
where
B: Send,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let query = req.uri().query().map(|query| query.to_owned());
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let query = parts.uri.query().map(|query| query.to_owned());
Ok(Self(query))
}
}

View file

@ -73,7 +73,7 @@ define_rejection! {
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "No paths parameters found for matched route. Are you also extracting `Request<_>`?"]
#[body = "No paths parameters found for matched route"]
/// Rejection type used if axum's internal representation of path parameters
/// is missing. This is commonly caused by extracting `Request<_>`. `Path`
/// must be extracted first.

View file

@ -1,11 +1,11 @@
use super::{rejection::*, take_body, Extension, FromRequest, RequestParts};
use super::{Extension, FromRequest, FromRequestParts};
use crate::{
body::{Body, Bytes, HttpBody},
BoxError, Error,
};
use async_trait::async_trait;
use futures_util::stream::Stream;
use http::Uri;
use http::{request::Parts, Request, Uri};
use std::{
convert::Infallible,
fmt,
@ -86,17 +86,16 @@ pub struct OriginalUri(pub Uri);
#[cfg(feature = "original-uri")]
#[async_trait]
impl<S, B> FromRequest<S, B> for OriginalUri
impl<S> FromRequestParts<S> for OriginalUri
where
B: Send,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let uri = Extension::<Self>::from_request(req)
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let uri = Extension::<Self>::from_request_parts(parts, state)
.await
.unwrap_or_else(|_| Extension(OriginalUri(req.uri().clone())))
.unwrap_or_else(|_| Extension(OriginalUri(parts.uri.clone())))
.0;
Ok(uri)
}
@ -148,10 +147,11 @@ where
B::Error: Into<BoxError>,
S: Send + Sync,
{
type Rejection = BodyAlreadyExtracted;
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?
async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
let body = req
.into_body()
.map_data(Into::into)
.map_err(|err| Error::new(err.into()));
let stream = BodyStream(SyncWrapper::new(Box::pin(body)));
@ -203,40 +203,17 @@ where
B: Send,
S: Send + Sync,
{
type Rejection = BodyAlreadyExtracted;
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?;
Ok(Self(body))
async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
Ok(Self(req.into_body()))
}
}
#[cfg(test)]
mod tests {
use crate::{
body::Body,
extract::Extension,
routing::{get, post},
test_helpers::*,
Router,
};
use http::{Method, Request, StatusCode};
#[tokio::test]
async fn multiple_request_extractors() {
async fn handler(_: Request<Body>, _: Request<Body>) {}
let app = Router::new().route("/", post(handler));
let client = TestClient::new(app);
let res = client.post("/").body("hi there").send().await;
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
res.text().await,
"Cannot have two request body extractors for a single handler"
);
}
use crate::{extract::Extension, routing::get, test_helpers::*, Router};
use http::{Method, StatusCode};
#[tokio::test]
async fn extract_request_parts() {
@ -256,19 +233,4 @@ mod tests {
let res = client.get("/").header("x-foo", "123").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn extract_request_parts_doesnt_consume_the_body() {
#[derive(Clone)]
struct Ext;
async fn handler(_parts: http::request::Parts, body: String) {
assert_eq!(body, "foo");
}
let client = TestClient::new(Router::new().route("/", get(handler)));
let res = client.get("/").body("foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
}

View file

@ -1,5 +1,6 @@
use async_trait::async_trait;
use axum_core::extract::{FromRef, FromRequest, RequestParts};
use axum_core::extract::{FromRef, FromRequestParts};
use http::request::Parts;
use std::{
convert::Infallible,
ops::{Deref, DerefMut},
@ -139,7 +140,8 @@ use std::{
/// to do it:
///
/// ```rust
/// use axum_core::extract::{FromRequest, RequestParts, FromRef};
/// use axum_core::extract::{FromRequestParts, FromRef};
/// use http::request::Parts;
/// use async_trait::async_trait;
/// use std::convert::Infallible;
///
@ -147,9 +149,8 @@ use std::{
/// struct MyLibraryExtractor;
///
/// #[async_trait]
/// impl<S, B> FromRequest<S, B> for MyLibraryExtractor
/// impl<S> FromRequestParts<S> for MyLibraryExtractor
/// where
/// B: Send,
/// // keep `S` generic but require that it can produce a `MyLibraryState`
/// // this means users will have to implement `FromRef<UserState> for MyLibraryState`
/// MyLibraryState: FromRef<S>,
@ -157,9 +158,9 @@ use std::{
/// {
/// type Rejection = Infallible;
///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
/// // get a `MyLibraryState` from a reference to the state
/// let state = MyLibraryState::from_ref(req.state());
/// let state = MyLibraryState::from_ref(state);
///
/// // ...
/// # todo!()
@ -171,23 +172,22 @@ use std::{
/// // ...
/// }
/// ```
///
/// Note that you don't need to use the `State` extractor since you can access the state directly
/// from [`RequestParts`].
#[derive(Debug, Default, Clone, Copy)]
pub struct State<S>(pub S);
#[async_trait]
impl<B, OuterState, InnerState> FromRequest<OuterState, B> for State<InnerState>
impl<OuterState, InnerState> FromRequestParts<OuterState> for State<InnerState>
where
B: Send,
InnerState: FromRef<OuterState>,
OuterState: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<OuterState, B>) -> Result<Self, Self::Rejection> {
let inner_state = InnerState::from_ref(req.state());
async fn from_request_parts(
_parts: &mut Parts,
state: &OuterState,
) -> Result<Self, Self::Rejection> {
let inner_state = InnerState::from_ref(state);
Ok(Self(inner_state))
}
}

View file

@ -95,7 +95,7 @@
//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
use self::rejection::*;
use super::{FromRequest, RequestParts};
use super::FromRequestParts;
use crate::{
body::{self, Bytes},
response::Response,
@ -107,7 +107,8 @@ use futures_util::{
stream::{Stream, StreamExt},
};
use http::{
header::{self, HeaderName, HeaderValue},
header::{self, HeaderMap, HeaderName, HeaderValue},
request::Parts,
Method, StatusCode,
};
use hyper::upgrade::{OnUpgrade, Upgraded};
@ -275,41 +276,40 @@ impl WebSocketUpgrade {
}
#[async_trait]
impl<S, B> FromRequest<S, B> for WebSocketUpgrade
impl<S> FromRequestParts<S> for WebSocketUpgrade
where
B: Send,
S: Send + Sync,
{
type Rejection = WebSocketUpgradeRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if req.method() != Method::GET {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if parts.method != Method::GET {
return Err(MethodNotGet.into());
}
if !header_contains(req, header::CONNECTION, "upgrade") {
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into());
}
if !header_eq(req, header::UPGRADE, "websocket") {
if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into());
}
if !header_eq(req, header::SEC_WEBSOCKET_VERSION, "13") {
if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
return Err(InvalidWebSocketVersionHeader.into());
}
let sec_websocket_key = req
.headers_mut()
let sec_websocket_key = parts
.headers
.remove(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?;
let on_upgrade = req
.extensions_mut()
let on_upgrade = parts
.extensions
.remove::<OnUpgrade>()
.ok_or(ConnectionNotUpgradable)?;
let sec_websocket_protocol = req.headers().get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
Ok(Self {
config: Default::default(),
@ -321,16 +321,16 @@ where
}
}
fn header_eq<S, B>(req: &RequestParts<S, B>, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = req.headers().get(&key) {
fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = headers.get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
} else {
false
}
}
fn header_contains<S, B>(req: &RequestParts<S, B>, key: HeaderName, value: &'static str) -> bool {
let header = if let Some(header) = req.headers().get(&key) {
fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
let header = if let Some(header) = headers.get(&key) {
header
} else {
return false;

View file

@ -1,10 +1,10 @@
use crate::body::{Bytes, HttpBody};
use crate::extract::{has_content_type, rejection::*, FromRequest, RequestParts};
use crate::extract::{has_content_type, rejection::*, FromRequest};
use crate::BoxError;
use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response};
use http::header::CONTENT_TYPE;
use http::{Method, StatusCode};
use http::{Method, Request, StatusCode};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::ops::Deref;
@ -59,25 +59,25 @@ pub struct Form<T>(pub T);
impl<T, S, B> FromRequest<S, B> for Form<T>
where
T: DeserializeOwned,
B: HttpBody + Send,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send + Sync,
{
type Rejection = FormRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
if req.method() == Method::GET {
let query = req.uri().query().unwrap_or_default();
let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?;
Ok(Form(value))
} else {
if !has_content_type(req, &mime::APPLICATION_WWW_FORM_URLENCODED) {
if !has_content_type(req.headers(), &mime::APPLICATION_WWW_FORM_URLENCODED) {
return Err(InvalidFormContentType.into());
}
let bytes = Bytes::from_request(req).await?;
let bytes = Bytes::from_request(req, state).await?;
let value = serde_urlencoded::from_bytes(&bytes)
.map_err(FailedToDeserializeQueryString::__private_new)?;
@ -114,7 +114,6 @@ impl<T> Deref for Form<T> {
mod tests {
use super::*;
use crate::body::{Empty, Full};
use crate::extract::RequestParts;
use http::Request;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
@ -130,8 +129,7 @@ mod tests {
.uri(uri.as_ref())
.body(Empty::<Bytes>::new())
.unwrap();
let mut req = RequestParts::new(req);
assert_eq!(Form::<T>::from_request(&mut req).await.unwrap().0, value);
assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
}
async fn check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T) {
@ -146,8 +144,7 @@ mod tests {
serde_urlencoded::to_string(&value).unwrap().into(),
))
.unwrap();
let mut req = RequestParts::new(req);
assert_eq!(Form::<T>::from_request(&mut req).await.unwrap().0, value);
assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
}
#[tokio::test]
@ -216,9 +213,8 @@ mod tests {
.into(),
))
.unwrap();
let mut req = RequestParts::new(req);
assert!(matches!(
Form::<Pagination>::from_request(&mut req)
Form::<Pagination>::from_request(req, &())
.await
.unwrap_err(),
FormRejection::InvalidFormContentType(InvalidFormContentType)

View file

@ -88,7 +88,7 @@ where
use futures_util::future::FutureExt;
let handler = self.handler.clone();
let future = Handler::call(handler, Arc::clone(&self.state), req);
let future = Handler::call(handler, req, Arc::clone(&self.state));
let future = future.map(Ok as _);
super::future::IntoServiceFuture::new(future)

View file

@ -78,7 +78,7 @@ where
.expect("state extension missing. This is a bug in axum, please file an issue");
let handler = self.handler.clone();
let future = Handler::call(handler, state, req);
let future = Handler::call(handler, req, state);
let future = future.map(Ok as _);
super::future::IntoServiceFuture::new(future)

View file

@ -37,7 +37,7 @@
use crate::{
body::Body,
extract::{connect_info::IntoMakeServiceWithConnectInfo, FromRequest, RequestParts},
extract::{connect_info::IntoMakeServiceWithConnectInfo, FromRequest, FromRequestParts},
response::{IntoResponse, Response},
routing::IntoMakeService,
};
@ -95,12 +95,12 @@ pub use self::{into_service::IntoService, with_state::WithState};
/// {}
/// ```
#[doc = include_str!("../docs/debugging_handler_type_errors.md")]
pub trait Handler<T, S = (), B = Body>: Clone + Send + Sized + 'static {
pub trait Handler<T, S, B = Body>: Clone + Send + Sized + 'static {
/// The type of future calling this handler returns.
type Future: Future<Output = Response> + Send + 'static;
/// Call the handler with the given request.
fn call(self, state: Arc<S>, req: Request<B>) -> Self::Future;
fn call(self, req: Request<B>, state: Arc<S>) -> Self::Future;
/// Apply a [`tower::Layer`] to the handler.
///
@ -162,7 +162,7 @@ pub trait Handler<T, S = (), B = Body>: Clone + Send + Sized + 'static {
}
}
impl<F, Fut, Res, S, B> Handler<(), S, B> for F
impl<F, Fut, Res, S, B> Handler<((),), S, B> for F
where
F: FnOnce() -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
@ -171,37 +171,48 @@ where
{
type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, _state: Arc<S>, _req: Request<B>) -> Self::Future {
fn call(self, _req: Request<B>, _state: Arc<S>) -> Self::Future {
Box::pin(async move { self().await.into_response() })
}
}
macro_rules! impl_handler {
( $($ty:ident),* $(,)? ) => {
#[allow(non_snake_case)]
impl<F, Fut, S, B, Res, $($ty,)*> Handler<($($ty,)*), S, B> for F
(
[$($ty:ident),*], $last:ident
) => {
#[allow(non_snake_case, unused_mut)]
impl<F, Fut, S, B, Res, M, $($ty,)* $last> Handler<(M, $($ty,)* $last,), S, B> for F
where
F: FnOnce($($ty,)*) -> Fut + Clone + Send + 'static,
F: FnOnce($($ty,)* $last,) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
B: Send + 'static,
S: Send + Sync + 'static,
Res: IntoResponse,
$( $ty: FromRequest<S, B> + Send,)*
$( $ty: FromRequestParts<S> + Send, )*
$last: FromRequest<S, B, M> + Send,
{
type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, state: Arc<S>, req: Request<B>) -> Self::Future {
fn call(self, req: Request<B>, state: Arc<S>) -> Self::Future {
Box::pin(async move {
let mut req = RequestParts::with_state_arc(state, req);
let (mut parts, body) = req.into_parts();
let state = &state;
$(
let $ty = match $ty::from_request(&mut req).await {
let $ty = match $ty::from_request_parts(&mut parts, state).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
)*
let res = self($($ty,)*).await;
let req = Request::from_parts(parts, body);
let $last = match $last::from_request(req, state).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
let res = self($($ty,)* $last,).await;
res.into_response()
})
@ -210,7 +221,31 @@ macro_rules! impl_handler {
};
}
all_the_tuples!(impl_handler);
impl_handler!([], T1);
impl_handler!([T1], T2);
impl_handler!([T1, T2], T3);
impl_handler!([T1, T2, T3], T4);
impl_handler!([T1, T2, T3, T4], T5);
impl_handler!([T1, T2, T3, T4, T5], T6);
impl_handler!([T1, T2, T3, T4, T5, T6], T7);
impl_handler!([T1, T2, T3, T4, T5, T6, T7], T8);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8], T9);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12);
impl_handler!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13);
impl_handler!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13],
T14
);
impl_handler!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14],
T15
);
impl_handler!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15],
T16
);
/// A [`Service`] created from a [`Handler`] by applying a Tower middleware.
///
@ -259,7 +294,7 @@ where
{
type Future = future::LayeredFuture<B, L::Service>;
fn call(self, state: Arc<S>, req: Request<B>) -> Self::Future {
fn call(self, req: Request<B>, state: Arc<S>) -> Self::Future {
use futures_util::future::{FutureExt, Map};
let svc = self.handler.with_state_arc(state);

View file

@ -1,14 +1,14 @@
use crate::{
body::{Bytes, HttpBody},
extract::{rejection::*, FromRequest, RequestParts},
extract::{rejection::*, FromRequest},
BoxError,
};
use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response};
use bytes::{BufMut, BytesMut};
use http::{
header::{self, HeaderValue},
StatusCode,
header::{self, HeaderMap, HeaderValue},
Request, StatusCode,
};
use serde::{de::DeserializeOwned, Serialize};
use std::ops::{Deref, DerefMut};
@ -97,16 +97,16 @@ pub struct Json<T>(pub T);
impl<T, S, B> FromRequest<S, B> for Json<T>
where
T: DeserializeOwned,
B: HttpBody + Send,
B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send + Sync,
{
type Rejection = JsonRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if json_content_type(req) {
let bytes = Bytes::from_request(req).await?;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
if json_content_type(req.headers()) {
let bytes = Bytes::from_request(req, state).await?;
let value = match serde_json::from_slice(&bytes) {
Ok(value) => value,
@ -137,8 +137,8 @@ where
}
}
fn json_content_type<S, B>(req: &RequestParts<S, B>) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
fn json_content_type(headers: &HeaderMap) -> bool {
let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
content_type
} else {
return false;

View file

@ -93,8 +93,8 @@
//!
//! # Extractors
//!
//! An extractor is a type that implements [`FromRequest`]. Extractors is how
//! you pick apart the incoming request to get the parts your handler needs.
//! An extractor is a type that implements [`FromRequest`] or [`FromRequestParts`]. Extractors is
//! how you pick apart the incoming request to get the parts your handler needs.
//!
//! ```rust
//! use axum::extract::{Path, Query, Json};
@ -302,9 +302,10 @@
//!
//! # Building integrations for axum
//!
//! Libraries authors that want to provide [`FromRequest`] or [`IntoResponse`] implementations
//! should depend on the [`axum-core`] crate, instead of `axum` if possible. [`axum-core`] contains
//! core types and traits and is less likely to receive breaking changes.
//! Libraries authors that want to provide [`FromRequest`], [`FromRequestParts`], or
//! [`IntoResponse`] implementations should depend on the [`axum-core`] crate, instead of `axum` if
//! possible. [`axum-core`] contains core types and traits and is less likely to receive breaking
//! changes.
//!
//! # Required dependencies
//!
@ -376,6 +377,7 @@
//! [tower-guides]: https://github.com/tower-rs/tower/tree/master/guides
//! [`Uuid`]: https://docs.rs/uuid/latest/uuid/
//! [`FromRequest`]: crate::extract::FromRequest
//! [`FromRequestParts`]: crate::extract::FromRequestParts
//! [`HeaderMap`]: http::header::HeaderMap
//! [`Request`]: http::Request
//! [customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs

View file

@ -1,5 +1,5 @@
use crate::{
extract::{FromRequest, RequestParts},
extract::FromRequestParts,
response::{IntoResponse, Response},
};
use futures_util::{future::BoxFuture, ready};
@ -33,28 +33,27 @@ use tower_service::Service;
///
/// ```rust
/// use axum::{
/// extract::{FromRequest, RequestParts},
/// extract::FromRequestParts,
/// middleware::from_extractor,
/// routing::{get, post},
/// Router,
/// http::{header, StatusCode, request::Parts},
/// };
/// use http::{header, StatusCode};
/// use async_trait::async_trait;
///
/// // An extractor that performs authorization.
/// struct RequireAuth;
///
/// #[async_trait]
/// impl<S, B> FromRequest<S, B> for RequireAuth
/// impl<S> FromRequestParts<S> for RequireAuth
/// where
/// B: Send,
/// S: Send + Sync,
/// {
/// type Rejection = StatusCode;
///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// let auth_header = req
/// .headers()
/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
/// let auth_header = parts
/// .headers
/// .get(header::AUTHORIZATION)
/// .and_then(|value| value.to_str().ok());
///
@ -169,7 +168,7 @@ where
impl<S, E, B> Service<Request<B>> for FromExtractor<S, E>
where
E: FromRequest<(), B> + 'static,
E: FromRequestParts<()> + 'static,
B: Default + Send + 'static,
S: Service<Request<B>> + Clone,
S::Response: IntoResponse,
@ -185,8 +184,9 @@ where
fn call(&mut self, req: Request<B>) -> Self::Future {
let extract_future = Box::pin(async move {
let mut req = RequestParts::new(req);
let extracted = E::from_request(&mut req).await;
let (mut parts, body) = req.into_parts();
let extracted = E::from_request_parts(&mut parts, &()).await;
let req = Request::from_parts(parts, body);
(req, extracted)
});
@ -204,7 +204,7 @@ pin_project! {
#[allow(missing_debug_implementations)]
pub struct ResponseFuture<B, S, E>
where
E: FromRequest<(), B>,
E: FromRequestParts<()>,
S: Service<Request<B>>,
{
#[pin]
@ -217,11 +217,11 @@ pin_project! {
#[project = StateProj]
enum State<B, S, E>
where
E: FromRequest<(), B>,
E: FromRequestParts<()>,
S: Service<Request<B>>,
{
Extracting {
future: BoxFuture<'static, (RequestParts<(), B>, Result<E, E::Rejection>)>,
future: BoxFuture<'static, (Request<B>, Result<E, E::Rejection>)>,
},
Call { #[pin] future: S::Future },
}
@ -229,7 +229,7 @@ pin_project! {
impl<B, S, E> Future for ResponseFuture<B, S, E>
where
E: FromRequest<(), B>,
E: FromRequestParts<()>,
S: Service<Request<B>>,
S::Response: IntoResponse,
B: Default,
@ -247,7 +247,6 @@ where
match extracted {
Ok(_) => {
let mut svc = this.svc.take().expect("future polled after completion");
let req = req.try_into_request().unwrap_or_default();
let future = svc.call(req);
State::Call { future }
}
@ -273,23 +272,25 @@ where
mod tests {
use super::*;
use crate::{handler::Handler, routing::get, test_helpers::*, Router};
use http::{header, StatusCode};
use http::{header, request::Parts, StatusCode};
#[tokio::test]
async fn test_from_extractor() {
struct RequireAuth;
#[async_trait::async_trait]
impl<S, B> FromRequest<S, B> for RequireAuth
impl<S> FromRequestParts<S> for RequireAuth
where
B: Send,
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if let Some(auth) = req
.headers()
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
if let Some(auth) = parts
.headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
{

View file

@ -1,5 +1,5 @@
use crate::response::{IntoResponse, Response};
use axum_core::extract::{FromRequest, RequestParts};
use axum_core::extract::{FromRequest, FromRequestParts};
use futures_util::future::BoxFuture;
use http::Request;
use std::{
@ -249,12 +249,15 @@ where
}
macro_rules! impl_service {
( $($ty:ident),* $(,)? ) => {
#[allow(non_snake_case)]
impl<F, Fut, Out, S, B, $($ty,)*> Service<Request<B>> for FromFn<F, S, ($($ty,)*)>
(
[$($ty:ident),*], $last:ident
) => {
#[allow(non_snake_case, unused_mut)]
impl<F, Fut, Out, S, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, ($($ty,)* $last,)>
where
F: FnMut($($ty),*, Next<B>) -> Fut + Clone + Send + 'static,
$( $ty: FromRequest<(), B> + Send, )*
F: FnMut($($ty,)* $last, Next<B>) -> Fut + Clone + Send + 'static,
$( $ty: FromRequestParts<()> + Send, )*
$last: FromRequest<(), B> + Send,
Fut: Future<Output = Out> + Send + 'static,
Out: IntoResponse + 'static,
S: Service<Request<B>, Error = Infallible>
@ -280,21 +283,29 @@ macro_rules! impl_service {
let mut f = self.f.clone();
let future = Box::pin(async move {
let mut parts = RequestParts::new(req);
let (mut parts, body) = req.into_parts();
$(
let $ty = match $ty::from_request(&mut parts).await {
let $ty = match $ty::from_request_parts(&mut parts, &()).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
)*
let req = Request::from_parts(parts, body);
let $last = match $last::from_request(req, &()).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
let inner = ServiceBuilder::new()
.boxed_clone()
.map_response(IntoResponse::into_response)
.service(ready_inner);
let next = Next { inner };
f($($ty),*, next).await.into_response()
f($($ty,)* $last, next).await.into_response()
});
ResponseFuture {
@ -305,7 +316,31 @@ macro_rules! impl_service {
};
}
all_the_tuples!(impl_service);
impl_service!([], T1);
impl_service!([T1], T2);
impl_service!([T1, T2], T3);
impl_service!([T1, T2, T3], T4);
impl_service!([T1, T2, T3, T4], T5);
impl_service!([T1, T2, T3, T4, T5], T6);
impl_service!([T1, T2, T3, T4, T5, T6], T7);
impl_service!([T1, T2, T3, T4, T5, T6, T7], T8);
impl_service!([T1, T2, T3, T4, T5, T6, T7, T8], T9);
impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10);
impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11);
impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12);
impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13);
impl_service!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13],
T14
);
impl_service!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14],
T15
);
impl_service!(
[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15],
T16
);
impl<F, S, T> fmt::Debug for FromFn<F, S, T>
where

View file

@ -1,7 +1,8 @@
use crate::extract::{FromRequest, RequestParts};
use crate::extract::FromRequestParts;
use async_trait::async_trait;
use axum_core::response::{IntoResponse, IntoResponseParts, Response, ResponseParts};
use headers::HeaderMapExt;
use http::request::Parts;
use std::{convert::Infallible, ops::Deref};
/// Extractor and response that works with typed header values from [`headers`].
@ -52,16 +53,15 @@ use std::{convert::Infallible, ops::Deref};
pub struct TypedHeader<T>(pub T);
#[async_trait]
impl<T, S, B> FromRequest<S, B> for TypedHeader<T>
impl<T, S> FromRequestParts<S> for TypedHeader<T>
where
T: headers::Header,
B: Send,
S: Send + Sync,
{
type Rejection = TypedHeaderRejection;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
match req.headers().typed_try_get::<T>() {
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
match parts.headers.typed_try_get::<T>() {
Ok(Some(value)) => Ok(Self(value)),
Ok(None) => Err(TypedHeaderRejection {
name: T::name(),

View file

@ -7,7 +7,7 @@
use axum::{
async_trait,
body::{self, BoxBody, Bytes, Full},
extract::{FromRequest, RequestParts},
extract::FromRequest,
http::{Request, StatusCode},
middleware::{self, Next},
response::{IntoResponse, Response},
@ -72,31 +72,28 @@ fn do_thing_with_request_body(bytes: Bytes) {
tracing::debug!(body = ?bytes);
}
async fn handler(_: PrintRequestBody, body: Bytes) {
async fn handler(BufferRequestBody(body): BufferRequestBody) {
tracing::debug!(?body, "handler received body");
}
// extractor that shows how to consume the request body upfront
struct PrintRequestBody;
struct BufferRequestBody(Bytes);
// we must implement `FromRequest` (and not `FromRequestParts`) to consume the body
#[async_trait]
impl<S> FromRequest<S, BoxBody> for PrintRequestBody
impl<S> FromRequest<S, BoxBody> for BufferRequestBody
where
S: Clone + Send + Sync,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: &mut RequestParts<S, BoxBody>) -> Result<Self, Self::Rejection> {
let state = req.state().clone();
let request = Request::from_request(req)
async fn from_request(req: Request<BoxBody>, state: &S) -> Result<Self, Self::Rejection> {
let body = Bytes::from_request(req, state)
.await
.map_err(|err| err.into_response())?;
let request = buffer_request_body(request).await?;
do_thing_with_request_body(body.clone());
*req = RequestParts::with_state(state, request);
Ok(Self)
Ok(Self(body))
}
}

View file

@ -4,15 +4,13 @@
//! and `async/await`. This means that you can create more powerful rejections
//! - Boilerplate: Requires creating a new extractor for every custom rejection
//! - Complexity: Manually implementing `FromRequest` results on more complex code
use axum::extract::MatchedPath;
use axum::{
async_trait,
extract::{rejection::JsonRejection, FromRequest, RequestParts},
extract::{rejection::JsonRejection, FromRequest, FromRequestParts, MatchedPath},
http::Request,
http::StatusCode,
response::IntoResponse,
BoxError,
};
use serde::de::DeserializeOwned;
use serde_json::{json, Value};
pub async fn handler(Json(value): Json<Value>) -> impl IntoResponse {
@ -25,31 +23,33 @@ pub struct Json<T>(pub T);
#[async_trait]
impl<S, B, T> FromRequest<S, B> for Json<T>
where
axum::Json<T>: FromRequest<S, B, Rejection = JsonRejection>,
S: Send + Sync,
// these trait bounds are copied from `impl FromRequest for axum::Json`
// `T: Send` is required to send this future across an await
T: DeserializeOwned + Send,
B: axum::body::HttpBody + Send,
B::Data: Send,
B::Error: Into<BoxError>,
B: Send + 'static,
{
type Rejection = (StatusCode, axum::Json<Value>);
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
match axum::Json::<T>::from_request(req).await {
Ok(value) => Ok(Self(value.0)),
// convert the error from `axum::Json` into whatever we want
Err(rejection) => {
let path = req
.extract::<MatchedPath>()
.await
.map(|x| x.as_str().to_owned())
.ok();
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let (mut parts, body) = req.into_parts();
// We can use other extractors to provide better rejection
// messages. For example, here we are using
// `axum::extract::MatchedPath` to provide a better error
// message
//
// Have to run that first since `Json::from_request` consumes
// the request
let path = MatchedPath::from_request_parts(&mut parts, state)
.await
.map(|path| path.as_str().to_owned())
.ok();
let req = Request::from_parts(parts, body);
match axum::Json::<T>::from_request(req, state).await {
Ok(value) => Ok(Self(value.0)),
// convert the error from `axum::Json` into whatever we want
Err(rejection) => {
let payload = json!({
"message": rejection.to_string(),
"origin": "custom_extractor",

View file

@ -47,7 +47,7 @@ impl From<JsonRejection> for ApiError {
}
}
// We implement `IntoResponse` so ApiError can be used as a response
// We implement `IntoResponse` so `ApiError` can be used as a response
impl IntoResponse for ApiError {
fn into_response(self) -> axum::response::Response {
let payload = json!({

View file

@ -6,8 +6,8 @@
use axum::{
async_trait,
extract::{path::ErrorKind, rejection::PathRejection, FromRequest, RequestParts},
http::StatusCode,
extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts},
http::{request::Parts, StatusCode},
response::IntoResponse,
routing::get,
Router,
@ -52,17 +52,16 @@ struct Params {
struct Path<T>(T);
#[async_trait]
impl<S, B, T> FromRequest<S, B> for Path<T>
impl<S, T> FromRequestParts<S> for Path<T>
where
// these trait bounds are copied from `impl FromRequest for axum::extract::path::Path`
T: DeserializeOwned + Send,
B: Send,
S: Send + Sync,
{
type Rejection = (StatusCode, axum::Json<PathError>);
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
match axum::extract::Path::<T>::from_request(req).await {
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
match axum::extract::Path::<T>::from_request_parts(parts, state).await {
Ok(value) => Ok(Self(value.0)),
Err(rejection) => {
let (status, body) = match rejection {

View file

@ -65,8 +65,8 @@ async fn users_show(
/// Handler for `POST /users`.
async fn users_create(
Json(params): Json<CreateUser>,
State(user_repo): State<DynUserRepo>,
Json(params): Json<CreateUser>,
) -> Result<Json<User>, AppError> {
let user = user_repo.create(params).await?;

View file

@ -8,9 +8,9 @@
use axum::{
async_trait,
extract::{FromRequest, RequestParts, TypedHeader},
extract::{FromRequestParts, TypedHeader},
headers::{authorization::Bearer, Authorization},
http::StatusCode,
http::{request::Parts, StatusCode},
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
@ -122,17 +122,16 @@ impl AuthBody {
}
#[async_trait]
impl<S, B> FromRequest<S, B> for Claims
impl<S> FromRequestParts<S> for Claims
where
S: Send + Sync,
B: Send,
{
type Rejection = AuthError;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
// Extract the token from the authorization header
let TypedHeader(Authorization(bearer)) =
TypedHeader::<Authorization<Bearer>>::from_request(req)
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
.await
.map_err(|_| AuthError::InvalidToken)?;
// Decode the user data

View file

@ -96,8 +96,8 @@ async fn kv_get(
async fn kv_set(
Path(key): Path<String>,
ContentLengthLimit(bytes): ContentLengthLimit<Bytes, { 1024 * 5_000 }>, // ~5mb
State(state): State<SharedState>,
ContentLengthLimit(bytes): ContentLengthLimit<Bytes, { 1024 * 5_000 }>, // ~5mb
) {
state.write().unwrap().db.insert(key, bytes);
}

View file

@ -12,15 +12,14 @@ use async_session::{MemoryStore, Session, SessionStore};
use axum::{
async_trait,
extract::{
rejection::TypedHeaderRejectionReason, FromRef, FromRequest, Query, RequestParts, State,
TypedHeader,
rejection::TypedHeaderRejectionReason, FromRef, FromRequestParts, Query, State, TypedHeader,
},
http::{header::SET_COOKIE, HeaderMap},
response::{IntoResponse, Redirect, Response},
routing::get,
Router,
};
use http::header;
use http::{header, request::Parts};
use oauth2::{
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
@ -139,7 +138,7 @@ async fn discord_auth(State(client): State<BasicClient>) -> impl IntoResponse {
.url();
// Redirect to Discord's oauth service
Redirect::to(&auth_url.to_string())
Redirect::to(auth_url.as_ref())
}
// Valid user session required. If there is none, redirect to the auth page
@ -224,17 +223,18 @@ impl IntoResponse for AuthRedirect {
}
#[async_trait]
impl<B> FromRequest<AppState, B> for User
impl<S> FromRequestParts<S> for User
where
B: Send,
MemoryStore: FromRef<S>,
S: Send + Sync,
{
// If anything goes wrong or no session is found, redirect to the auth page
type Rejection = AuthRedirect;
async fn from_request(req: &mut RequestParts<AppState, B>) -> Result<Self, Self::Rejection> {
let store = req.state().clone().store;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let store = MemoryStore::from_ref(state);
let cookies = TypedHeader::<headers::Cookie>::from_request(req)
let cookies = TypedHeader::<headers::Cookie>::from_request_parts(parts, state)
.await
.map_err(|e| match *e.name() {
header::COOKIE => match e.reason() {

View file

@ -7,11 +7,12 @@
use async_session::{MemoryStore, Session, SessionStore as _};
use axum::{
async_trait,
extract::{FromRequest, RequestParts, TypedHeader},
extract::{FromRef, FromRequestParts, TypedHeader},
headers::Cookie,
http::{
self,
header::{HeaderMap, HeaderValue},
request::Parts,
StatusCode,
},
response::IntoResponse,
@ -80,16 +81,19 @@ enum UserIdFromSession {
}
#[async_trait]
impl<B> FromRequest<MemoryStore, B> for UserIdFromSession
impl<S> FromRequestParts<S> for UserIdFromSession
where
B: Send,
MemoryStore: FromRef<S>,
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request(req: &mut RequestParts<MemoryStore, B>) -> Result<Self, Self::Rejection> {
let store = req.state().clone();
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let store = MemoryStore::from_ref(state);
let cookie = req.extract::<Option<TypedHeader<Cookie>>>().await.unwrap();
let cookie = Option::<TypedHeader<Cookie>>::from_request_parts(parts, state)
.await
.unwrap();
let session_cookie = cookie
.as_ref()

View file

@ -15,8 +15,8 @@
use axum::{
async_trait,
extract::{FromRequest, RequestParts, State},
http::StatusCode,
extract::{FromRef, FromRequestParts, State},
http::{request::Parts, StatusCode},
routing::get,
Router,
};
@ -75,14 +75,15 @@ async fn using_connection_pool_extractor(
struct DatabaseConnection(sqlx::pool::PoolConnection<sqlx::Postgres>);
#[async_trait]
impl<B> FromRequest<PgPool, B> for DatabaseConnection
impl<S> FromRequestParts<S> for DatabaseConnection
where
B: Send,
PgPool: FromRef<S>,
S: Send + Sync,
{
type Rejection = (StatusCode, String);
async fn from_request(req: &mut RequestParts<PgPool, B>) -> Result<Self, Self::Rejection> {
let pool = req.state().clone();
async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let pool = PgPool::from_ref(state);
let conn = pool.acquire().await.map_err(internal_error)?;

Some files were not shown because too many files have changed in this diff Show more