Add type safe state extractor (#1155)

* begin threading the state through

* Pass state to extractors

* make state extractor work

* make sure nesting with different states work

* impl Service for MethodRouter<()>

* Fix some of axum-macro's tests

* Implement more traits for `State`

* Update examples to use `State`

* consistent naming of request body param

* swap type params

* Default the state param to ()

* fix docs references

* Docs and handler state refactoring

* docs clean ups

* more consistent naming

* when does MethodRouter implement Service?

* add missing docs

* use `Router`'s default state type param

* changelog

* don't use default type param for FromRequest and RequestParts

probably safer for library authors so you don't accidentally forget

* fix examples

* minor docs tweaks

* clarify how to convert handlers into services

* group methods in one impl block

* make sure merged `MethodRouter`s can access state

* fix docs link

* test merge with same state type

* Document how to access state from middleware

* Port cookie extractors to use state to extract keys (#1250)

* Updates ECOSYSTEM with a new sample project (#1252)

* Avoid unhelpful compiler suggestion (#1251)

* fix docs typo

* document how library authors should access state

* Add `RequestParts::with_state`

* fix example

* apply suggestions from review

* add relevant changes to axum-extra and axum-core changelogs

* Add `route_service_with_tsr`

* fix trybuild expectations

* make sure `SpaRouter` works with routers that have state

* Change order of type params on FromRequest and RequestParts

* reverse order of `RequestParts::with_state` args to match type params

* Add `FromRef` trait (#1268)

* Add `FromRef` trait

* Remove unnecessary type params

* format

* fix docs link

* format examples

* Avoid unnecessary `MethodRouter`

* apply suggestions from review

Co-authored-by: Dani Pardo <dani.pardo@inmensys.com>
Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
David Pedersen 2022-08-17 17:13:31 +02:00 committed by GitHub
parent 90dbd52ee4
commit 423308de3c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
132 changed files with 2404 additions and 1126 deletions

View file

@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased # Unreleased
- None. - **breaking:** `FromRequest` and `RequestParts` has a new `S` type param which
represents the state ([#1155])
[#1155]: https://github.com/tokio-rs/axum/pull/1155
# 0.2.6 (18. June, 2022) # 0.2.6 (18. June, 2022)

View file

@ -0,0 +1,23 @@
/// Used to do reference-to-value conversions thus not consuming the input value.
///
/// This is mainly used with [`State`] to extract "substates" from a reference to main application
/// state.
///
/// See [`State`] for more details on how library authors should use this trait.
///
/// [`State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html
// NOTE: This trait is defined in axum-core, even though it is mainly used with `State` which is
// defined in axum. That allows crate authors to use it when implementing extractors.
pub trait FromRef<T> {
/// Converts to this type from a reference to the input type.
fn from_ref(input: &T) -> Self;
}
impl<T> FromRef<T> for T
where
T: Clone,
{
fn from_ref(input: &T) -> Self {
input.clone()
}
}

View file

@ -12,9 +12,12 @@ use std::convert::Infallible;
pub mod rejection; pub mod rejection;
mod from_ref;
mod request_parts; mod request_parts;
mod tuple; mod tuple;
pub use self::from_ref::FromRef;
/// Types that can be created from requests. /// Types that can be created from requests.
/// ///
/// See [`axum::extract`] for more details. /// See [`axum::extract`] for more details.
@ -42,13 +45,15 @@ mod tuple;
/// struct MyExtractor; /// struct MyExtractor;
/// ///
/// #[async_trait] /// #[async_trait]
/// impl<B> FromRequest<B> for MyExtractor /// impl<S, B> FromRequest<S, B> for MyExtractor
/// where /// where
/// B: Send, // required by `async_trait` /// // these bounds are required by `async_trait`
/// B: Send,
/// S: Send,
/// { /// {
/// type Rejection = http::StatusCode; /// type Rejection = http::StatusCode;
/// ///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { /// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // ... /// // ...
/// # unimplemented!() /// # unimplemented!()
/// } /// }
@ -60,20 +65,21 @@ mod tuple;
/// [`http::Request<B>`]: http::Request /// [`http::Request<B>`]: http::Request
/// [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html /// [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html
#[async_trait] #[async_trait]
pub trait FromRequest<B>: Sized { pub trait FromRequest<S, B>: Sized {
/// If the extractor fails it'll use this "rejection" type. A rejection is /// If the extractor fails it'll use this "rejection" type. A rejection is
/// a kind of error that can be converted into a response. /// a kind of error that can be converted into a response.
type Rejection: IntoResponse; type Rejection: IntoResponse;
/// Perform the extraction. /// Perform the extraction.
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection>; async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection>;
} }
/// The type used with [`FromRequest`] to extract data from requests. /// The type used with [`FromRequest`] to extract data from requests.
/// ///
/// Has several convenience methods for getting owned parts of the request. /// Has several convenience methods for getting owned parts of the request.
#[derive(Debug)] #[derive(Debug)]
pub struct RequestParts<B> { pub struct RequestParts<S, B> {
state: S,
method: Method, method: Method,
uri: Uri, uri: Uri,
version: Version, version: Version,
@ -82,8 +88,8 @@ pub struct RequestParts<B> {
body: Option<B>, body: Option<B>,
} }
impl<B> RequestParts<B> { impl<B> RequestParts<(), B> {
/// Create a new `RequestParts`. /// Create a new `RequestParts` without any state.
/// ///
/// You generally shouldn't need to construct this type yourself, unless /// You generally shouldn't need to construct this type yourself, unless
/// using extractors outside of axum for example to implement a /// using extractors outside of axum for example to implement a
@ -91,6 +97,19 @@ impl<B> RequestParts<B> {
/// ///
/// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html
pub fn new(req: Request<B>) -> Self { pub fn new(req: Request<B>) -> Self {
Self::with_state((), req)
}
}
impl<S, B> RequestParts<S, B> {
/// Create a new `RequestParts` with the given state.
///
/// You generally shouldn't need to construct this type yourself, unless
/// using extractors outside of axum for example to implement a
/// [`tower::Service`].
///
/// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html
pub fn with_state(state: S, req: Request<B>) -> Self {
let ( let (
http::request::Parts { http::request::Parts {
method, method,
@ -104,6 +123,7 @@ impl<B> RequestParts<B> {
) = req.into_parts(); ) = req.into_parts();
RequestParts { RequestParts {
state,
method, method,
uri, uri,
version, version,
@ -130,10 +150,14 @@ impl<B> RequestParts<B> {
/// use http::{Method, Uri}; /// use http::{Method, Uri};
/// ///
/// #[async_trait] /// #[async_trait]
/// impl<B: Send> FromRequest<B> for MyExtractor { /// impl<S, B> FromRequest<S, B> for MyExtractor
/// where
/// B: Send,
/// S: Send,
/// {
/// type Rejection = Infallible; /// type Rejection = Infallible;
/// ///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Infallible> { /// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Infallible> {
/// let method = req.extract::<Method>().await?; /// let method = req.extract::<Method>().await?;
/// let path = req.extract::<Uri>().await?.path().to_owned(); /// let path = req.extract::<Uri>().await?.path().to_owned();
/// ///
@ -141,7 +165,10 @@ impl<B> RequestParts<B> {
/// } /// }
/// } /// }
/// ``` /// ```
pub async fn extract<E: FromRequest<B>>(&mut self) -> Result<E, E::Rejection> { pub async fn extract<E>(&mut self) -> Result<E, E::Rejection>
where
E: FromRequest<S, B>,
{
E::from_request(self).await E::from_request(self).await
} }
@ -153,6 +180,7 @@ impl<B> RequestParts<B> {
/// [`take_body`]: RequestParts::take_body /// [`take_body`]: RequestParts::take_body
pub fn try_into_request(self) -> Result<Request<B>, BodyAlreadyExtracted> { pub fn try_into_request(self) -> Result<Request<B>, BodyAlreadyExtracted> {
let Self { let Self {
state: _,
method, method,
uri, uri,
version, version,
@ -245,30 +273,37 @@ impl<B> RequestParts<B> {
pub fn take_body(&mut self) -> Option<B> { pub fn take_body(&mut self) -> Option<B> {
self.body.take() self.body.take()
} }
/// Get a reference to the state.
pub fn state(&self) -> &S {
&self.state
}
} }
#[async_trait] #[async_trait]
impl<T, B> FromRequest<B> for Option<T> impl<S, T, B> FromRequest<S, B> for Option<T>
where where
T: FromRequest<B>, T: FromRequest<S, B>,
B: Send, B: Send,
S: Send,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Option<T>, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Option<T>, Self::Rejection> {
Ok(T::from_request(req).await.ok()) Ok(T::from_request(req).await.ok())
} }
} }
#[async_trait] #[async_trait]
impl<T, B> FromRequest<B> for Result<T, T::Rejection> impl<S, T, B> FromRequest<S, B> for Result<T, T::Rejection>
where where
T: FromRequest<B>, T: FromRequest<S, B>,
B: Send, B: Send,
S: Send,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(T::from_request(req).await) Ok(T::from_request(req).await)
} }
} }

View file

@ -6,16 +6,18 @@ use http::{Extensions, HeaderMap, Method, Request, Uri, Version};
use std::convert::Infallible; use std::convert::Infallible;
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for Request<B> impl<S, B> FromRequest<S, B> for Request<B>
where where
B: Send, B: Send,
S: Clone + Send,
{ {
type Rejection = BodyAlreadyExtracted; type Rejection = BodyAlreadyExtracted;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let req = std::mem::replace( let req = std::mem::replace(
req, req,
RequestParts { RequestParts {
state: req.state().clone(),
method: req.method.clone(), method: req.method.clone(),
version: req.version, version: req.version,
uri: req.uri.clone(), uri: req.uri.clone(),
@ -30,37 +32,40 @@ where
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for Method impl<S, B> FromRequest<S, B> for Method
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.method().clone()) Ok(req.method().clone())
} }
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for Uri impl<S, B> FromRequest<S, B> for Uri
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.uri().clone()) Ok(req.uri().clone())
} }
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for Version impl<S, B> FromRequest<S, B> for Version
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.version()) Ok(req.version())
} }
} }
@ -71,27 +76,29 @@ where
/// ///
/// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html /// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for HeaderMap impl<S, B> FromRequest<S, B> for HeaderMap
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.headers().clone()) Ok(req.headers().clone())
} }
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for Bytes impl<S, B> FromRequest<S, B> for Bytes
where where
B: http_body::Body + Send, B: http_body::Body + Send,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send,
{ {
type Rejection = BytesRejection; type Rejection = BytesRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?; let body = take_body(req)?;
let bytes = crate::body::to_bytes(body) let bytes = crate::body::to_bytes(body)
@ -103,15 +110,16 @@ where
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for String impl<S, B> FromRequest<S, B> for String
where where
B: http_body::Body + Send, B: http_body::Body + Send,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send,
{ {
type Rejection = StringRejection; type Rejection = StringRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?; let body = take_body(req)?;
let bytes = crate::body::to_bytes(body) let bytes = crate::body::to_bytes(body)
@ -126,13 +134,14 @@ where
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for http::request::Parts impl<S, B> FromRequest<S, B> for http::request::Parts
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let method = unwrap_infallible(Method::from_request(req).await); let method = unwrap_infallible(Method::from_request(req).await);
let uri = unwrap_infallible(Uri::from_request(req).await); let uri = unwrap_infallible(Uri::from_request(req).await);
let version = unwrap_infallible(Version::from_request(req).await); let version = unwrap_infallible(Version::from_request(req).await);
@ -159,6 +168,6 @@ fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
} }
} }
pub(crate) fn take_body<B>(req: &mut RequestParts<B>) -> Result<B, BodyAlreadyExtracted> { pub(crate) fn take_body<S, B>(req: &mut RequestParts<S, B>) -> Result<B, BodyAlreadyExtracted> {
req.take_body().ok_or(BodyAlreadyExtracted) req.take_body().ok_or(BodyAlreadyExtracted)
} }

View file

@ -4,13 +4,14 @@ use async_trait::async_trait;
use std::convert::Infallible; use std::convert::Infallible;
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for () impl<S, B> FromRequest<S, B> for ()
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(_: &mut RequestParts<B>) -> Result<(), Self::Rejection> { async fn from_request(_: &mut RequestParts<S, B>) -> Result<(), Self::Rejection> {
Ok(()) Ok(())
} }
} }
@ -21,14 +22,15 @@ macro_rules! impl_from_request {
( $($ty:ident),* $(,)? ) => { ( $($ty:ident),* $(,)? ) => {
#[async_trait] #[async_trait]
#[allow(non_snake_case)] #[allow(non_snake_case)]
impl<B, $($ty,)*> FromRequest<B> for ($($ty,)*) impl<S, B, $($ty,)*> FromRequest<S, B> for ($($ty,)*)
where where
$( $ty: FromRequest<B> + Send, )* $( $ty: FromRequest<S, B> + Send, )*
B: Send, B: Send,
S: Send,
{ {
type Rejection = Response; type Rejection = Response;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
$( let $ty = $ty::from_request(req).await.map_err(|err| err.into_response())?; )* $( let $ty = $ty::from_request(req).await.map_err(|err| err.into_response())?; )*
Ok(($($ty,)*)) Ok(($($ty,)*))
} }

View file

@ -17,15 +17,21 @@ and this project adheres to [Semantic Versioning].
literal `Response` literal `Response`
- **added:** Support chaining handlers with `HandlerCallWithExtractors::or` ([#1170]) - **added:** Support chaining handlers with `HandlerCallWithExtractors::or` ([#1170])
- **change:** axum-extra's MSRV is now 1.60 ([#1239]) - **change:** axum-extra's MSRV is now 1.60 ([#1239])
- **breaking:** `SignedCookieJar` and `PrivateCookieJar` now extracts the keys
from the router's state, rather than extensions
- **added:** Add Protocol Buffer extractor and response ([#1239]) - **added:** Add Protocol Buffer extractor and response ([#1239])
- **added:** Add `Either*` types for combining extractors and responses into a - **added:** Add `Either*` types for combining extractors and responses into a
single type ([#1263]) single type ([#1263])
- **added:** `WithRejection` extractor for customizing other extractors' rejections ([#1262]) - **added:** `WithRejection` extractor for customizing other extractors' rejections ([#1262])
- **added:** Add sync constructors to `CookieJar`, `PrivateCookieJar`, and - **added:** Add sync constructors to `CookieJar`, `PrivateCookieJar`, and
`SignedCookieJar` so they're easier to use in custom middleware `SignedCookieJar` so they're easier to use in custom middleware
- **breaking:** `Resource` has a new `S` type param which represents the state ([#1155])
- **breaking:** `RouterExt::route_with_tsr` now only accepts `MethodRouter`s ([#1155])
- **added:** `RouterExt::route_service_with_tsr` for routing to any `Service` ([#1155])
[#1086]: https://github.com/tokio-rs/axum/pull/1086 [#1086]: https://github.com/tokio-rs/axum/pull/1086
[#1119]: https://github.com/tokio-rs/axum/pull/1119 [#1119]: https://github.com/tokio-rs/axum/pull/1119
[#1155]: https://github.com/tokio-rs/axum/pull/1155
[#1170]: https://github.com/tokio-rs/axum/pull/1170 [#1170]: https://github.com/tokio-rs/axum/pull/1170
[#1214]: https://github.com/tokio-rs/axum/pull/1214 [#1214]: https://github.com/tokio-rs/axum/pull/1214
[#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1239]: https://github.com/tokio-rs/axum/pull/1239

View file

@ -190,15 +190,16 @@ macro_rules! impl_traits_for_either {
$last:ident $(,)? $last:ident $(,)?
) => { ) => {
#[async_trait] #[async_trait]
impl<B, $($ident),*, $last> FromRequest<B> for $either<$($ident),*, $last> impl<S, B, $($ident),*, $last> FromRequest<S, B> for $either<$($ident),*, $last>
where where
$($ident: FromRequest<B>),*, $($ident: FromRequest<S, B>),*,
$last: FromRequest<B>, $last: FromRequest<S, B>,
B: Send, B: Send,
S: Send,
{ {
type Rejection = $last::Rejection; type Rejection = $last::Rejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
$( $(
if let Ok(value) = req.extract().await { if let Ok(value) = req.extract().await {
return Ok(Self::$ident(value)); return Ok(Self::$ident(value));

View file

@ -30,13 +30,14 @@ use std::ops::{Deref, DerefMut};
/// struct Session { /* ... */ } /// struct Session { /* ... */ }
/// ///
/// #[async_trait] /// #[async_trait]
/// impl<B> FromRequest<B> for Session /// impl<S, B> FromRequest<S, B> for Session
/// where /// where
/// B: Send, /// B: Send,
/// S: Send,
/// { /// {
/// type Rejection = (StatusCode, String); /// type Rejection = (StatusCode, String);
/// ///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { /// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // load session... /// // load session...
/// # unimplemented!() /// # unimplemented!()
/// } /// }
@ -45,13 +46,14 @@ use std::ops::{Deref, DerefMut};
/// struct CurrentUser { /* ... */ } /// struct CurrentUser { /* ... */ }
/// ///
/// #[async_trait] /// #[async_trait]
/// impl<B> FromRequest<B> for CurrentUser /// impl<S, B> FromRequest<S, B> for CurrentUser
/// where /// where
/// B: Send, /// B: Send,
/// S: Send,
/// { /// {
/// type Rejection = Response; /// type Rejection = Response;
/// ///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { /// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // loading a `CurrentUser` requires first loading the `Session` /// // loading a `CurrentUser` requires first loading the `Session`
/// // /// //
/// // by using `Cached<Session>` we avoid extracting the session more than /// // by using `Cached<Session>` we avoid extracting the session more than
@ -88,14 +90,15 @@ pub struct Cached<T>(pub T);
struct CachedEntry<T>(T); struct CachedEntry<T>(T);
#[async_trait] #[async_trait]
impl<B, T> FromRequest<B> for Cached<T> impl<S, B, T> FromRequest<S, B> for Cached<T>
where where
B: Send, B: Send,
T: FromRequest<B> + Clone + Send + Sync + 'static, S: Send,
T: FromRequest<S, B> + Clone + Send + Sync + 'static,
{ {
type Rejection = T::Rejection; type Rejection = T::Rejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
match Extension::<CachedEntry<T>>::from_request(req).await { match Extension::<CachedEntry<T>>::from_request(req).await {
Ok(Extension(CachedEntry(value))) => Ok(Self(value)), Ok(Extension(CachedEntry(value))) => Ok(Self(value)),
Err(_) => { Err(_) => {
@ -139,13 +142,14 @@ mod tests {
struct Extractor(Instant); struct Extractor(Instant);
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for Extractor impl<S, B> FromRequest<S, B> for Extractor
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
COUNTER.fetch_add(1, Ordering::SeqCst); COUNTER.fetch_add(1, Ordering::SeqCst);
Ok(Self(Instant::now())) Ok(Self(Instant::now()))
} }

View file

@ -80,7 +80,7 @@ pub use cookie::Key;
/// let app = Router::new() /// let app = Router::new()
/// .route("/sessions", post(create_session)) /// .route("/sessions", post(create_session))
/// .route("/me", get(me)); /// .route("/me", get(me));
/// # let app: Router<axum::body::Body> = app; /// # let app: Router = app;
/// ``` /// ```
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct CookieJar { pub struct CookieJar {
@ -88,13 +88,14 @@ pub struct CookieJar {
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for CookieJar impl<S, B> FromRequest<S, B> for CookieJar
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(Self::from_headers(req.headers())) Ok(Self::from_headers(req.headers()))
} }
} }
@ -226,7 +227,7 @@ fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use axum::{body::Body, http::Request, routing::get, Extension, Router}; use axum::{body::Body, extract::FromRef, http::Request, routing::get, Router};
use tower::ServiceExt; use tower::ServiceExt;
macro_rules! cookie_test { macro_rules! cookie_test {
@ -245,12 +246,15 @@ mod tests {
jar.remove(Cookie::named("key")) jar.remove(Cookie::named("key"))
} }
let app = Router::<Body>::new() let state = AppState {
key: Key::generate(),
custom_key: CustomKey(Key::generate()),
};
let app = Router::<_, Body>::with_state(state)
.route("/set", get(set_cookie)) .route("/set", get(set_cookie))
.route("/get", get(get_cookie)) .route("/get", get(get_cookie))
.route("/remove", get(remove_cookie)) .route("/remove", get(remove_cookie));
.layer(Extension(Key::generate()))
.layer(Extension(CustomKey(Key::generate())));
let res = app let res = app
.clone() .clone()
@ -298,6 +302,24 @@ mod tests {
cookie_test!(private_cookies, PrivateCookieJar); cookie_test!(private_cookies, PrivateCookieJar);
cookie_test!(private_cookies_with_custom_key, PrivateCookieJar<CustomKey>); cookie_test!(private_cookies_with_custom_key, PrivateCookieJar<CustomKey>);
#[derive(Clone)]
struct AppState {
key: Key,
custom_key: CustomKey,
}
impl FromRef<AppState> for Key {
fn from_ref(state: &AppState) -> Key {
state.key.clone()
}
}
impl FromRef<AppState> for CustomKey {
fn from_ref(state: &AppState) -> CustomKey {
state.custom_key.clone()
}
}
#[derive(Clone)] #[derive(Clone)]
struct CustomKey(Key); struct CustomKey(Key);
@ -313,9 +335,12 @@ mod tests {
format!("{:?}", jar.get("key")) format!("{:?}", jar.get("key"))
} }
let app = Router::<Body>::new() let state = AppState {
.route("/get", get(get_cookie)) key: Key::generate(),
.layer(Extension(Key::generate())); custom_key: CustomKey(Key::generate()),
};
let app = Router::<_, Body>::with_state(state).route("/get", get(get_cookie));
let res = app let res = app
.clone() .clone()

View file

@ -1,9 +1,8 @@
use super::{cookies_from_request, set_cookies, Cookie, Key}; use super::{cookies_from_request, set_cookies, Cookie, Key};
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRequest, RequestParts}, extract::{FromRef, FromRequest, RequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
Extension,
}; };
use cookie::PrivateJar; use cookie::PrivateJar;
use http::HeaderMap; use http::HeaderMap;
@ -23,9 +22,8 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// ```rust /// ```rust
/// use axum::{ /// use axum::{
/// Router, /// Router,
/// Extension,
/// routing::{post, get}, /// routing::{post, get},
/// extract::TypedHeader, /// extract::{TypedHeader, FromRef},
/// response::{IntoResponse, Redirect}, /// response::{IntoResponse, Redirect},
/// headers::authorization::{Authorization, Bearer}, /// headers::authorization::{Authorization, Bearer},
/// http::StatusCode, /// http::StatusCode,
@ -45,22 +43,36 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// } /// }
/// } /// }
/// ///
/// // Generate a secure key /// // our application state
/// // /// #[derive(Clone)]
/// // You probably don't wanna generate a new one each time the app starts though /// struct AppState {
/// let key = Key::generate(); /// // that holds the key used to sign cookies
/// key: Key,
/// }
/// ///
/// let app = Router::new() /// // this impl tells `SignedCookieJar` how to access the key from our state
/// impl FromRef<AppState> for Key {
/// fn from_ref(state: &AppState) -> Self {
/// state.key.clone()
/// }
/// }
///
/// let state = AppState {
/// // Generate a secure key
/// //
/// // You probably don't wanna generate a new one each time the app starts though
/// key: Key::generate(),
/// };
///
/// let app = Router::with_state(state)
/// .route("/set", post(set_secret)) /// .route("/set", post(set_secret))
/// .route("/get", get(get_secret)) /// .route("/get", get(get_secret));
/// // add extension with the key so `PrivateCookieJar` can access it /// # let app: Router<_> = app;
/// .layer(Extension(key));
/// # let app: Router<axum::body::Body> = app;
/// ``` /// ```
pub struct PrivateCookieJar<K = Key> { pub struct PrivateCookieJar<K = Key> {
jar: cookie::CookieJar, jar: cookie::CookieJar,
key: Key, key: Key,
// The key used to extract the key extension. Allows users to use multiple keys for different // The key used to extract the key. Allows users to use multiple keys for different
// jars. Maybe a library wants its own key. // jars. Maybe a library wants its own key.
_marker: PhantomData<K>, _marker: PhantomData<K>,
} }
@ -75,15 +87,17 @@ impl<K> fmt::Debug for PrivateCookieJar<K> {
} }
#[async_trait] #[async_trait]
impl<B, K> FromRequest<B> for PrivateCookieJar<K> impl<S, B, K> FromRequest<S, B> for PrivateCookieJar<K>
where where
B: Send, B: Send,
K: Into<Key> + Clone + Send + Sync + 'static, S: Send,
K: FromRef<S> + Into<Key>,
{ {
type Rejection = <axum::Extension<K> as FromRequest<B>>::Rejection; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let key = req.extract::<Extension<K>>().await?.0.into(); let k = K::from_ref(req.state());
let key = k.into();
let PrivateCookieJar { let PrivateCookieJar {
jar, jar,
key, key,

View file

@ -1,9 +1,8 @@
use super::{cookies_from_request, set_cookies}; use super::{cookies_from_request, set_cookies};
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRequest, RequestParts}, extract::{FromRef, FromRequest, RequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
Extension,
}; };
use cookie::SignedJar; use cookie::SignedJar;
use cookie::{Cookie, Key}; use cookie::{Cookie, Key};
@ -24,9 +23,8 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// ```rust /// ```rust
/// use axum::{ /// use axum::{
/// Router, /// Router,
/// Extension,
/// routing::{post, get}, /// routing::{post, get},
/// extract::TypedHeader, /// extract::{TypedHeader, FromRef},
/// response::{IntoResponse, Redirect}, /// response::{IntoResponse, Redirect},
/// headers::authorization::{Authorization, Bearer}, /// headers::authorization::{Authorization, Bearer},
/// http::StatusCode, /// http::StatusCode,
@ -63,22 +61,36 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// # todo!() /// # todo!()
/// } /// }
/// ///
/// // Generate a secure key /// // our application state
/// // /// #[derive(Clone)]
/// // You probably don't wanna generate a new one each time the app starts though /// struct AppState {
/// let key = Key::generate(); /// // that holds the key used to sign cookies
/// key: Key,
/// }
/// ///
/// let app = Router::new() /// // this impl tells `SignedCookieJar` how to access the key from our state
/// impl FromRef<AppState> for Key {
/// fn from_ref(state: &AppState) -> Self {
/// state.key.clone()
/// }
/// }
///
/// let state = AppState {
/// // Generate a secure key
/// //
/// // You probably don't wanna generate a new one each time the app starts though
/// key: Key::generate(),
/// };
///
/// let app = Router::with_state(state)
/// .route("/sessions", post(create_session)) /// .route("/sessions", post(create_session))
/// .route("/me", get(me)) /// .route("/me", get(me));
/// // add extension with the key so `SignedCookieJar` can access it /// # let app: Router<_> = app;
/// .layer(Extension(key));
/// # let app: Router<axum::body::Body> = app;
/// ``` /// ```
pub struct SignedCookieJar<K = Key> { pub struct SignedCookieJar<K = Key> {
jar: cookie::CookieJar, jar: cookie::CookieJar,
key: Key, key: Key,
// The key used to extract the key extension. Allows users to use multiple keys for different // The key used to extract the key. Allows users to use multiple keys for different
// jars. Maybe a library wants its own key. // jars. Maybe a library wants its own key.
_marker: PhantomData<K>, _marker: PhantomData<K>,
} }
@ -93,15 +105,17 @@ impl<K> fmt::Debug for SignedCookieJar<K> {
} }
#[async_trait] #[async_trait]
impl<B, K> FromRequest<B> for SignedCookieJar<K> impl<S, B, K> FromRequest<S, B> for SignedCookieJar<K>
where where
B: Send, B: Send,
K: Into<Key> + Clone + Send + Sync + 'static, S: Send,
K: FromRef<S> + Into<Key>,
{ {
type Rejection = <axum::Extension<K> as FromRequest<B>>::Rejection; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let key = req.extract::<Extension<K>>().await?.0.into(); let k = K::from_ref(req.state());
let key = k.into();
let SignedCookieJar { let SignedCookieJar {
jar, jar,
key, key,

View file

@ -55,16 +55,17 @@ impl<T> Deref for Form<T> {
} }
#[async_trait] #[async_trait]
impl<T, B> FromRequest<B> for Form<T> impl<T, S, B> FromRequest<S, B> for Form<T>
where where
T: DeserializeOwned, T: DeserializeOwned,
B: HttpBody + Send, B: HttpBody + Send,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send,
{ {
type Rejection = FormRejection; type Rejection = FormRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if req.method() == Method::GET { if req.method() == Method::GET {
let query = req.uri().query().unwrap_or_default(); let query = req.uri().query().unwrap_or_default();
let value = serde_html_form::from_str(query) let value = serde_html_form::from_str(query)
@ -85,7 +86,7 @@ where
} }
// this is duplicated in `axum/src/extract/mod.rs` // this is duplicated in `axum/src/extract/mod.rs`
fn has_content_type<B>(req: &RequestParts<B>, expected_content_type: &mime::Mime) -> bool { fn has_content_type<S, B>(req: &RequestParts<S, B>, expected_content_type: &mime::Mime) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
content_type content_type
} else { } else {

View file

@ -58,14 +58,15 @@ use std::ops::Deref;
pub struct Query<T>(pub T); pub struct Query<T>(pub T);
#[async_trait] #[async_trait]
impl<T, B> FromRequest<B> for Query<T> impl<T, S, B> FromRequest<S, B> for Query<T>
where where
T: DeserializeOwned, T: DeserializeOwned,
B: Send, B: Send,
S: Send,
{ {
type Rejection = QueryRejection; type Rejection = QueryRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let query = req.uri().query().unwrap_or_default(); let query = req.uri().query().unwrap_or_default();
let value = serde_html_form::from_str(query) let value = serde_html_form::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?; .map_err(FailedToDeserializeQueryString::__private_new)?;

View file

@ -107,15 +107,16 @@ impl<E, R> DerefMut for WithRejection<E, R> {
} }
#[async_trait] #[async_trait]
impl<B, E, R> FromRequest<B> for WithRejection<E, R> impl<B, E, R, S> FromRequest<S, B> for WithRejection<E, R>
where where
B: Send, B: Send,
E: FromRequest<B>, S: Send,
E: FromRequest<S, B>,
R: From<E::Rejection> + IntoResponse, R: From<E::Rejection> + IntoResponse,
{ {
type Rejection = R; type Rejection = R;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let extractor = req.extract::<E>().await?; let extractor = req.extract::<E>().await?;
Ok(WithRejection(extractor, PhantomData)) Ok(WithRejection(extractor, PhantomData))
} }
@ -134,10 +135,14 @@ mod tests {
struct TestRejection; struct TestRejection;
#[async_trait] #[async_trait]
impl<B: Send> FromRequest<B> for TestExtractor { impl<S, B> FromRequest<S, B> for TestExtractor
where
B: Send,
S: Send,
{
type Rejection = (); type Rejection = ();
async fn from_request(_: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(_: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Err(()) Err(())
} }
} }

View file

@ -19,15 +19,15 @@ pub use self::or::Or;
/// ///
/// The drawbacks of this trait is that you cannot apply middleware to individual handlers like you /// The drawbacks of this trait is that you cannot apply middleware to individual handlers like you
/// can with [`Handler::layer`]. /// can with [`Handler::layer`].
pub trait HandlerCallWithExtractors<T, B>: Sized { pub trait HandlerCallWithExtractors<T, S, B>: Sized {
/// The type of future calling this handler returns. /// The type of future calling this handler returns.
type Future: Future<Output = Response> + Send + 'static; type Future: Future<Output = Response> + Send + 'static;
/// Call the handler with the extracted inputs. /// Call the handler with the extracted inputs.
fn call(self, extractors: T) -> <Self as HandlerCallWithExtractors<T, B>>::Future; fn call(self, state: S, extractors: T) -> <Self as HandlerCallWithExtractors<T, S, B>>::Future;
/// Conver this `HandlerCallWithExtractors` into [`Handler`]. /// Conver this `HandlerCallWithExtractors` into [`Handler`].
fn into_handler(self) -> IntoHandler<Self, T, B> { fn into_handler(self) -> IntoHandler<Self, T, S, B> {
IntoHandler { IntoHandler {
handler: self, handler: self,
_marker: PhantomData, _marker: PhantomData,
@ -67,10 +67,14 @@ pub trait HandlerCallWithExtractors<T, B>: Sized {
/// struct AdminPermissions {} /// struct AdminPermissions {}
/// ///
/// #[async_trait] /// #[async_trait]
/// impl<B: Send> FromRequest<B> for AdminPermissions { /// impl<S, B> FromRequest<S, B> for AdminPermissions
/// where
/// B: Send,
/// S: Send,
/// {
/// // check for admin permissions... /// // check for admin permissions...
/// # type Rejection = (); /// # type Rejection = ();
/// # async fn from_request(req: &mut axum::extract::RequestParts<B>) -> Result<Self, Self::Rejection> { /// # async fn from_request(req: &mut axum::extract::RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// # todo!() /// # todo!()
/// # } /// # }
/// } /// }
@ -78,10 +82,14 @@ pub trait HandlerCallWithExtractors<T, B>: Sized {
/// struct User {} /// struct User {}
/// ///
/// #[async_trait] /// #[async_trait]
/// impl<B: Send> FromRequest<B> for User { /// impl<S, B> FromRequest<S, B> for User
/// where
/// B: Send,
/// S: Send,
/// {
/// // check for a logged in user... /// // check for a logged in user...
/// # type Rejection = (); /// # type Rejection = ();
/// # async fn from_request(req: &mut axum::extract::RequestParts<B>) -> Result<Self, Self::Rejection> { /// # async fn from_request(req: &mut axum::extract::RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// # todo!() /// # todo!()
/// # } /// # }
/// } /// }
@ -96,9 +104,9 @@ pub trait HandlerCallWithExtractors<T, B>: Sized {
/// ); /// );
/// # let _: Router = app; /// # let _: Router = app;
/// ``` /// ```
fn or<R, Rt>(self, rhs: R) -> Or<Self, R, T, Rt, B> fn or<R, Rt>(self, rhs: R) -> Or<Self, R, T, Rt, S, B>
where where
R: HandlerCallWithExtractors<Rt, B>, R: HandlerCallWithExtractors<Rt, S, B>,
{ {
Or { Or {
lhs: self, lhs: self,
@ -111,7 +119,7 @@ pub trait HandlerCallWithExtractors<T, B>: Sized {
macro_rules! impl_handler_call_with { macro_rules! impl_handler_call_with {
( $($ty:ident),* $(,)? ) => { ( $($ty:ident),* $(,)? ) => {
#[allow(non_snake_case)] #[allow(non_snake_case)]
impl<F, Fut, B, $($ty,)*> HandlerCallWithExtractors<($($ty,)*), B> for F impl<F, Fut, S, B, $($ty,)*> HandlerCallWithExtractors<($($ty,)*), S, B> for F
where where
F: FnOnce($($ty,)*) -> Fut, F: FnOnce($($ty,)*) -> Fut,
Fut: Future + Send + 'static, Fut: Future + Send + 'static,
@ -122,8 +130,9 @@ macro_rules! impl_handler_call_with {
fn call( fn call(
self, self,
_state: S,
($($ty,)*): ($($ty,)*), ($($ty,)*): ($($ty,)*),
) -> <Self as HandlerCallWithExtractors<($($ty,)*), B>>::Future { ) -> <Self as HandlerCallWithExtractors<($($ty,)*), S, B>>::Future {
self($($ty,)*).map(IntoResponse::into_response) self($($ty,)*).map(IntoResponse::into_response)
} }
} }
@ -152,34 +161,35 @@ impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13,
/// ///
/// Created with [`HandlerCallWithExtractors::into_handler`]. /// Created with [`HandlerCallWithExtractors::into_handler`].
#[allow(missing_debug_implementations)] #[allow(missing_debug_implementations)]
pub struct IntoHandler<H, T, B> { pub struct IntoHandler<H, T, S, B> {
handler: H, handler: H,
_marker: PhantomData<fn() -> (T, B)>, _marker: PhantomData<fn() -> (T, S, B)>,
} }
impl<H, T, B> Handler<T, B> for IntoHandler<H, T, B> impl<H, T, S, B> Handler<T, S, B> for IntoHandler<H, T, S, B>
where where
H: HandlerCallWithExtractors<T, B> + Clone + Send + 'static, H: HandlerCallWithExtractors<T, S, B> + Clone + Send + 'static,
T: FromRequest<B> + Send + 'static, T: FromRequest<S, B> + Send + 'static,
T::Rejection: Send, T::Rejection: Send,
B: Send + 'static, B: Send + 'static,
S: Clone + Send + 'static,
{ {
type Future = BoxFuture<'static, Response>; type Future = BoxFuture<'static, Response>;
fn call(self, req: http::Request<B>) -> Self::Future { fn call(self, state: S, req: http::Request<B>) -> Self::Future {
Box::pin(async move { Box::pin(async move {
let mut req = RequestParts::new(req); let mut req = RequestParts::with_state(state.clone(), req);
match req.extract::<T>().await { match req.extract::<T>().await {
Ok(t) => self.handler.call(t).await, Ok(t) => self.handler.call(state, t).await,
Err(rejection) => rejection.into_response(), Err(rejection) => rejection.into_response(),
} }
}) })
} }
} }
impl<H, T, B> Copy for IntoHandler<H, T, B> where H: Copy {} impl<H, T, S, B> Copy for IntoHandler<H, T, S, B> where H: Copy {}
impl<H, T, B> Clone for IntoHandler<H, T, B> impl<H, T, S, B> Clone for IntoHandler<H, T, S, B>
where where
H: Clone, H: Clone,
{ {

View file

@ -15,16 +15,16 @@ use std::{future::Future, marker::PhantomData};
/// ///
/// Created with [`HandlerCallWithExtractors::or`](super::HandlerCallWithExtractors::or). /// Created with [`HandlerCallWithExtractors::or`](super::HandlerCallWithExtractors::or).
#[allow(missing_debug_implementations)] #[allow(missing_debug_implementations)]
pub struct Or<L, R, Lt, Rt, B> { pub struct Or<L, R, Lt, Rt, S, B> {
pub(super) lhs: L, pub(super) lhs: L,
pub(super) rhs: R, pub(super) rhs: R,
pub(super) _marker: PhantomData<fn() -> (Lt, Rt, B)>, pub(super) _marker: PhantomData<fn() -> (Lt, Rt, S, B)>,
} }
impl<B, L, R, Lt, Rt> HandlerCallWithExtractors<Either<Lt, Rt>, B> for Or<L, R, Lt, Rt, B> impl<S, B, L, R, Lt, Rt> HandlerCallWithExtractors<Either<Lt, Rt>, S, B> for Or<L, R, Lt, Rt, S, B>
where where
L: HandlerCallWithExtractors<Lt, B> + Send + 'static, L: HandlerCallWithExtractors<Lt, S, B> + Send + 'static,
R: HandlerCallWithExtractors<Rt, B> + Send + 'static, R: HandlerCallWithExtractors<Rt, S, B> + Send + 'static,
Rt: Send + 'static, Rt: Send + 'static,
Lt: Send + 'static, Lt: Send + 'static,
B: Send + 'static, B: Send + 'static,
@ -37,46 +37,48 @@ where
fn call( fn call(
self, self,
state: S,
extractors: Either<Lt, Rt>, extractors: Either<Lt, Rt>,
) -> <Self as HandlerCallWithExtractors<Either<Lt, Rt>, B>>::Future { ) -> <Self as HandlerCallWithExtractors<Either<Lt, Rt>, S, B>>::Future {
match extractors { match extractors {
Either::E1(lt) => self Either::E1(lt) => self
.lhs .lhs
.call(lt) .call(state, lt)
.map(IntoResponse::into_response as _) .map(IntoResponse::into_response as _)
.left_future(), .left_future(),
Either::E2(rt) => self Either::E2(rt) => self
.rhs .rhs
.call(rt) .call(state, rt)
.map(IntoResponse::into_response as _) .map(IntoResponse::into_response as _)
.right_future(), .right_future(),
} }
} }
} }
impl<B, L, R, Lt, Rt> Handler<(Lt, Rt), B> for Or<L, R, Lt, Rt, B> impl<S, B, L, R, Lt, Rt> Handler<(Lt, Rt), S, B> for Or<L, R, Lt, Rt, S, B>
where where
L: HandlerCallWithExtractors<Lt, B> + Clone + Send + 'static, L: HandlerCallWithExtractors<Lt, S, B> + Clone + Send + 'static,
R: HandlerCallWithExtractors<Rt, B> + Clone + Send + 'static, R: HandlerCallWithExtractors<Rt, S, B> + Clone + Send + 'static,
Lt: FromRequest<B> + Send + 'static, Lt: FromRequest<S, B> + Send + 'static,
Rt: FromRequest<B> + Send + 'static, Rt: FromRequest<S, B> + Send + 'static,
Lt::Rejection: Send, Lt::Rejection: Send,
Rt::Rejection: Send, Rt::Rejection: Send,
B: Send + 'static, B: Send + 'static,
S: Clone + Send + 'static,
{ {
// this puts `futures_util` in our public API but thats fine in axum-extra // this puts `futures_util` in our public API but thats fine in axum-extra
type Future = BoxFuture<'static, Response>; type Future = BoxFuture<'static, Response>;
fn call(self, req: Request<B>) -> Self::Future { fn call(self, state: S, req: Request<B>) -> Self::Future {
Box::pin(async move { Box::pin(async move {
let mut req = RequestParts::new(req); let mut req = RequestParts::with_state(state.clone(), req);
if let Ok(lt) = req.extract::<Lt>().await { if let Ok(lt) = req.extract::<Lt>().await {
return self.lhs.call(lt).await; return self.lhs.call(state, lt).await;
} }
if let Ok(rt) = req.extract::<Rt>().await { if let Ok(rt) = req.extract::<Rt>().await {
return self.rhs.call(rt).await; return self.rhs.call(state, rt).await;
} }
StatusCode::NOT_FOUND.into_response() StatusCode::NOT_FOUND.into_response()
@ -84,14 +86,14 @@ where
} }
} }
impl<L, R, Lt, Rt, B> Copy for Or<L, R, Lt, Rt, B> impl<L, R, Lt, Rt, S, B> Copy for Or<L, R, Lt, Rt, S, B>
where where
L: Copy, L: Copy,
R: Copy, R: Copy,
{ {
} }
impl<L, R, Lt, Rt, B> Clone for Or<L, R, Lt, Rt, B> impl<L, R, Lt, Rt, S, B> Clone for Or<L, R, Lt, Rt, S, B>
where where
L: Clone, L: Clone,
R: Clone, R: Clone,

View file

@ -98,16 +98,17 @@ impl<S> JsonLines<S, AsResponse> {
} }
#[async_trait] #[async_trait]
impl<B, T> FromRequest<B> for JsonLines<T, AsExtractor> impl<S, B, T> FromRequest<S, B> for JsonLines<T, AsExtractor>
where where
B: HttpBody + Send + 'static, B: HttpBody + Send + 'static,
B::Data: Into<Bytes>, B::Data: Into<Bytes>,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
T: DeserializeOwned, T: DeserializeOwned,
S: Send,
{ {
type Rejection = BodyAlreadyExtracted; type Rejection = BodyAlreadyExtracted;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
// `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead` // `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead`
// so we can call `AsyncRead::lines` and then convert it back to a `Stream` // so we can call `AsyncRead::lines` and then convert it back to a `Stream`

View file

@ -97,16 +97,17 @@ use std::ops::{Deref, DerefMut};
pub struct ProtoBuf<T>(pub T); pub struct ProtoBuf<T>(pub T);
#[async_trait] #[async_trait]
impl<T, B> FromRequest<B> for ProtoBuf<T> impl<T, S, B> FromRequest<S, B> for ProtoBuf<T>
where where
T: Message + Default, T: Message + Default,
B: HttpBody + Send, B: HttpBody + Send,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send,
{ {
type Rejection = ProtoBufRejection; type Rejection = ProtoBufRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let mut bytes = Bytes::from_request(req).await?; let mut bytes = Bytes::from_request(req).await?;
match T::decode(&mut bytes) { match T::decode(&mut bytes) {

View file

@ -1,12 +1,13 @@
//! Additional types for defining routes. //! Additional types for defining routes.
use axum::{ use axum::{
handler::Handler, handler::{Handler, HandlerWithoutStateExt},
http::Request, http::Request,
response::{IntoResponse, Redirect}, response::{IntoResponse, Redirect},
routing::{any, MethodRouter},
Router, Router,
}; };
use std::{convert::Infallible, future::ready}; use std::{convert::Infallible, future::ready, sync::Arc};
use tower_service::Service; use tower_service::Service;
mod resource; mod resource;
@ -29,7 +30,7 @@ pub use self::typed::{FirstElementIs, TypedPath};
pub use self::spa::SpaRouter; pub use self::spa::SpaRouter;
/// Extension trait that adds additional methods to [`Router`]. /// Extension trait that adds additional methods to [`Router`].
pub trait RouterExt<B>: sealed::Sealed { pub trait RouterExt<S, B>: sealed::Sealed {
/// Add a typed `GET` route to the router. /// Add a typed `GET` route to the router.
/// ///
/// The path will be inferred from the first argument to the handler function which must /// The path will be inferred from the first argument to the handler function which must
@ -39,7 +40,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_get<H, T, P>(self, handler: H) -> Self fn typed_get<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
@ -52,7 +53,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_delete<H, T, P>(self, handler: H) -> Self fn typed_delete<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
@ -65,7 +66,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_head<H, T, P>(self, handler: H) -> Self fn typed_head<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
@ -78,7 +79,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_options<H, T, P>(self, handler: H) -> Self fn typed_options<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
@ -91,7 +92,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_patch<H, T, P>(self, handler: H) -> Self fn typed_patch<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
@ -104,7 +105,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_post<H, T, P>(self, handler: H) -> Self fn typed_post<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
@ -117,7 +118,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_put<H, T, P>(self, handler: H) -> Self fn typed_put<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
@ -130,7 +131,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_trace<H, T, P>(self, handler: H) -> Self fn typed_trace<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath; P: TypedPath;
@ -159,7 +160,14 @@ pub trait RouterExt<B>: sealed::Sealed {
/// .route_with_tsr("/bar/", get(|| async {})); /// .route_with_tsr("/bar/", get(|| async {}));
/// # let _: Router = app; /// # let _: Router = app;
/// ``` /// ```
fn route_with_tsr<T>(self, path: &str, service: T) -> Self fn route_with_tsr(self, path: &str, method_router: MethodRouter<S, B>) -> Self
where
Self: Sized;
/// Add another route to the router with an additional "trailing slash redirect" route.
///
/// This works like [`RouterExt::route_with_tsr`] but accepts any [`Service`].
fn route_service_with_tsr<T>(self, path: &str, service: T) -> Self
where where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse, T::Response: IntoResponse,
@ -167,14 +175,15 @@ pub trait RouterExt<B>: sealed::Sealed {
Self: Sized; Self: Sized;
} }
impl<B> RouterExt<B> for Router<B> impl<S, B> RouterExt<S, B> for Router<S, B>
where where
B: axum::body::HttpBody + Send + 'static, B: axum::body::HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{ {
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_get<H, T, P>(self, handler: H) -> Self fn typed_get<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
@ -184,7 +193,7 @@ where
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_delete<H, T, P>(self, handler: H) -> Self fn typed_delete<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
@ -194,7 +203,7 @@ where
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_head<H, T, P>(self, handler: H) -> Self fn typed_head<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
@ -204,7 +213,7 @@ where
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_options<H, T, P>(self, handler: H) -> Self fn typed_options<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
@ -214,7 +223,7 @@ where
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_patch<H, T, P>(self, handler: H) -> Self fn typed_patch<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
@ -224,7 +233,7 @@ where
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_post<H, T, P>(self, handler: H) -> Self fn typed_post<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
@ -234,7 +243,7 @@ where
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_put<H, T, P>(self, handler: H) -> Self fn typed_put<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
@ -244,41 +253,56 @@ where
#[cfg(feature = "typed-routing")] #[cfg(feature = "typed-routing")]
fn typed_trace<H, T, P>(self, handler: H) -> Self fn typed_trace<H, T, P>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static, T: FirstElementIs<P> + 'static,
P: TypedPath, P: TypedPath,
{ {
self.route(P::PATH, axum::routing::trace(handler)) self.route(P::PATH, axum::routing::trace(handler))
} }
fn route_with_tsr<T>(mut self, path: &str, service: T) -> Self fn route_with_tsr(mut self, path: &str, method_router: MethodRouter<S, B>) -> Self
where
Self: Sized,
{
self = self.route(path, method_router);
let redirect_service = {
let path: Arc<str> = path.into();
(move || ready(Redirect::permanent(&path))).into_service()
};
if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
self.route_service(path_without_trailing_slash, redirect_service)
} else {
self.route_service(&format!("{}/", path), redirect_service)
}
}
fn route_service_with_tsr<T>(mut self, path: &str, service: T) -> Self
where where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse, T::Response: IntoResponse,
T::Future: Send + 'static, T::Future: Send + 'static,
Self: Sized, Self: Sized,
{ {
self = self.route(path, service); self = self.route_service(path, service);
let redirect = Redirect::permanent(path); let redirect = Redirect::permanent(path);
if let Some(path_without_trailing_slash) = path.strip_suffix('/') { if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
self.route( self.route(
path_without_trailing_slash, path_without_trailing_slash,
(move || ready(redirect.clone())).into_service(), any(move || ready(redirect.clone())),
) )
} else { } else {
self.route( self.route(&format!("{}/", path), any(move || ready(redirect.clone())))
&format!("{}/", path),
(move || ready(redirect.clone())).into_service(),
)
} }
} }
} }
mod sealed { mod sealed {
pub trait Sealed {} pub trait Sealed {}
impl<B> Sealed for axum::Router<B> {} impl<S, B> Sealed for axum::Router<S, B> {}
} }
#[cfg(test)] #[cfg(test)]

View file

@ -1,13 +1,9 @@
use axum::{ use axum::{
body::Body, body::Body,
handler::Handler, handler::Handler,
http::Request, routing::{delete, get, on, post, MethodFilter, MethodRouter},
response::IntoResponse,
routing::{delete, get, on, post, MethodFilter},
Router, Router,
}; };
use std::{convert::Infallible, fmt};
use tower_service::Service;
/// A resource which defines a set of conventional CRUD routes. /// A resource which defines a set of conventional CRUD routes.
/// ///
@ -34,14 +30,15 @@ use tower_service::Service;
/// .destroy(|Path(user_id): Path<u64>| async {}); /// .destroy(|Path(user_id): Path<u64>| async {});
/// ///
/// let app = Router::new().merge(users); /// let app = Router::new().merge(users);
/// # let _: Router<axum::body::Body> = app; /// # let _: Router = app;
/// ``` /// ```
pub struct Resource<B = Body> { #[derive(Debug)]
pub struct Resource<S = (), B = Body> {
pub(crate) name: String, pub(crate) name: String,
pub(crate) router: Router<B>, pub(crate) router: Router<S, B>,
} }
impl<B> Resource<B> impl<B> Resource<(), B>
where where
B: axum::body::HttpBody + Send + 'static, B: axum::body::HttpBody + Send + 'static,
{ {
@ -49,16 +46,29 @@ where
/// ///
/// All routes will be nested at `/{resource_name}`. /// All routes will be nested at `/{resource_name}`.
pub fn named(resource_name: &str) -> Self { pub fn named(resource_name: &str) -> Self {
Self::named_with((), resource_name)
}
}
impl<S, B> Resource<S, B>
where
B: axum::body::HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
/// Create a `Resource` with the given name and state.
///
/// All routes will be nested at `/{resource_name}`.
pub fn named_with(state: S, resource_name: &str) -> Self {
Self { Self {
name: resource_name.to_owned(), name: resource_name.to_owned(),
router: Default::default(), router: Router::with_state(state),
} }
} }
/// Add a handler at `GET /{resource_name}`. /// Add a handler at `GET /{resource_name}`.
pub fn index<H, T>(self, handler: H) -> Self pub fn index<H, T>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: 'static, T: 'static,
{ {
let path = self.index_create_path(); let path = self.index_create_path();
@ -68,7 +78,7 @@ where
/// Add a handler at `POST /{resource_name}`. /// Add a handler at `POST /{resource_name}`.
pub fn create<H, T>(self, handler: H) -> Self pub fn create<H, T>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: 'static, T: 'static,
{ {
let path = self.index_create_path(); let path = self.index_create_path();
@ -78,7 +88,7 @@ where
/// Add a handler at `GET /{resource_name}/new`. /// Add a handler at `GET /{resource_name}/new`.
pub fn new<H, T>(self, handler: H) -> Self pub fn new<H, T>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: 'static, T: 'static,
{ {
let path = format!("/{}/new", self.name); let path = format!("/{}/new", self.name);
@ -88,7 +98,7 @@ where
/// Add a handler at `GET /{resource_name}/:{resource_name}_id`. /// Add a handler at `GET /{resource_name}/:{resource_name}_id`.
pub fn show<H, T>(self, handler: H) -> Self pub fn show<H, T>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: 'static, T: 'static,
{ {
let path = self.show_update_destroy_path(); let path = self.show_update_destroy_path();
@ -98,7 +108,7 @@ where
/// Add a handler at `GET /{resource_name}/:{resource_name}_id/edit`. /// Add a handler at `GET /{resource_name}/:{resource_name}_id/edit`.
pub fn edit<H, T>(self, handler: H) -> Self pub fn edit<H, T>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: 'static, T: 'static,
{ {
let path = format!("/{0}/:{0}_id/edit", self.name); let path = format!("/{0}/:{0}_id/edit", self.name);
@ -108,7 +118,7 @@ where
/// Add a handler at `PUT or PATCH /resource_name/:{resource_name}_id`. /// Add a handler at `PUT or PATCH /resource_name/:{resource_name}_id`.
pub fn update<H, T>(self, handler: H) -> Self pub fn update<H, T>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: 'static, T: 'static,
{ {
let path = self.show_update_destroy_path(); let path = self.show_update_destroy_path();
@ -118,7 +128,7 @@ where
/// Add a handler at `DELETE /{resource_name}/:{resource_name}_id`. /// Add a handler at `DELETE /{resource_name}/:{resource_name}_id`.
pub fn destroy<H, T>(self, handler: H) -> Self pub fn destroy<H, T>(self, handler: H) -> Self
where where
H: Handler<T, B>, H: Handler<T, S, B>,
T: 'static, T: 'static,
{ {
let path = self.show_update_destroy_path(); let path = self.show_update_destroy_path();
@ -133,13 +143,8 @@ where
format!("/{0}/:{0}_id", self.name) format!("/{0}/:{0}_id", self.name)
} }
fn route<T>(mut self, path: &str, svc: T) -> Self fn route(mut self, path: &str, method_router: MethodRouter<S, B>) -> Self {
where self.router = self.router.route(path, method_router);
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
self.router = self.router.route(path, svc);
self self
} }
} }
@ -150,21 +155,13 @@ impl<B> From<Resource<B>> for Router<B> {
} }
} }
impl<B> fmt::Debug for Resource<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Resource")
.field("name", &self.name)
.field("router", &self.router)
.finish()
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[allow(unused_imports)] #[allow(unused_imports)]
use super::*; use super::*;
use axum::{extract::Path, http::Method, Router}; use axum::{extract::Path, http::Method, Router};
use tower::ServiceExt; use http::Request;
use tower::{Service, ServiceExt};
#[tokio::test] #[tokio::test]
async fn works() { async fn works() {
@ -220,7 +217,7 @@ mod tests {
); );
} }
async fn call_route(app: &mut Router, method: Method, uri: &str) -> String { async fn call_route(app: &mut Router<()>, method: Method, uri: &str) -> String {
let res = app let res = app
.ready() .ready()
.await .await

View file

@ -36,7 +36,7 @@ use tower_service::Service;
/// .merge(spa) /// .merge(spa)
/// // we can still add other routes /// // we can still add other routes
/// .route("/api/foo", get(api_foo)); /// .route("/api/foo", get(api_foo));
/// # let _: Router<axum::body::Body> = app; /// # let _: Router = app;
/// ///
/// async fn api_foo() {} /// async fn api_foo() {}
/// ``` /// ```
@ -101,7 +101,7 @@ impl<B, T, F> SpaRouter<B, T, F> {
/// .index_file("another_file.html"); /// .index_file("another_file.html");
/// ///
/// let app = Router::new().merge(spa); /// let app = Router::new().merge(spa);
/// # let _: Router<axum::body::Body> = app; /// # let _: Router = app;
/// ``` /// ```
pub fn index_file<P>(mut self, path: P) -> Self pub fn index_file<P>(mut self, path: P) -> Self
where where
@ -136,7 +136,7 @@ impl<B, T, F> SpaRouter<B, T, F> {
/// } /// }
/// ///
/// let app = Router::new().merge(spa); /// let app = Router::new().merge(spa);
/// # let _: Router<axum::body::Body> = app; /// # let _: Router = app;
/// ``` /// ```
pub fn handle_error<T2, F2>(self, f: F2) -> SpaRouter<B, T2, F2> { pub fn handle_error<T2, F2>(self, f: F2) -> SpaRouter<B, T2, F2> {
SpaRouter { SpaRouter {
@ -147,7 +147,7 @@ impl<B, T, F> SpaRouter<B, T, F> {
} }
} }
impl<B, F, T> From<SpaRouter<B, T, F>> for Router<B> impl<B, F, T> From<SpaRouter<B, T, F>> for Router<(), B>
where where
F: Clone + Send + 'static, F: Clone + Send + 'static,
HandleError<Route<B, io::Error>, F, T>: Service<Request<B>, Error = Infallible>, HandleError<Route<B, io::Error>, F, T>: Service<Request<B>, Error = Infallible>,
@ -162,7 +162,7 @@ where
Router::new() Router::new()
.nest(&spa.paths.assets_path, assets_service) .nest(&spa.paths.assets_path, assets_service)
.fallback( .fallback_service(
get_service(ServeFile::new(&spa.paths.index_file)).handle_error(spa.handle_error), get_service(ServeFile::new(&spa.paths.index_file)).handle_error(spa.handle_error),
) )
} }
@ -264,6 +264,13 @@ mod tests {
let spa = SpaRouter::new("/assets", "test_files").handle_error(handle_error); let spa = SpaRouter::new("/assets", "test_files").handle_error(handle_error);
Router::<Body>::new().merge(spa); Router::<_, Body>::new().merge(spa);
}
#[allow(dead_code)]
fn works_with_router_with_state() {
let _: Router<String> = Router::with_state(String::new())
.merge(SpaRouter::new("/assets", "test_files"))
.route("/", get(|_: axum::extract::State<String>| async {}));
} }
} }

View file

@ -60,7 +60,7 @@ use http::Uri;
/// async fn users_destroy(_: UsersCollection) { /* ... */ } /// async fn users_destroy(_: UsersCollection) { /* ... */ }
/// ///
/// # /// #
/// # let app: Router<axum::body::Body> = app; /// # let app: Router = app;
/// ``` /// ```
/// ///
/// # Using `#[derive(TypedPath)]` /// # Using `#[derive(TypedPath)]`

View file

@ -203,7 +203,7 @@ fn check_inputs_impls_from_request(item_fn: &ItemFn, body_ty: &Type) -> TokenStr
#[allow(warnings)] #[allow(warnings)]
fn #name() fn #name()
where where
#ty: ::axum::extract::FromRequest<#body_ty> + Send, #ty: ::axum::extract::FromRequest<(), #body_ty> + Send,
{} {}
} }
}) })

View file

@ -218,16 +218,17 @@ fn impl_struct_by_extracting_each_field(
Ok(quote! { Ok(quote! {
#[::axum::async_trait] #[::axum::async_trait]
#[automatically_derived] #[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where where
B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
B::Data: ::std::marker::Send, B::Data: ::std::marker::Send,
B::Error: ::std::convert::Into<::axum::BoxError>, B::Error: ::std::convert::Into<::axum::BoxError>,
S: Send,
{ {
type Rejection = #rejection_ident; type Rejection = #rejection_ident;
async fn from_request( async fn from_request(
req: &mut ::axum::extract::RequestParts<B>, req: &mut ::axum::extract::RequestParts<S, B>,
) -> ::std::result::Result<Self, Self::Rejection> { ) -> ::std::result::Result<Self, Self::Rejection> {
::std::result::Result::Ok(Self { ::std::result::Result::Ok(Self {
#(#extract_fields)* #(#extract_fields)*
@ -422,7 +423,7 @@ fn extract_each_field_rejection(
Ok(quote_spanned! {ty_span=> Ok(quote_spanned! {ty_span=>
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
#variant_name(<#extractor_ty as ::axum::extract::FromRequest<::axum::body::Body>>::Rejection), #variant_name(<#extractor_ty as ::axum::extract::FromRequest<(), ::axum::body::Body>>::Rejection),
}) })
}) })
.collect::<syn::Result<Vec<_>>>()?; .collect::<syn::Result<Vec<_>>>()?;
@ -609,26 +610,26 @@ fn impl_struct_by_extracting_all_at_once(
quote! { #rejection } quote! { #rejection }
} else { } else {
quote! { quote! {
<#path<Self> as ::axum::extract::FromRequest<B>>::Rejection <#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection
} }
}; };
let rejection_bound = rejection.as_ref().map(|rejection| { let rejection_bound = rejection.as_ref().map(|rejection| {
if generic_ident.is_some() { if generic_ident.is_some() {
quote! { quote! {
#rejection: ::std::convert::From<<#path<T> as ::axum::extract::FromRequest<B>>::Rejection>, #rejection: ::std::convert::From<<#path<T> as ::axum::extract::FromRequest<S, B>>::Rejection>,
} }
} else { } else {
quote! { quote! {
#rejection: ::std::convert::From<<#path<Self> as ::axum::extract::FromRequest<B>>::Rejection>, #rejection: ::std::convert::From<<#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection>,
} }
} }
}).unwrap_or_default(); }).unwrap_or_default();
let impl_generics = if generic_ident.is_some() { let impl_generics = if generic_ident.is_some() {
quote! { B, T } quote! { S, B, T }
} else { } else {
quote! { B } quote! { S, B }
}; };
let type_generics = generic_ident let type_generics = generic_ident
@ -653,18 +654,19 @@ fn impl_struct_by_extracting_all_at_once(
Ok(quote_spanned! {path_span=> Ok(quote_spanned! {path_span=>
#[::axum::async_trait] #[::axum::async_trait]
#[automatically_derived] #[automatically_derived]
impl<#impl_generics> ::axum::extract::FromRequest<B> for #ident #type_generics impl<#impl_generics> ::axum::extract::FromRequest<S, B> for #ident #type_generics
where where
#path<#via_type_generics>: ::axum::extract::FromRequest<B>, #path<#via_type_generics>: ::axum::extract::FromRequest<S, B>,
#rejection_bound #rejection_bound
B: ::std::marker::Send, B: ::std::marker::Send,
S: ::std::marker::Send,
{ {
type Rejection = #associated_rejection_type; type Rejection = #associated_rejection_type;
async fn from_request( async fn from_request(
req: &mut ::axum::extract::RequestParts<B>, req: &mut ::axum::extract::RequestParts<S, B>,
) -> ::std::result::Result<Self, Self::Rejection> { ) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequest::<B>::from_request(req) ::axum::extract::FromRequest::<S, B>::from_request(req)
.await .await
.map(|#path(value)| #value_to_self) .map(|#path(value)| #value_to_self)
.map_err(::std::convert::From::from) .map_err(::std::convert::From::from)
@ -709,7 +711,7 @@ fn impl_enum_by_extracting_all_at_once(
quote! { #rejection } quote! { #rejection }
} else { } else {
quote! { quote! {
<#path<Self> as ::axum::extract::FromRequest<B>>::Rejection <#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection
} }
}; };
@ -718,18 +720,19 @@ fn impl_enum_by_extracting_all_at_once(
Ok(quote_spanned! {path_span=> Ok(quote_spanned! {path_span=>
#[::axum::async_trait] #[::axum::async_trait]
#[automatically_derived] #[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where where
B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
B::Data: ::std::marker::Send, B::Data: ::std::marker::Send,
B::Error: ::std::convert::Into<::axum::BoxError>, B::Error: ::std::convert::Into<::axum::BoxError>,
S: ::std::marker::Send,
{ {
type Rejection = #associated_rejection_type; type Rejection = #associated_rejection_type;
async fn from_request( async fn from_request(
req: &mut ::axum::extract::RequestParts<B>, req: &mut ::axum::extract::RequestParts<S, B>,
) -> ::std::result::Result<Self, Self::Rejection> { ) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequest::<B>::from_request(req) ::axum::extract::FromRequest::<S, B>::from_request(req)
.await .await
.map(|#path(inner)| inner) .map(|#path(inner)| inner)
.map_err(::std::convert::From::from) .map_err(::std::convert::From::from)

View file

@ -125,7 +125,7 @@ mod typed_path;
/// ``` /// ```
/// pub struct ViaExtractor<T>(pub T); /// pub struct ViaExtractor<T>(pub T);
/// ///
/// // impl<T, B> FromRequest<B> for ViaExtractor<T> { ... } /// // impl<T, S, B> FromRequest<S, B> for ViaExtractor<T> { ... }
/// ``` /// ```
/// ///
/// More complex via extractors are not supported and require writing a manual implementation. /// More complex via extractors are not supported and require writing a manual implementation.
@ -223,14 +223,15 @@ mod typed_path;
/// struct OtherExtractor; /// struct OtherExtractor;
/// ///
/// #[async_trait] /// #[async_trait]
/// impl<B> FromRequest<B> for OtherExtractor /// impl<S, B> FromRequest<S, B> for OtherExtractor
/// where /// where
/// B: Send + 'static, /// B: Send,
/// S: Send,
/// { /// {
/// // this rejection doesn't implement `Display` and `Error` /// // this rejection doesn't implement `Display` and `Error`
/// type Rejection = (StatusCode, String); /// type Rejection = (StatusCode, String);
/// ///
/// async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { /// async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // ... /// // ...
/// # unimplemented!() /// # unimplemented!()
/// } /// }

View file

@ -127,13 +127,14 @@ fn expand_named_fields(
let from_request_impl = quote! { let from_request_impl = quote! {
#[::axum::async_trait] #[::axum::async_trait]
#[automatically_derived] #[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = #rejection_assoc_type; type Rejection = #rejection_assoc_type;
async fn from_request(req: &mut ::axum::extract::RequestParts<B>) -> ::std::result::Result<Self, Self::Rejection> { async fn from_request(req: &mut ::axum::extract::RequestParts<S, B>) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request(req) ::axum::extract::Path::from_request(req)
.await .await
.map(|path| path.0) .map(|path| path.0)
@ -229,13 +230,14 @@ fn expand_unnamed_fields(
let from_request_impl = quote! { let from_request_impl = quote! {
#[::axum::async_trait] #[::axum::async_trait]
#[automatically_derived] #[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = #rejection_assoc_type; type Rejection = #rejection_assoc_type;
async fn from_request(req: &mut ::axum::extract::RequestParts<B>) -> ::std::result::Result<Self, Self::Rejection> { async fn from_request(req: &mut ::axum::extract::RequestParts<S, B>) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request(req) ::axum::extract::Path::from_request(req)
.await .await
.map(|path| path.0) .map(|path| path.0)
@ -310,13 +312,14 @@ fn expand_unit_fields(
let from_request_impl = quote! { let from_request_impl = quote! {
#[::axum::async_trait] #[::axum::async_trait]
#[automatically_derived] #[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = #rejection_assoc_type; type Rejection = #rejection_assoc_type;
async fn from_request(req: &mut ::axum::extract::RequestParts<B>) -> ::std::result::Result<Self, Self::Rejection> { async fn from_request(req: &mut ::axum::extract::RequestParts<S, B>) -> ::std::result::Result<Self, Self::Rejection> {
if req.uri().path() == <Self as ::axum_extra::routing::TypedPath>::PATH { if req.uri().path() == <Self as ::axum_extra::routing::TypedPath>::PATH {
Ok(Self) Ok(Self)
} else { } else {
@ -387,7 +390,7 @@ enum Segment {
fn path_rejection() -> TokenStream { fn path_rejection() -> TokenStream {
quote! { quote! {
<::axum::extract::Path<Self> as ::axum::extract::FromRequest<B>>::Rejection <::axum::extract::Path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection
} }
} }

View file

@ -1,17 +1,17 @@
error[E0277]: the trait bound `bool: FromRequest<Body>` is not satisfied error[E0277]: the trait bound `bool: FromRequest<(), Body>` is not satisfied
--> tests/debug_handler/fail/argument_not_extractor.rs:4:23 --> tests/debug_handler/fail/argument_not_extractor.rs:4:23
| |
4 | async fn handler(foo: bool) {} 4 | async fn handler(foo: bool) {}
| ^^^^ the trait `FromRequest<Body>` is not implemented for `bool` | ^^^^ the trait `FromRequest<(), Body>` is not implemented for `bool`
| |
= help: the following other types implement trait `FromRequest<B>`: = help: the following other types implement trait `FromRequest<S, B>`:
() <() as FromRequest<S, B>>
(T1, T2) <(T1, T2) as FromRequest<S, B>>
(T1, T2, T3) <(T1, T2, T3) as FromRequest<S, B>>
(T1, T2, T3, T4) <(T1, T2, T3, T4) as FromRequest<S, B>>
(T1, T2, T3, T4, T5) <(T1, T2, T3, T4, T5) as FromRequest<S, B>>
(T1, T2, T3, T4, T5, T6) <(T1, T2, T3, T4, T5, T6) as FromRequest<S, B>>
(T1, T2, T3, T4, T5, T6, T7) <(T1, T2, T3, T4, T5, T6, T7) as FromRequest<S, B>>
(T1, T2, T3, T4, T5, T6, T7, T8) <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequest<S, B>>
and 33 others and 34 others
= help: see issue #48214 = help: see issue #48214

View file

@ -7,13 +7,14 @@ use axum_macros::debug_handler;
struct A; struct A;
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for A impl<S, B> FromRequest<S, B> for A
where where
B: Send + 'static, B: Send,
S: Send,
{ {
type Rejection = (); type Rejection = ();
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
unimplemented!() unimplemented!()
} }
} }

View file

@ -1,5 +1,5 @@
error: Handlers must only take owned values error: Handlers must only take owned values
--> tests/debug_handler/fail/extract_self_mut.rs:23:22 --> tests/debug_handler/fail/extract_self_mut.rs:24:22
| |
23 | async fn handler(&mut self) {} 24 | async fn handler(&mut self) {}
| ^^^^^^^^^ | ^^^^^^^^^

View file

@ -7,13 +7,14 @@ use axum_macros::debug_handler;
struct A; struct A;
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for A impl<S, B> FromRequest<S, B> for A
where where
B: Send + 'static, B: Send,
S: Send,
{ {
type Rejection = (); type Rejection = ();
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
unimplemented!() unimplemented!()
} }
} }

View file

@ -1,5 +1,5 @@
error: Handlers must only take owned values error: Handlers must only take owned values
--> tests/debug_handler/fail/extract_self_ref.rs:23:22 --> tests/debug_handler/fail/extract_self_ref.rs:24:22
| |
23 | async fn handler(&self) {} 24 | async fn handler(&self) {}
| ^^^^^ | ^^^^^

View file

@ -120,13 +120,14 @@ impl A {
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for A impl<S, B> FromRequest<S, B> for A
where where
B: Send + 'static, B: Send,
S: Send,
{ {
type Rejection = (); type Rejection = ();
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
unimplemented!() unimplemented!()
} }
} }

View file

@ -7,13 +7,14 @@ use axum_macros::debug_handler;
struct A; struct A;
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for A impl<S, B> FromRequest<S, B> for A
where where
B: Send + 'static, B: Send,
S: Send,
{ {
type Rejection = (); type Rejection = ();
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
unimplemented!() unimplemented!()
} }
} }

View file

@ -1,8 +1,7 @@
use axum_macros::FromRequest; use axum_macros::FromRequest;
use axum::extract::Extension;
#[derive(FromRequest)] #[derive(FromRequest)]
struct Extractor(#[from_request(via(Extension), via(Extension))] State); struct Extractor(#[from_request(via(axum::Extension), via(axum::Extension))] State);
#[derive(Clone)] #[derive(Clone)]
struct State; struct State;

View file

@ -1,13 +1,5 @@
error: `via` specified more than once error: `via` specified more than once
--> tests/from_request/fail/double_via_attr.rs:5:49 --> tests/from_request/fail/double_via_attr.rs:4:55
| |
5 | struct Extractor(#[from_request(via(Extension), via(Extension))] State); 4 | struct Extractor(#[from_request(via(axum::Extension), via(axum::Extension))] State);
| ^^^ | ^^^
warning: unused import: `axum::extract::Extension`
--> tests/from_request/fail/double_via_attr.rs:2:5
|
2 | use axum::extract::Extension;
| ^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default

View file

@ -1,4 +1,4 @@
use axum::{body::Body, routing::get, Extension, Router}; use axum::{body::Body, routing::get, Router};
use axum_macros::FromRequest; use axum_macros::FromRequest;
#[derive(FromRequest, Clone)] #[derive(FromRequest, Clone)]
@ -7,5 +7,5 @@ struct Extractor<T>(T);
async fn foo(_: Extractor<()>) {} async fn foo(_: Extractor<()>) {}
fn main() { fn main() {
Router::<Body>::new().route("/", get(foo)); Router::<(), Body>::new().route("/", get(foo));
} }

View file

@ -4,23 +4,15 @@ error: #[derive(FromRequest)] only supports generics when used with #[from_reque
5 | struct Extractor<T>(T); 5 | struct Extractor<T>(T);
| ^ | ^
warning: unused import: `Extension` error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {foo}: Handler<_, _, _>` is not satisfied
--> tests/from_request/fail/generic_without_via.rs:1:38 --> tests/from_request/fail/generic_without_via.rs:10:46
|
1 | use axum::{body::Body, routing::get, Extension, Router};
| ^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default
error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {foo}: Handler<_, _>` is not satisfied
--> tests/from_request/fail/generic_without_via.rs:10:42
| |
10 | Router::<Body>::new().route("/", get(foo)); 10 | Router::<(), Body>::new().route("/", get(foo));
| --- ^^^ the trait `Handler<_, _>` is not implemented for `fn(Extractor<()>) -> impl Future<Output = ()> {foo}` | --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future<Output = ()> {foo}`
| | | |
| required by a bound introduced by this call | required by a bound introduced by this call
| |
= help: the trait `Handler<T, ReqBody>` is implemented for `Layered<S, T>` = help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
note: required by a bound in `axum::routing::get` note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs --> $WORKSPACE/axum/src/routing/method_routing.rs
| |

View file

@ -1,4 +1,4 @@
use axum::{body::Body, routing::get, Extension, Router}; use axum::{body::Body, routing::get, Router};
use axum_macros::FromRequest; use axum_macros::FromRequest;
#[derive(FromRequest, Clone)] #[derive(FromRequest, Clone)]
@ -8,5 +8,5 @@ struct Extractor<T>(T);
async fn foo(_: Extractor<()>) {} async fn foo(_: Extractor<()>) {}
fn main() { fn main() {
Router::<Body>::new().route("/", get(foo)); Router::<(), Body>::new().route("/", get(foo));
} }

View file

@ -4,23 +4,15 @@ error: #[derive(FromRequest)] only supports generics when used with #[from_reque
6 | struct Extractor<T>(T); 6 | struct Extractor<T>(T);
| ^ | ^
warning: unused import: `Extension` error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {foo}: Handler<_, _, _>` is not satisfied
--> tests/from_request/fail/generic_without_via_rejection.rs:1:38 --> tests/from_request/fail/generic_without_via_rejection.rs:11:46
|
1 | use axum::{body::Body, routing::get, Extension, Router};
| ^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default
error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {foo}: Handler<_, _>` is not satisfied
--> tests/from_request/fail/generic_without_via_rejection.rs:11:42
| |
11 | Router::<Body>::new().route("/", get(foo)); 11 | Router::<(), Body>::new().route("/", get(foo));
| --- ^^^ the trait `Handler<_, _>` is not implemented for `fn(Extractor<()>) -> impl Future<Output = ()> {foo}` | --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future<Output = ()> {foo}`
| | | |
| required by a bound introduced by this call | required by a bound introduced by this call
| |
= help: the trait `Handler<T, ReqBody>` is implemented for `Layered<S, T>` = help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
note: required by a bound in `axum::routing::get` note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs --> $WORKSPACE/axum/src/routing/method_routing.rs
| |

View file

@ -1,4 +1,4 @@
use axum::{body::Body, routing::get, Extension, Router}; use axum::{body::Body, routing::get, Router};
use axum_macros::FromRequest; use axum_macros::FromRequest;
#[derive(FromRequest, Clone)] #[derive(FromRequest, Clone)]
@ -8,5 +8,5 @@ struct Extractor<T>(T);
async fn foo(_: Extractor<()>) {} async fn foo(_: Extractor<()>) {}
fn main() { fn main() {
Router::<Body>::new().route("/", get(foo)); Router::<(), Body>::new().route("/", get(foo));
} }

View file

@ -4,23 +4,15 @@ error: #[derive(FromRequest)] only supports generics when used with #[from_reque
6 | struct Extractor<T>(T); 6 | struct Extractor<T>(T);
| ^ | ^
warning: unused import: `Extension` error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {foo}: Handler<_, _, _>` is not satisfied
--> tests/from_request/fail/generic_without_via_rejection_derive.rs:1:38 --> tests/from_request/fail/generic_without_via_rejection_derive.rs:11:46
|
1 | use axum::{body::Body, routing::get, Extension, Router};
| ^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default
error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {foo}: Handler<_, _>` is not satisfied
--> tests/from_request/fail/generic_without_via_rejection_derive.rs:11:42
| |
11 | Router::<Body>::new().route("/", get(foo)); 11 | Router::<(), Body>::new().route("/", get(foo));
| --- ^^^ the trait `Handler<_, _>` is not implemented for `fn(Extractor<()>) -> impl Future<Output = ()> {foo}` | --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future<Output = ()> {foo}`
| | | |
| required by a bound introduced by this call | required by a bound introduced by this call
| |
= help: the trait `Handler<T, ReqBody>` is implemented for `Layered<S, T>` = help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
note: required by a bound in `axum::routing::get` note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs --> $WORKSPACE/axum/src/routing/method_routing.rs
| |

View file

@ -4,15 +4,15 @@ error: cannot use `rejection` without `via`
18 | #[from_request(rejection(MyRejection))] 18 | #[from_request(rejection(MyRejection))]
| ^^^^^^^^^^^ | ^^^^^^^^^^^
error[E0277]: the trait bound `fn(MyExtractor) -> impl Future<Output = ()> {handler}: Handler<_, _>` is not satisfied error[E0277]: the trait bound `fn(MyExtractor) -> impl Future<Output = ()> {handler}: Handler<_, _, _>` is not satisfied
--> tests/from_request/fail/override_rejection_on_enum_without_via.rs:10:50 --> tests/from_request/fail/override_rejection_on_enum_without_via.rs:10:50
| |
10 | let _: Router = Router::new().route("/", get(handler).post(handler_result)); 10 | let _: Router = Router::new().route("/", get(handler).post(handler_result));
| --- ^^^^^^^ the trait `Handler<_, _>` is not implemented for `fn(MyExtractor) -> impl Future<Output = ()> {handler}` | --- ^^^^^^^ the trait `Handler<_, _, _>` is not implemented for `fn(MyExtractor) -> impl Future<Output = ()> {handler}`
| | | |
| required by a bound introduced by this call | required by a bound introduced by this call
| |
= help: the trait `Handler<T, ReqBody>` is implemented for `Layered<S, T>` = help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
note: required by a bound in `axum::routing::get` note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs --> $WORKSPACE/axum/src/routing/method_routing.rs
| |
@ -20,18 +20,18 @@ note: required by a bound in `axum::routing::get`
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `axum::routing::get` | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `axum::routing::get`
= note: this error originates in the macro `top_level_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info) = note: this error originates in the macro `top_level_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info)
error[E0277]: the trait bound `fn(Result<MyExtractor, MyRejection>) -> impl Future<Output = ()> {handler_result}: Handler<_, _>` is not satisfied error[E0277]: the trait bound `fn(Result<MyExtractor, MyRejection>) -> impl Future<Output = ()> {handler_result}: Handler<_, _, _>` is not satisfied
--> tests/from_request/fail/override_rejection_on_enum_without_via.rs:10:64 --> tests/from_request/fail/override_rejection_on_enum_without_via.rs:10:64
| |
10 | let _: Router = Router::new().route("/", get(handler).post(handler_result)); 10 | let _: Router = Router::new().route("/", get(handler).post(handler_result));
| ---- ^^^^^^^^^^^^^^ the trait `Handler<_, _>` is not implemented for `fn(Result<MyExtractor, MyRejection>) -> impl Future<Output = ()> {handler_result}` | ---- ^^^^^^^^^^^^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Result<MyExtractor, MyRejection>) -> impl Future<Output = ()> {handler_result}`
| | | |
| required by a bound introduced by this call | required by a bound introduced by this call
| |
= help: the trait `Handler<T, ReqBody>` is implemented for `Layered<S, T>` = help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
note: required by a bound in `MethodRouter::<B>::post` note: required by a bound in `MethodRouter::<S, B>::post`
--> $WORKSPACE/axum/src/routing/method_routing.rs --> $WORKSPACE/axum/src/routing/method_routing.rs
| |
| chained_handler_fn!(post, POST); | chained_handler_fn!(post, POST);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `MethodRouter::<B>::post` | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `MethodRouter::<S, B>::post`
= note: this error originates in the macro `chained_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info) = note: this error originates in the macro `chained_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info)

View file

@ -1,8 +1,7 @@
use axum_macros::FromRequest; use axum_macros::FromRequest;
use axum::extract::Extension;
#[derive(FromRequest, Clone)] #[derive(FromRequest, Clone)]
#[from_request(rejection_derive(!Error), via(Extension))] #[from_request(rejection_derive(!Error), via(axum::Extension))]
struct Extractor { struct Extractor {
config: String, config: String,
} }

View file

@ -1,13 +1,5 @@
error: cannot use both `rejection_derive` and `via` error: cannot use both `rejection_derive` and `via`
--> tests/from_request/fail/rejection_derive_and_via.rs:5:42 --> tests/from_request/fail/rejection_derive_and_via.rs:4:42
| |
5 | #[from_request(rejection_derive(!Error), via(Extension))] 4 | #[from_request(rejection_derive(!Error), via(axum::Extension))]
| ^^^ | ^^^
warning: unused import: `axum::extract::Extension`
--> tests/from_request/fail/rejection_derive_and_via.rs:2:5
|
2 | use axum::extract::Extension;
| ^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default

View file

@ -1,8 +1,7 @@
use axum_macros::FromRequest; use axum_macros::FromRequest;
use axum::extract::Extension;
#[derive(FromRequest, Clone)] #[derive(FromRequest, Clone)]
#[from_request(via(Extension), rejection_derive(!Error))] #[from_request(via(axum::Extension), rejection_derive(!Error))]
struct Extractor { struct Extractor {
config: String, config: String,
} }

View file

@ -1,13 +1,5 @@
error: cannot use both `via` and `rejection_derive` error: cannot use both `via` and `rejection_derive`
--> tests/from_request/fail/via_and_rejection_derive.rs:5:32 --> tests/from_request/fail/via_and_rejection_derive.rs:4:38
| |
5 | #[from_request(via(Extension), rejection_derive(!Error))] 4 | #[from_request(via(axum::Extension), rejection_derive(!Error))]
| ^^^^^^^^^^^^^^^^ | ^^^^^^^^^^^^^^^^
warning: unused import: `axum::extract::Extension`
--> tests/from_request/fail/via_and_rejection_derive.rs:2:5
|
2 | use axum::extract::Extension;
| ^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default

View file

@ -1,9 +1,8 @@
use axum_macros::FromRequest; use axum_macros::FromRequest;
use axum::extract::Extension;
#[derive(FromRequest)] #[derive(FromRequest)]
#[from_request(via(Extension))] #[from_request(via(axum::Extension))]
struct Extractor(#[from_request(via(Extension))] State); struct Extractor(#[from_request(via(axum::Extension))] State);
#[derive(Clone)] #[derive(Clone)]
struct State; struct State;

View file

@ -1,13 +1,5 @@
error: `#[from_request(via(...))]` on a field cannot be used together with `#[from_request(...)]` on the container error: `#[from_request(via(...))]` on a field cannot be used together with `#[from_request(...)]` on the container
--> tests/from_request/fail/via_on_container_and_field.rs:6:33 --> tests/from_request/fail/via_on_container_and_field.rs:5:33
| |
6 | struct Extractor(#[from_request(via(Extension))] State); 5 | struct Extractor(#[from_request(via(axum::Extension))] State);
| ^^^ | ^^^
warning: unused import: `axum::extract::Extension`
--> tests/from_request/fail/via_on_container_and_field.rs:2:5
|
2 | use axum::extract::Extension;
| ^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default

View file

@ -15,7 +15,7 @@ struct Extractor {
fn assert_from_request() fn assert_from_request()
where where
Extractor: FromRequest<Body, Rejection = JsonRejection>, Extractor: FromRequest<(), Body, Rejection = JsonRejection>,
{ {
} }

View file

@ -14,13 +14,14 @@ struct Extractor {
struct OtherExtractor; struct OtherExtractor;
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for OtherExtractor impl<S, B> FromRequest<S, B> for OtherExtractor
where where
B: Send + 'static, B: Send,
S: Send,
{ {
type Rejection = OtherExtractorRejection; type Rejection = OtherExtractorRejection;
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
unimplemented!() unimplemented!()
} }
} }

View file

@ -5,7 +5,7 @@ struct Extractor {}
fn assert_from_request() fn assert_from_request()
where where
Extractor: axum::extract::FromRequest<axum::body::Body, Rejection = std::convert::Infallible>, Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>,
{ {
} }

View file

@ -5,7 +5,7 @@ struct Extractor();
fn assert_from_request() fn assert_from_request()
where where
Extractor: axum::extract::FromRequest<axum::body::Body, Rejection = std::convert::Infallible>, Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>,
{ {
} }

View file

@ -8,5 +8,5 @@ enum Extractor {}
async fn foo(_: Extractor) {} async fn foo(_: Extractor) {}
fn main() { fn main() {
Router::<Body>::new().route("/", get(foo)); Router::<(), Body>::new().route("/", get(foo));
} }

View file

@ -18,7 +18,7 @@ struct Extractor {
fn assert_from_request() fn assert_from_request()
where where
Extractor: FromRequest<Body, Rejection = ExtractorRejection>, Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>,
{ {
} }

View file

@ -25,7 +25,7 @@ struct Extractor {
fn assert_from_request() fn assert_from_request()
where where
Extractor: FromRequest<Body, Rejection = ExtractorRejection>, Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>,
{ {
} }

View file

@ -28,14 +28,15 @@ struct MyExtractor {
struct OtherExtractor; struct OtherExtractor;
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for OtherExtractor impl<S, B> FromRequest<S, B> for OtherExtractor
where where
B: Send + 'static, B: Send + 'static,
S: Send,
{ {
// this rejection doesn't implement `Display` and `Error` // this rejection doesn't implement `Display` and `Error`
type Rejection = (StatusCode, String); type Rejection = (StatusCode, String);
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
todo!() todo!()
} }
} }

View file

@ -5,7 +5,7 @@ struct Extractor(axum::http::HeaderMap, String);
fn assert_from_request() fn assert_from_request()
where where
Extractor: axum::extract::FromRequest<axum::body::Body>, Extractor: axum::extract::FromRequest<(), axum::body::Body>,
{ {
} }

View file

@ -13,7 +13,7 @@ struct Payload {}
fn assert_from_request() fn assert_from_request()
where where
Extractor: axum::extract::FromRequest<axum::body::Body>, Extractor: axum::extract::FromRequest<(), axum::body::Body>,
{ {
} }

View file

@ -27,7 +27,7 @@ struct Payload {}
fn assert_from_request() fn assert_from_request()
where where
Extractor: axum::extract::FromRequest<axum::body::Body>, Extractor: axum::extract::FromRequest<(), axum::body::Body>,
{ {
} }

View file

@ -1,5 +1,5 @@
use axum::Extension;
use axum_macros::FromRequest; use axum_macros::FromRequest;
use axum::extract::Extension;
#[derive(FromRequest)] #[derive(FromRequest)]
struct Extractor(#[from_request(via(Extension))] State); struct Extractor(#[from_request(via(Extension))] State);
@ -9,7 +9,7 @@ struct State;
fn assert_from_request() fn assert_from_request()
where where
Extractor: axum::extract::FromRequest<axum::body::Body>, Extractor: axum::extract::FromRequest<(), axum::body::Body>,
{ {
} }

View file

@ -5,7 +5,7 @@ struct Extractor;
fn assert_from_request() fn assert_from_request()
where where
Extractor: axum::extract::FromRequest<axum::body::Body, Rejection = std::convert::Infallible>, Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>,
{ {
} }

View file

@ -15,5 +15,5 @@ error[E0277]: the trait bound `for<'de> MyPath: serde::de::Deserialize<'de>` is
(T0, T1, T2, T3) (T0, T1, T2, T3)
and 138 others and 138 others
= note: required because of the requirements on the impl of `serde::de::DeserializeOwned` for `MyPath` = note: required because of the requirements on the impl of `serde::de::DeserializeOwned` for `MyPath`
= note: required because of the requirements on the impl of `FromRequest<B>` for `axum::extract::Path<MyPath>` = note: required because of the requirements on the impl of `FromRequest<S, B>` for `axum::extract::Path<MyPath>`
= note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info) = note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info)

View file

@ -40,7 +40,7 @@ impl Default for MyRejection {
} }
fn main() { fn main() {
axum::Router::<axum::body::Body>::new() axum::Router::<(), axum::body::Body>::new()
.typed_get(|_: Result<MyPathNamed, MyRejection>| async {}) .typed_get(|_: Result<MyPathNamed, MyRejection>| async {})
.typed_post(|_: Result<MyPathUnnamed, MyRejection>| async {}) .typed_post(|_: Result<MyPathUnnamed, MyRejection>| async {})
.typed_put(|_: Result<MyPathUnit, MyRejection>| async {}); .typed_put(|_: Result<MyPathUnit, MyRejection>| async {});

View file

@ -9,7 +9,7 @@ struct MyPath {
} }
fn main() { fn main() {
axum::Router::<axum::body::Body>::new().route("/", axum::routing::get(|_: MyPath| async {})); axum::Router::<(), axum::body::Body>::new().route("/", axum::routing::get(|_: MyPath| async {}));
assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id"); assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id");
assert_eq!( assert_eq!(

View file

@ -20,7 +20,7 @@ struct UsersIndex;
async fn result_handler_unit_struct(_: Result<UsersIndex, StatusCode>) {} async fn result_handler_unit_struct(_: Result<UsersIndex, StatusCode>) {}
fn main() { fn main() {
axum::Router::<axum::body::Body>::new() axum::Router::<(), axum::body::Body>::new()
.typed_get(option_handler) .typed_get(option_handler)
.typed_post(result_handler) .typed_post(result_handler)
.typed_post(result_handler_unit_struct); .typed_post(result_handler_unit_struct);

View file

@ -8,7 +8,7 @@ pub type Result<T> = std::result::Result<T, ()>;
struct MyPath(u32, u32); struct MyPath(u32, u32);
fn main() { fn main() {
axum::Router::<axum::body::Body>::new().route("/", axum::routing::get(|_: MyPath| async {})); axum::Router::<(), axum::body::Body>::new().route("/", axum::routing::get(|_: MyPath| async {}));
assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id"); assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id");
assert_eq!(format!("{}", MyPath(1, 2)), "/users/1/teams/2"); assert_eq!(format!("{}", MyPath(1, 2)), "/users/1/teams/2");

View file

@ -5,7 +5,7 @@ use axum_extra::routing::TypedPath;
struct MyPath; struct MyPath;
fn main() { fn main() {
axum::Router::<axum::body::Body>::new() axum::Router::<(), axum::body::Body>::new()
.route("/", axum::routing::get(|_: MyPath| async {})); .route("/", axum::routing::get(|_: MyPath| async {}));
assert_eq!(MyPath::PATH, "/users"); assert_eq!(MyPath::PATH, "/users");

View file

@ -8,5 +8,5 @@ struct MyPath {
} }
fn main() { fn main() {
axum::Router::<axum::body::Body>::new().typed_get(|_: MyPath| async {}); axum::Router::<(), axum::body::Body>::new().typed_get(|_: MyPath| async {});
} }

View file

@ -35,6 +35,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Added `debug_handler` which is an attribute macro that improves - **added:** Added `debug_handler` which is an attribute macro that improves
type errors when applied to handler function. It is re-exported from type errors when applied to handler function. It is re-exported from
`axum-macros` `axum-macros`
- **added:** Added new type safe `State` extractor. This can be used with
`Router::with_state` and gives compile errors for missing states, whereas
`Extension` would result in runtime errors ([#1155])
- **breaking:** The following types or traits have a new `S` type param
(`()` by default) which represents the state ([#1155]):
- `FromRequest`
- `RequestParts`
- `Router`
- `MethodRouter`
- `Handler`
- **breaking:** `Router::route` now only accepts `MethodRouter`s created with
`get`, `post`, etc ([#1155])
- **added:** `Router::route_service` for routing to arbitrary `Service`s ([#1155])
- **added:** Support any middleware response that implements `IntoResponse` ([#1152]) - **added:** Support any middleware response that implements `IntoResponse` ([#1152])
- **breaking:** Require middleware added with `Handler::layer` to have - **breaking:** Require middleware added with `Handler::layer` to have
`Infallible` as the error type ([#1152]) `Infallible` as the error type ([#1152])
@ -54,6 +67,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1130]: https://github.com/tokio-rs/axum/pull/1130 [#1130]: https://github.com/tokio-rs/axum/pull/1130
[#1135]: https://github.com/tokio-rs/axum/pull/1135 [#1135]: https://github.com/tokio-rs/axum/pull/1135
[#1152]: https://github.com/tokio-rs/axum/pull/1152 [#1152]: https://github.com/tokio-rs/axum/pull/1152
[#1155]: https://github.com/tokio-rs/axum/pull/1155
[#1171]: https://github.com/tokio-rs/axum/pull/1171 [#1171]: https://github.com/tokio-rs/axum/pull/1171
[#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1239]: https://github.com/tokio-rs/axum/pull/1239
[#1248]: https://github.com/tokio-rs/axum/pull/1248 [#1248]: https://github.com/tokio-rs/axum/pull/1248

View file

@ -69,7 +69,7 @@ let some_fallible_service = tower::service_fn(|_req| async {
Ok::<_, anyhow::Error>(Response::new(Body::empty())) Ok::<_, anyhow::Error>(Response::new(Body::empty()))
}); });
let app = Router::new().route( let app = Router::new().route_service(
"/", "/",
// we cannot route to `some_fallible_service` directly since it might fail. // we cannot route to `some_fallible_service` directly since it might fail.
// we have to use `handle_error` which converts its errors into responses // we have to use `handle_error` which converts its errors into responses

View file

@ -421,13 +421,14 @@ use http::{StatusCode, header::{HeaderValue, USER_AGENT}};
struct ExtractUserAgent(HeaderValue); struct ExtractUserAgent(HeaderValue);
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for ExtractUserAgent impl<S, B> FromRequest<S, B> for ExtractUserAgent
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = (StatusCode, &'static str); type Rejection = (StatusCode, &'static str);
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if let Some(user_agent) = req.headers().get(USER_AGENT) { if let Some(user_agent) = req.headers().get(USER_AGENT) {
Ok(ExtractUserAgent(user_agent.clone())) Ok(ExtractUserAgent(user_agent.clone()))
} else { } else {
@ -472,13 +473,14 @@ struct AuthenticatedUser {
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for AuthenticatedUser impl<S, B> FromRequest<S, B> for AuthenticatedUser
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = Response; type Rejection = Response;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(token)) = let TypedHeader(Authorization(token)) =
TypedHeader::<Authorization<Bearer>>::from_request(req) TypedHeader::<Authorization<Bearer>>::from_request(req)
.await .await
@ -633,7 +635,7 @@ fn token_is_valid(token: &str) -> bool {
} }
let app = Router::new().layer(middleware::from_fn(auth_middleware)); let app = Router::new().layer(middleware::from_fn(auth_middleware));
# let _: Router = app; # let _: Router<()> = app;
``` ```
[`body::Body`]: crate::body::Body [`body::Body`]: crate::body::Body

View file

@ -11,7 +11,7 @@ use axum::{
http::{StatusCode, Method, Uri}, http::{StatusCode, Method, Uri},
}; };
let handler = get(|| async {}).fallback(fallback.into_service()); let handler = get(|| async {}).fallback(fallback);
let app = Router::new().route("/", handler); let app = Router::new().route("/", handler);
@ -36,11 +36,9 @@ use axum::{
http::{StatusCode, Uri}, http::{StatusCode, Uri},
}; };
let one = get(|| async {}) let one = get(|| async {}).fallback(fallback_one);
.fallback(fallback_one.into_service());
let two = post(|| async {}) let two = post(|| async {}).fallback(fallback_two);
.fallback(fallback_two.into_service());
let method_route = one.merge(two); let method_route = one.merge(two);

View file

@ -6,7 +6,8 @@
- [Ordering](#ordering) - [Ordering](#ordering)
- [Writing middleware](#writing-middleware) - [Writing middleware](#writing-middleware)
- [Routing to services/middleware and backpressure](#routing-to-servicesmiddleware-and-backpressure) - [Routing to services/middleware and backpressure](#routing-to-servicesmiddleware-and-backpressure)
- [Sharing state between handlers and middleware](#sharing-state-between-handlers-and-middleware) - [Accessing state in middleware](#accessing-state-in-middleware)
- [Passing state from middleware to handlers](#passing-state-from-middleware-to-handlers)
# Intro # Intro
@ -95,7 +96,7 @@ let app = Router::new()
.layer(layer_one) .layer(layer_one)
.layer(layer_two) .layer(layer_two)
.layer(layer_three); .layer(layer_three);
# let app: Router<axum::body::Body> = app; # let _: Router<(), axum::body::Body> = app;
``` ```
Think of the middleware as being layered like an onion where each new layer Think of the middleware as being layered like an onion where each new layer
@ -154,7 +155,7 @@ let app = Router::new()
.layer(layer_two) .layer(layer_two)
.layer(layer_three), .layer(layer_three),
); );
# let app: Router<axum::body::Body> = app; # let _: Router<(), axum::body::Body> = app;
``` ```
`ServiceBuilder` works by composing all layers into one such that they run top `ServiceBuilder` works by composing all layers into one such that they run top
@ -386,9 +387,119 @@ Also note that handlers created from async functions don't care about
backpressure and are always ready. So if you're not using any Tower backpressure and are always ready. So if you're not using any Tower
middleware you don't have to worry about any of this. middleware you don't have to worry about any of this.
# Sharing state between handlers and middleware # Accessing state in middleware
State can be shared between middleware and handlers using [request extensions]: Handlers can access state using the [`State`] extractor but this isn't available
to middleware. Instead you have to pass the state directly to middleware using
either closure captures (for [`axum::middleware::from_fn`]) or regular struct
fields (if you're implementing a [`tower::Layer`])
## Accessing state in `axum::middleware::from_fn`
```rust
use axum::{
Router,
routing::get,
middleware::{self, Next},
response::Response,
extract::State,
http::Request,
};
#[derive(Clone)]
struct AppState {}
async fn my_middleware<B>(
state: AppState,
req: Request<B>,
next: Next<B>,
) -> Response {
next.run(req).await
}
async fn handler(_: State<AppState>) {}
let state = AppState {};
let app = Router::with_state(state.clone())
.route("/", get(handler))
.layer(middleware::from_fn(move |req, next| {
my_middleware(state.clone(), req, next)
}));
# let _: Router<_> = app;
```
## Accessing state in custom `tower::Layer`s
```rust
use axum::{
Router,
routing::get,
middleware::{self, Next},
response::Response,
extract::State,
http::Request,
};
use tower::{Layer, Service};
use std::task::{Context, Poll};
#[derive(Clone)]
struct AppState {}
#[derive(Clone)]
struct MyLayer {
state: AppState,
}
impl<S> Layer<S> for MyLayer {
type Service = MyService<S>;
fn layer(&self, inner: S) -> Self::Service {
MyService {
inner,
state: self.state.clone(),
}
}
}
#[derive(Clone)]
struct MyService<S> {
inner: S,
state: AppState,
}
impl<S, B> Service<Request<B>> for MyService<S>
where
S: Service<Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
// do something with `self.state`
self.inner.call(req)
}
}
async fn handler(_: State<AppState>) {}
let state = AppState {};
let app = Router::with_state(state.clone())
.route("/", get(handler))
.layer(MyLayer { state });
# let _: Router<_> = app;
```
# Passing state from middleware to handlers
State can be passed from middleware to handlers using [request extensions]:
```rust ```rust
use axum::{ use axum::{
@ -415,6 +526,8 @@ async fn auth<B>(mut req: Request<B>, next: Next<B>) -> Result<Response, StatusC
}; };
if let Some(current_user) = authorize_current_user(auth_header).await { if let Some(current_user) = authorize_current_user(auth_header).await {
// insert the current user into a request extension so the handler can
// extract it
req.extensions_mut().insert(current_user); req.extensions_mut().insert(current_user);
Ok(next.run(req).await) Ok(next.run(req).await)
} else { } else {
@ -437,7 +550,7 @@ async fn handler(
let app = Router::new() let app = Router::new()
.route("/", get(handler)) .route("/", get(handler))
.route_layer(middleware::from_fn(auth)); .route_layer(middleware::from_fn(auth));
# let app: Router = app; # let _: Router<()> = app;
``` ```
[Response extensions] can also be used but note that request extensions are not [Response extensions] can also be used but note that request extensions are not
@ -462,3 +575,4 @@ extensions you need.
[`MethodRouter::route_layer`]: crate::routing::MethodRouter::route_layer [`MethodRouter::route_layer`]: crate::routing::MethodRouter::route_layer
[request extensions]: https://docs.rs/http/latest/http/request/struct.Request.html#method.extensions [request extensions]: https://docs.rs/http/latest/http/request/struct.Request.html#method.extensions
[Response extensions]: https://docs.rs/http/latest/http/response/struct.Response.html#method.extensions [Response extensions]: https://docs.rs/http/latest/http/response/struct.Response.html#method.extensions
[`State`]: crate::extract::State

View file

@ -1,4 +1,4 @@
Add a fallback service to the router. Add a fallback [`Handler`] to the router.
This service will be called if no routes matches the incoming request. This service will be called if no routes matches the incoming request.
@ -13,7 +13,7 @@ use axum::{
let app = Router::new() let app = Router::new()
.route("/foo", get(|| async { /* ... */ })) .route("/foo", get(|| async { /* ... */ }))
.fallback(fallback.into_service()); .fallback(fallback);
async fn fallback(uri: Uri) -> (StatusCode, String) { async fn fallback(uri: Uri) -> (StatusCode, String) {
(StatusCode::NOT_FOUND, format!("No route for {}", uri)) (StatusCode::NOT_FOUND, format!("No route for {}", uri))

View file

@ -104,7 +104,7 @@ let api_routes = Router::new().nest("/users", get(|| async {}));
let app = Router::new() let app = Router::new()
.nest("/api", api_routes) .nest("/api", api_routes)
.fallback(fallback.into_service()); .fallback(fallback);
# let _: Router = app; # let _: Router = app;
``` ```
@ -132,11 +132,11 @@ async fn api_fallback() -> (StatusCode, Json<Value>) {
let api_routes = Router::new() let api_routes = Router::new()
.nest("/users", get(|| async {})) .nest("/users", get(|| async {}))
// add dedicated fallback for requests starting with `/api` // add dedicated fallback for requests starting with `/api`
.fallback(api_fallback.into_service()); .fallback(api_fallback);
let app = Router::new() let app = Router::new()
.nest("/api", api_routes) .nest("/api", api_routes)
.fallback(fallback.into_service()); .fallback(fallback);
# let _: Router = app; # let _: Router = app;
``` ```

View file

@ -3,10 +3,10 @@ Add another route to the router.
`path` is a string of path segments separated by `/`. Each segment `path` is a string of path segments separated by `/`. Each segment
can be either static, a capture, or a wildcard. can be either static, a capture, or a wildcard.
`service` is the [`Service`] that should receive the request if the path matches `method_router` is the [`MethodRouter`] that should receive the request if the
`path`. `service` will commonly be a handler wrapped in a method router like path matches `path`. `method_router` will commonly be a handler wrapped in a method
[`get`](crate::routing::get). See [`handler`](crate::handler) for more details router like [`get`](crate::routing::get). See [`handler`](crate::handler) for
on handlers. more details on handlers.
# Static paths # Static paths
@ -105,69 +105,6 @@ async fn serve_asset(Path(path): Path<String>) {}
# }; # };
``` ```
# Routing to any [`Service`]
axum also supports routing to general [`Service`]s:
```rust,no_run
use axum::{
Router,
body::Body,
routing::{any_service, get_service},
http::{Request, StatusCode},
error_handling::HandleErrorLayer,
};
use tower_http::services::ServeFile;
use http::Response;
use std::{convert::Infallible, io};
use tower::service_fn;
let app = Router::new()
.route(
// Any request to `/` goes to a service
"/",
// Services whose response body is not `axum::body::BoxBody`
// can be wrapped in `axum::routing::any_service` (or one of the other routing filters)
// to have the response body mapped
any_service(service_fn(|_: Request<Body>| async {
let res = Response::new(Body::from("Hi from `GET /`"));
Ok::<_, Infallible>(res)
}))
)
.route(
"/foo",
// This service's response body is `axum::body::BoxBody` so
// it can be routed to directly.
service_fn(|req: Request<Body>| async move {
let body = Body::from(format!("Hi from `{} /foo`", req.method()));
let body = axum::body::boxed(body);
let res = Response::new(body);
Ok::<_, Infallible>(res)
})
)
.route(
// GET `/static/Cargo.toml` goes to a service from tower-http
"/static/Cargo.toml",
get_service(ServeFile::new("Cargo.toml"))
// though we must handle any potential errors
.handle_error(|error: io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
})
);
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
```
Routing to arbitrary services in this way has complications for backpressure
([`Service::poll_ready`]). See the [Routing to services and backpressure] module
for more details.
[Routing to services and backpressure]: middleware/index.html#routing-to-servicesmiddleware-and-backpressure
# Panics # Panics
Panics if the route overlaps with another route: Panics if the route overlaps with another route:
@ -187,21 +124,3 @@ The static route `/foo` and the dynamic route `/:key` are not considered to
overlap and `/foo` will take precedence. overlap and `/foo` will take precedence.
Also panics if `path` is empty. Also panics if `path` is empty.
## Nesting
`route` cannot be used to nest `Router`s. Instead use [`Router::nest`].
Attempting to will result in a panic:
```rust,should_panic
use axum::{routing::get, Router};
let app = Router::new().route(
"/",
Router::new().route("/foo", get(|| async {})),
);
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
```

View file

@ -0,0 +1,81 @@
Add another route to the router that calls a [`Service`].
# Example
```rust,no_run
use axum::{
Router,
body::Body,
routing::{any_service, get_service},
http::{Request, StatusCode},
error_handling::HandleErrorLayer,
};
use tower_http::services::ServeFile;
use http::Response;
use std::{convert::Infallible, io};
use tower::service_fn;
let app = Router::new()
.route(
// Any request to `/` goes to a service
"/",
// Services whose response body is not `axum::body::BoxBody`
// can be wrapped in `axum::routing::any_service` (or one of the other routing filters)
// to have the response body mapped
any_service(service_fn(|_: Request<Body>| async {
let res = Response::new(Body::from("Hi from `GET /`"));
Ok::<_, Infallible>(res)
}))
)
.route_service(
"/foo",
// This service's response body is `axum::body::BoxBody` so
// it can be routed to directly.
service_fn(|req: Request<Body>| async move {
let body = Body::from(format!("Hi from `{} /foo`", req.method()));
let body = axum::body::boxed(body);
let res = Response::new(body);
Ok::<_, Infallible>(res)
})
)
.route(
// GET `/static/Cargo.toml` goes to a service from tower-http
"/static/Cargo.toml",
get_service(ServeFile::new("Cargo.toml"))
// though we must handle any potential errors
.handle_error(|error: io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
})
);
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
```
Routing to arbitrary services in this way has complications for backpressure
([`Service::poll_ready`]). See the [Routing to services and backpressure] module
for more details.
# Panics
Panics for the same reasons as [`Router::route`] or if you attempt to route to a
`Router`:
```rust,should_panic
use axum::{routing::get, Router};
let app = Router::new().route_service(
"/",
Router::new().route("/foo", get(|| async {})),
);
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
```
Use [`Router::nest`] instead.
[Routing to services and backpressure]: middleware/index.html#routing-to-servicesmiddleware-and-backpressure

View file

@ -1,7 +1,6 @@
#![doc = include_str!("../docs/error_handling.md")] #![doc = include_str!("../docs/error_handling.md")]
use crate::{ use crate::{
body::boxed,
extract::{FromRequest, RequestParts}, extract::{FromRequest, RequestParts},
http::{Request, StatusCode}, http::{Request, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
@ -113,16 +112,16 @@ where
} }
} }
impl<S, F, ReqBody, Fut, Res> Service<Request<ReqBody>> for HandleError<S, F, ()> impl<S, F, B, Fut, Res> Service<Request<B>> for HandleError<S, F, ()>
where where
S: Service<Request<ReqBody>> + Clone + Send + 'static, S: Service<Request<B>> + Clone + Send + 'static,
S::Response: IntoResponse + Send, S::Response: IntoResponse + Send,
S::Error: Send, S::Error: Send,
S::Future: Send, S::Future: Send,
F: FnOnce(S::Error) -> Fut + Clone + Send + 'static, F: FnOnce(S::Error) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send, Fut: Future<Output = Res> + Send,
Res: IntoResponse, Res: IntoResponse,
ReqBody: Send + 'static, B: Send + 'static,
{ {
type Response = Response; type Response = Response;
type Error = Infallible; type Error = Infallible;
@ -132,7 +131,7 @@ where
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, req: Request<ReqBody>) -> Self::Future { fn call(&mut self, req: Request<B>) -> Self::Future {
let f = self.f.clone(); let f = self.f.clone();
let clone = self.inner.clone(); let clone = self.inner.clone();
@ -152,18 +151,18 @@ where
#[allow(unused_macros)] #[allow(unused_macros)]
macro_rules! impl_service { macro_rules! impl_service {
( $($ty:ident),* $(,)? ) => { ( $($ty:ident),* $(,)? ) => {
impl<S, F, ReqBody, Res, Fut, $($ty,)*> Service<Request<ReqBody>> impl<S, F, B, Res, Fut, $($ty,)*> Service<Request<B>>
for HandleError<S, F, ($($ty,)*)> for HandleError<S, F, ($($ty,)*)>
where where
S: Service<Request<ReqBody>> + Clone + Send + 'static, S: Service<Request<B>> + Clone + Send + 'static,
S::Response: IntoResponse + Send, S::Response: IntoResponse + Send,
S::Error: Send, S::Error: Send,
S::Future: Send, S::Future: Send,
F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static, F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send, Fut: Future<Output = Res> + Send,
Res: IntoResponse, Res: IntoResponse,
$( $ty: FromRequest<ReqBody> + Send,)* $( $ty: FromRequest<(), B> + Send,)*
ReqBody: Send + 'static, B: Send + 'static,
{ {
type Response = Response; type Response = Response;
type Error = Infallible; type Error = Infallible;
@ -175,7 +174,7 @@ macro_rules! impl_service {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn call(&mut self, req: Request<ReqBody>) -> Self::Future { fn call(&mut self, req: Request<B>) -> Self::Future {
let f = self.f.clone(); let f = self.f.clone();
let clone = self.inner.clone(); let clone = self.inner.clone();
@ -187,7 +186,7 @@ macro_rules! impl_service {
$( $(
let $ty = match $ty::from_request(&mut req).await { let $ty = match $ty::from_request(&mut req).await {
Ok(value) => value, Ok(value) => value,
Err(rejection) => return Ok(rejection.into_response().map(boxed)), Err(rejection) => return Ok(rejection.into_response()),
}; };
)* )*
@ -200,7 +199,7 @@ macro_rules! impl_service {
match inner.oneshot(req).await { match inner.oneshot(req).await {
Ok(res) => Ok(res.into_response()), Ok(res) => Ok(res.into_response()),
Err(err) => Ok(f($($ty),*, err).await.into_response().map(boxed)), Err(err) => Ok(f($($ty),*, err).await.into_response()),
} }
}); });

View file

@ -73,14 +73,15 @@ use tower_service::Service;
pub struct Extension<T>(pub T); pub struct Extension<T>(pub T);
#[async_trait] #[async_trait]
impl<T, B> FromRequest<B> for Extension<T> impl<T, S, B> FromRequest<S, B> for Extension<T>
where where
T: Clone + Send + Sync + 'static, T: Clone + Send + Sync + 'static,
B: Send, B: Send,
S: Send,
{ {
type Rejection = ExtensionRejection; type Rejection = ExtensionRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let value = req let value = req
.extensions() .extensions()
.get::<T>() .get::<T>()

View file

@ -128,14 +128,15 @@ opaque_future! {
pub struct ConnectInfo<T>(pub T); pub struct ConnectInfo<T>(pub T);
#[async_trait] #[async_trait]
impl<B, T> FromRequest<B> for ConnectInfo<T> impl<S, B, T> FromRequest<S, B> for ConnectInfo<T>
where where
B: Send, B: Send,
S: Send,
T: Clone + Send + Sync + 'static, T: Clone + Send + Sync + 'static,
{ {
type Rejection = <Extension<Self> as FromRequest<B>>::Rejection; type Rejection = <Extension<Self> as FromRequest<S, B>>::Rejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let Extension(connect_info) = Extension::<Self>::from_request(req).await?; let Extension(connect_info) = Extension::<Self>::from_request(req).await?;
Ok(connect_info) Ok(connect_info)
} }

View file

@ -36,15 +36,16 @@ use std::ops::Deref;
pub struct ContentLengthLimit<T, const N: u64>(pub T); pub struct ContentLengthLimit<T, const N: u64>(pub T);
#[async_trait] #[async_trait]
impl<T, B, const N: u64> FromRequest<B> for ContentLengthLimit<T, N> impl<T, S, B, const N: u64> FromRequest<S, B> for ContentLengthLimit<T, N>
where where
T: FromRequest<B>, T: FromRequest<S, B>,
T::Rejection: IntoResponse, T::Rejection: IntoResponse,
B: Send, B: Send,
S: Send,
{ {
type Rejection = ContentLengthLimitRejection<T::Rejection>; type Rejection = ContentLengthLimitRejection<T::Rejection>;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let content_length = req let content_length = req
.headers() .headers()
.get(http::header::CONTENT_LENGTH) .get(http::header::CONTENT_LENGTH)

View file

@ -21,13 +21,14 @@ const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
pub struct Host(pub String); pub struct Host(pub String);
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for Host impl<S, B> FromRequest<S, B> for Host
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = HostRejection; type Rejection = HostRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if let Some(host) = parse_forwarded(req.headers()) { if let Some(host) = parse_forwarded(req.headers()) {
return Ok(Host(host.to_owned())); return Ok(Host(host.to_owned()));
} }

View file

@ -64,13 +64,14 @@ impl MatchedPath {
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for MatchedPath impl<S, B> FromRequest<S, B> for MatchedPath
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = MatchedPathRejection; type Rejection = MatchedPathRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let matched_path = req let matched_path = req
.extensions() .extensions()
.get::<Self>() .get::<Self>()
@ -84,7 +85,9 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::{extract::Extension, handler::Handler, routing::get, test_helpers::*, Router}; use crate::{
extract::Extension, handler::HandlerWithoutStateExt, routing::get, test_helpers::*, Router,
};
use http::{Request, StatusCode}; use http::{Request, StatusCode};
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use tower::layer::layer_fn; use tower::layer::layer_fn;
@ -93,7 +96,7 @@ mod tests {
#[derive(Clone)] #[derive(Clone)]
struct SetMatchedPathExtension<S>(S); struct SetMatchedPathExtension<S>(S);
impl<B, S> Service<Request<B>> for SetMatchedPathExtension<S> impl<S, B> Service<Request<B>> for SetMatchedPathExtension<S>
where where
S: Service<Request<B>>, S: Service<Request<B>>,
{ {

View file

@ -14,9 +14,10 @@ mod content_length_limit;
mod host; mod host;
mod raw_query; mod raw_query;
mod request_parts; mod request_parts;
mod state;
#[doc(inline)] #[doc(inline)]
pub use axum_core::extract::{FromRequest, RequestParts}; pub use axum_core::extract::{FromRef, FromRequest, RequestParts};
#[doc(inline)] #[doc(inline)]
#[allow(deprecated)] #[allow(deprecated)]
@ -27,6 +28,7 @@ pub use self::{
path::Path, path::Path,
raw_query::RawQuery, raw_query::RawQuery,
request_parts::{BodyStream, RawBody}, request_parts::{BodyStream, RawBody},
state::State,
}; };
#[doc(no_inline)] #[doc(no_inline)]
@ -73,13 +75,13 @@ pub use self::ws::WebSocketUpgrade;
#[doc(no_inline)] #[doc(no_inline)]
pub use crate::TypedHeader; pub use crate::TypedHeader;
pub(crate) fn take_body<B>(req: &mut RequestParts<B>) -> Result<B, BodyAlreadyExtracted> { pub(crate) fn take_body<S, B>(req: &mut RequestParts<S, B>) -> Result<B, BodyAlreadyExtracted> {
req.take_body().ok_or_else(BodyAlreadyExtracted::default) req.take_body().ok_or_else(BodyAlreadyExtracted::default)
} }
// this is duplicated in `axum-extra/src/extract/form.rs` // this is duplicated in `axum-extra/src/extract/form.rs`
pub(super) fn has_content_type<B>( pub(super) fn has_content_type<S, B>(
req: &RequestParts<B>, req: &RequestParts<S, B>,
expected_content_type: &mime::Mime, expected_content_type: &mime::Mime,
) -> bool { ) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {

View file

@ -50,14 +50,15 @@ pub struct Multipart {
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for Multipart impl<S, B> FromRequest<S, B> for Multipart
where where
B: HttpBody<Data = Bytes> + Default + Unpin + Send + 'static, B: HttpBody<Data = Bytes> + Default + Unpin + Send + 'static,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send,
{ {
type Rejection = MultipartRejection; type Rejection = MultipartRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let stream = BodyStream::from_request(req).await?; let stream = BodyStream::from_request(req).await?;
let headers = req.headers(); let headers = req.headers();
let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?; let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?;
@ -179,7 +180,7 @@ impl<'a> Field<'a> {
/// } /// }
/// ///
/// let app = Router::new().route("/upload", post(upload)); /// let app = Router::new().route("/upload", post(upload));
/// # let _: Router<axum::body::Body> = app; /// # let _: Router = app;
/// ``` /// ```
pub async fn chunk(&mut self) -> Result<Option<Bytes>, MultipartError> { pub async fn chunk(&mut self) -> Result<Option<Bytes>, MultipartError> {
self.inner self.inner

View file

@ -163,14 +163,15 @@ impl<T> DerefMut for Path<T> {
} }
#[async_trait] #[async_trait]
impl<T, B> FromRequest<B> for Path<T> impl<T, S, B> FromRequest<S, B> for Path<T>
where where
T: DeserializeOwned + Send, T: DeserializeOwned + Send,
B: Send, B: Send,
S: Send,
{ {
type Rejection = PathRejection; type Rejection = PathRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let params = match req.extensions_mut().get::<UrlParams>() { let params = match req.extensions_mut().get::<UrlParams>() {
Some(UrlParams::Params(params)) => params, Some(UrlParams::Params(params)) => params,
Some(UrlParams::InvalidUtf8InPathParam { key }) => { Some(UrlParams::InvalidUtf8InPathParam { key }) => {

View file

@ -49,14 +49,15 @@ use std::ops::Deref;
pub struct Query<T>(pub T); pub struct Query<T>(pub T);
#[async_trait] #[async_trait]
impl<T, B> FromRequest<B> for Query<T> impl<T, S, B> FromRequest<S, B> for Query<T>
where where
T: DeserializeOwned, T: DeserializeOwned,
B: Send, B: Send,
S: Send,
{ {
type Rejection = QueryRejection; type Rejection = QueryRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let query = req.uri().query().unwrap_or_default(); let query = req.uri().query().unwrap_or_default();
let value = serde_urlencoded::from_str(query) let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?; .map_err(FailedToDeserializeQueryString::__private_new)?;
@ -81,7 +82,8 @@ mod tests {
use std::fmt::Debug; use std::fmt::Debug;
async fn check<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) { async fn check<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) {
let mut req = RequestParts::new(Request::builder().uri(uri.as_ref()).body(()).unwrap()); let req = Request::builder().uri(uri.as_ref()).body(()).unwrap();
let mut req = RequestParts::new(req);
assert_eq!(Query::<T>::from_request(&mut req).await.unwrap().0, value); assert_eq!(Query::<T>::from_request(&mut req).await.unwrap().0, value);
} }

View file

@ -27,13 +27,14 @@ use std::convert::Infallible;
pub struct RawQuery(pub Option<String>); pub struct RawQuery(pub Option<String>);
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for RawQuery impl<S, B> FromRequest<S, B> for RawQuery
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let query = req.uri().query().map(|query| query.to_owned()); let query = req.uri().query().map(|query| query.to_owned());
Ok(Self(query)) Ok(Self(query))
} }

View file

@ -86,13 +86,14 @@ pub struct OriginalUri(pub Uri);
#[cfg(feature = "original-uri")] #[cfg(feature = "original-uri")]
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for OriginalUri impl<S, B> FromRequest<S, B> for OriginalUri
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = Infallible; type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let uri = Extension::<Self>::from_request(req) let uri = Extension::<Self>::from_request(req)
.await .await
.unwrap_or_else(|_| Extension(OriginalUri(req.uri().clone()))) .unwrap_or_else(|_| Extension(OriginalUri(req.uri().clone())))
@ -140,15 +141,16 @@ impl Stream for BodyStream {
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for BodyStream impl<S, B> FromRequest<S, B> for BodyStream
where where
B: HttpBody + Send + 'static, B: HttpBody + Send + 'static,
B::Data: Into<Bytes>, B::Data: Into<Bytes>,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send,
{ {
type Rejection = BodyAlreadyExtracted; type Rejection = BodyAlreadyExtracted;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)? let body = take_body(req)?
.map_data(Into::into) .map_data(Into::into)
.map_err(|err| Error::new(err.into())); .map_err(|err| Error::new(err.into()));
@ -196,13 +198,14 @@ fn body_stream_traits() {
pub struct RawBody<B = Body>(pub B); pub struct RawBody<B = Body>(pub B);
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for RawBody<B> impl<S, B> FromRequest<S, B> for RawBody<B>
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = BodyAlreadyExtracted; type Rejection = BodyAlreadyExtracted;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?; let body = take_body(req)?;
Ok(Self(body)) Ok(Self(body))
} }

207
axum/src/extract/state.rs Normal file
View file

@ -0,0 +1,207 @@
use async_trait::async_trait;
use axum_core::extract::{FromRef, FromRequest, RequestParts};
use std::{
convert::Infallible,
ops::{Deref, DerefMut},
};
/// Extractor for state.
///
/// Note this extractor is not available to middleware. See ["Accessing state in
/// middleware"][state-from-middleware] for how to access state in middleware.
///
/// [state-from-middleware]: ../middleware/index.html#accessing-state-in-middleware
///
/// # With `Router`
///
/// ```
/// use axum::{Router, routing::get, extract::State};
///
/// // the application state
/// //
/// // here you can put configuration, database connection pools, or whatever
/// // state you need
/// #[derive(Clone)]
/// struct AppState {}
///
/// let state = AppState {};
///
/// // create a `Router` that holds our state
/// let app = Router::with_state(state).route("/", get(handler));
///
/// async fn handler(
/// // access the state via the `State` extractor
/// // extracting a state of the wrong type results in a compile error
/// State(state): State<AppState>,
/// ) {
/// // use `state`...
/// }
/// # let _: Router<AppState> = app;
/// ```
///
/// # With `MethodRouter`
///
/// ```
/// use axum::{routing::get, extract::State};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// let state = AppState {};
///
/// let method_router_with_state = get(handler)
/// // provide the state so the handler can access it
/// .with_state(state);
///
/// async fn handler(State(state): State<AppState>) {
/// // use `state`...
/// }
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(method_router_with_state.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// # With `Handler`
///
/// ```
/// use axum::{routing::get, handler::Handler, extract::State};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// let state = AppState {};
///
/// async fn handler(State(state): State<AppState>) {
/// // use `state`...
/// }
///
/// // provide the state so the handler can access it
/// let handler_with_state = handler.with_state(state);
///
/// # async {
/// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
/// .serve(handler_with_state.into_make_service())
/// .await
/// .expect("server failed");
/// # };
/// ```
///
/// # Substates
///
/// [`State`] only allows a single state type but you can use [`From`] to extract "substates":
///
/// ```
/// use axum::{Router, routing::get, extract::{State, FromRef}};
///
/// // the application state
/// #[derive(Clone)]
/// struct AppState {
/// // that holds some api specific state
/// api_state: ApiState,
/// }
///
/// // the api specific state
/// #[derive(Clone)]
/// struct ApiState {}
///
/// // support converting an `AppState` in an `ApiState`
/// impl FromRef<AppState> for ApiState {
/// fn from_ref(app_state: &AppState) -> ApiState {
/// app_state.api_state.clone()
/// }
/// }
///
/// let state = AppState {
/// api_state: ApiState {},
/// };
///
/// let app = Router::with_state(state)
/// .route("/", get(handler))
/// .route("/api/users", get(api_users));
///
/// async fn api_users(
/// // access the api specific state
/// State(api_state): State<ApiState>,
/// ) {
/// }
///
/// async fn handler(
/// // we can still access to top level state
/// State(state): State<AppState>,
/// ) {
/// }
/// # let _: Router<AppState> = app;
/// ```
///
/// # For library authors
///
/// If you're writing a library that has an extractor that needs state, this is the recommended way
/// to do it:
///
/// ```rust
/// use axum_core::extract::{FromRequest, RequestParts, FromRef};
/// use async_trait::async_trait;
/// use std::convert::Infallible;
///
/// // the extractor your library provides
/// struct MyLibraryExtractor;
///
/// #[async_trait]
/// impl<S, B> FromRequest<S, B> for MyLibraryExtractor
/// where
/// B: Send,
/// // keep `S` generic but require that it can produce a `MyLibraryState`
/// // this means users will have to implement `FromRef<UserState> for MyLibraryState`
/// MyLibraryState: FromRef<S>,
/// S: Send,
/// {
/// type Rejection = Infallible;
///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // get a `MyLibraryState` from a reference to the state
/// let state = MyLibraryState::from_ref(req.state());
///
/// // ...
/// # todo!()
/// }
/// }
///
/// // the state your library needs
/// struct MyLibraryState {
/// // ...
/// }
/// ```
///
/// Note that you don't need to use the `State` extractor since you can access the state directly
/// from [`RequestParts`].
#[derive(Debug, Default, Clone, Copy)]
pub struct State<S>(pub S);
#[async_trait]
impl<B, OuterState, InnerState> FromRequest<OuterState, B> for State<InnerState>
where
B: Send,
InnerState: FromRef<OuterState>,
OuterState: Send,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<OuterState, B>) -> Result<Self, Self::Rejection> {
let inner_state = InnerState::from_ref(req.state());
Ok(Self(inner_state))
}
}
impl<S> Deref for State<S> {
type Target = S;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S> DerefMut for State<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

View file

@ -275,13 +275,14 @@ impl WebSocketUpgrade {
} }
#[async_trait] #[async_trait]
impl<B> FromRequest<B> for WebSocketUpgrade impl<S, B> FromRequest<S, B> for WebSocketUpgrade
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = WebSocketUpgradeRejection; type Rejection = WebSocketUpgradeRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if req.method() != Method::GET { if req.method() != Method::GET {
return Err(MethodNotGet.into()); return Err(MethodNotGet.into());
} }
@ -320,7 +321,7 @@ where
} }
} }
fn header_eq<B>(req: &RequestParts<B>, key: HeaderName, value: &'static str) -> bool { fn header_eq<S, B>(req: &RequestParts<S, B>, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = req.headers().get(&key) { if let Some(header) = req.headers().get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
} else { } else {
@ -328,7 +329,7 @@ fn header_eq<B>(req: &RequestParts<B>, key: HeaderName, value: &'static str) ->
} }
} }
fn header_contains<B>(req: &RequestParts<B>, key: HeaderName, value: &'static str) -> bool { fn header_contains<S, B>(req: &RequestParts<S, B>, key: HeaderName, value: &'static str) -> bool {
let header = if let Some(header) = req.headers().get(&key) { let header = if let Some(header) = req.headers().get(&key) {
header header
} else { } else {

View file

@ -56,16 +56,17 @@ use std::ops::Deref;
pub struct Form<T>(pub T); pub struct Form<T>(pub T);
#[async_trait] #[async_trait]
impl<T, B> FromRequest<B> for Form<T> impl<T, S, B> FromRequest<S, B> for Form<T>
where where
T: DeserializeOwned, T: DeserializeOwned,
B: HttpBody + Send, B: HttpBody + Send,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send,
{ {
type Rejection = FormRejection; type Rejection = FormRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if req.method() == Method::GET { if req.method() == Method::GET {
let query = req.uri().query().unwrap_or_default(); let query = req.uri().query().unwrap_or_default();
let value = serde_urlencoded::from_str(query) let value = serde_urlencoded::from_str(query)
@ -125,29 +126,27 @@ mod tests {
} }
async fn check_query<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) { async fn check_query<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) {
let mut req = RequestParts::new( let req = Request::builder()
Request::builder() .uri(uri.as_ref())
.uri(uri.as_ref()) .body(Empty::<Bytes>::new())
.body(Empty::<Bytes>::new()) .unwrap();
.unwrap(), let mut req = RequestParts::new(req);
);
assert_eq!(Form::<T>::from_request(&mut req).await.unwrap().0, value); assert_eq!(Form::<T>::from_request(&mut req).await.unwrap().0, value);
} }
async fn check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T) { async fn check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T) {
let mut req = RequestParts::new( let req = Request::builder()
Request::builder() .uri("http://example.com/test")
.uri("http://example.com/test") .method(Method::POST)
.method(Method::POST) .header(
.header( http::header::CONTENT_TYPE,
http::header::CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), )
) .body(Full::<Bytes>::new(
.body(Full::<Bytes>::new( serde_urlencoded::to_string(&value).unwrap().into(),
serde_urlencoded::to_string(&value).unwrap().into(), ))
)) .unwrap();
.unwrap(), let mut req = RequestParts::new(req);
);
assert_eq!(Form::<T>::from_request(&mut req).await.unwrap().0, value); assert_eq!(Form::<T>::from_request(&mut req).await.unwrap().0, value);
} }
@ -204,21 +203,20 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_incorrect_content_type() { async fn test_incorrect_content_type() {
let mut req = RequestParts::new( let req = Request::builder()
Request::builder() .uri("http://example.com/test")
.uri("http://example.com/test") .method(Method::POST)
.method(Method::POST) .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) .body(Full::<Bytes>::new(
.body(Full::<Bytes>::new( serde_urlencoded::to_string(&Pagination {
serde_urlencoded::to_string(&Pagination { size: Some(10),
size: Some(10), page: None,
page: None, })
}) .unwrap()
.unwrap() .into(),
.into(), ))
)) .unwrap();
.unwrap(), let mut req = RequestParts::new(req);
);
assert!(matches!( assert!(matches!(
Form::<Pagination>::from_request(&mut req) Form::<Pagination>::from_request(&mut req)
.await .await

View file

@ -19,29 +19,29 @@ opaque_future! {
pin_project! { pin_project! {
/// The response future for [`Layered`](super::Layered). /// The response future for [`Layered`](super::Layered).
pub struct LayeredFuture<S, ReqBody> pub struct LayeredFuture<B, S>
where where
S: Service<Request<ReqBody>>, S: Service<Request<B>>,
{ {
#[pin] #[pin]
inner: Map<Oneshot<S, Request<ReqBody>>, fn(Result<S::Response, S::Error>) -> Response>, inner: Map<Oneshot<S, Request<B>>, fn(Result<S::Response, S::Error>) -> Response>,
} }
} }
impl<S, ReqBody> LayeredFuture<S, ReqBody> impl<B, S> LayeredFuture<B, S>
where where
S: Service<Request<ReqBody>>, S: Service<Request<B>>,
{ {
pub(super) fn new( pub(super) fn new(
inner: Map<Oneshot<S, Request<ReqBody>>, fn(Result<S::Response, S::Error>) -> Response>, inner: Map<Oneshot<S, Request<B>>, fn(Result<S::Response, S::Error>) -> Response>,
) -> Self { ) -> Self {
Self { inner } Self { inner }
} }
} }
impl<S, ReqBody> Future for LayeredFuture<S, ReqBody> impl<B, S> Future for LayeredFuture<B, S>
where where
S: Service<Request<ReqBody>>, S: Service<Request<B>>,
{ {
type Output = Response; type Output = Response;

View file

@ -11,29 +11,40 @@ use tower_service::Service;
/// An adapter that makes a [`Handler`] into a [`Service`]. /// An adapter that makes a [`Handler`] into a [`Service`].
/// ///
/// Created with [`Handler::into_service`]. /// Created with [`HandlerWithoutStateExt::into_service`].
pub struct IntoService<H, T, B> { ///
/// [`HandlerWithoutStateExt::into_service`]: super::HandlerWithoutStateExt::into_service
pub struct IntoService<H, T, S, B> {
handler: H, handler: H,
state: S,
_marker: PhantomData<fn() -> (T, B)>, _marker: PhantomData<fn() -> (T, B)>,
} }
impl<H, T, S, B> IntoService<H, T, S, B> {
/// Get a reference to the state.
pub fn state(&self) -> &S {
&self.state
}
}
#[test] #[test]
fn traits() { fn traits() {
use crate::test_helpers::*; use crate::test_helpers::*;
assert_send::<IntoService<(), NotSendSync, NotSendSync>>(); assert_send::<IntoService<(), NotSendSync, (), NotSendSync>>();
assert_sync::<IntoService<(), NotSendSync, NotSendSync>>(); assert_sync::<IntoService<(), NotSendSync, (), NotSendSync>>();
} }
impl<H, T, B> IntoService<H, T, B> { impl<H, T, S, B> IntoService<H, T, S, B> {
pub(super) fn new(handler: H) -> Self { pub(super) fn new(handler: H, state: S) -> Self {
Self { Self {
handler, handler,
state,
_marker: PhantomData, _marker: PhantomData,
} }
} }
} }
impl<H, T, B> fmt::Debug for IntoService<H, T, B> { impl<H, T, S, B> fmt::Debug for IntoService<H, T, S, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("IntoService") f.debug_tuple("IntoService")
.field(&format_args!("...")) .field(&format_args!("..."))
@ -41,22 +52,25 @@ impl<H, T, B> fmt::Debug for IntoService<H, T, B> {
} }
} }
impl<H, T, B> Clone for IntoService<H, T, B> impl<H, T, S, B> Clone for IntoService<H, T, S, B>
where where
H: Clone, H: Clone,
S: Clone,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
handler: self.handler.clone(), handler: self.handler.clone(),
state: self.state.clone(),
_marker: PhantomData, _marker: PhantomData,
} }
} }
} }
impl<H, T, B> Service<Request<B>> for IntoService<H, T, B> impl<H, T, S, B> Service<Request<B>> for IntoService<H, T, S, B>
where where
H: Handler<T, B> + Clone + Send + 'static, H: Handler<T, S, B> + Clone + Send + 'static,
B: Send + 'static, B: Send + 'static,
S: Clone,
{ {
type Response = Response; type Response = Response;
type Error = Infallible; type Error = Infallible;
@ -74,7 +88,7 @@ where
use futures_util::future::FutureExt; use futures_util::future::FutureExt;
let handler = self.handler.clone(); let handler = self.handler.clone();
let future = Handler::call(handler, req); let future = Handler::call(handler, self.state.clone(), req);
let future = future.map(Ok as _); let future = future.map(Ok as _);
super::future::IntoServiceFuture::new(future) super::future::IntoServiceFuture::new(future)

View file

@ -0,0 +1,85 @@
use super::Handler;
use crate::response::Response;
use http::Request;
use std::{
convert::Infallible,
fmt,
marker::PhantomData,
task::{Context, Poll},
};
use tower_service::Service;
pub(crate) struct IntoServiceStateInExtension<H, T, S, B> {
handler: H,
_marker: PhantomData<fn() -> (T, S, B)>,
}
#[test]
fn traits() {
use crate::test_helpers::*;
assert_send::<IntoServiceStateInExtension<(), NotSendSync, (), NotSendSync>>();
assert_sync::<IntoServiceStateInExtension<(), NotSendSync, (), NotSendSync>>();
}
impl<H, T, S, B> IntoServiceStateInExtension<H, T, S, B> {
pub(crate) fn new(handler: H) -> Self {
Self {
handler,
_marker: PhantomData,
}
}
}
impl<H, T, S, B> fmt::Debug for IntoServiceStateInExtension<H, T, S, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("IntoServiceStateInExtension")
.field(&format_args!("..."))
.finish()
}
}
impl<H, T, S, B> Clone for IntoServiceStateInExtension<H, T, S, B>
where
H: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
_marker: PhantomData,
}
}
}
impl<H, T, S, B> Service<Request<B>> for IntoServiceStateInExtension<H, T, S, B>
where
H: Handler<T, S, B> + Clone + Send + 'static,
B: Send + 'static,
S: Clone + Send + Sync + 'static,
{
type Response = Response;
type Error = Infallible;
type Future = super::future::IntoServiceFuture<H::Future>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// `IntoServiceStateInExtension` can only be constructed from async functions which are always ready, or
// from `Layered` which bufferes in `<Layered as Handler>::call` and is therefore
// also always ready.
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
use futures_util::future::FutureExt;
let state = req
.extensions_mut()
.remove::<S>()
.expect("state extension missing. This is a bug in axum, please file an issue");
let handler = self.handler.clone();
let future = Handler::call(handler, state, req);
let future = future.map(Ok as _);
super::future::IntoServiceFuture::new(future)
}
}

View file

@ -49,8 +49,11 @@ use tower_service::Service;
pub mod future; pub mod future;
mod into_service; mod into_service;
mod into_service_state_in_extension;
mod with_state;
pub use self::into_service::IntoService; pub(crate) use self::into_service_state_in_extension::IntoServiceStateInExtension;
pub use self::{into_service::IntoService, with_state::WithState};
/// Trait for async functions that can be used to handle requests. /// Trait for async functions that can be used to handle requests.
/// ///
@ -59,13 +62,45 @@ pub use self::into_service::IntoService;
/// ///
/// See the [module docs](crate::handler) for more details. /// See the [module docs](crate::handler) for more details.
/// ///
/// # Converting `Handler`s into [`Service`]s
///
/// To convert `Handler`s into [`Service`]s you have to call either
/// [`HandlerWithoutStateExt::into_service`] or [`Handler::with_state`]:
///
/// ```
/// use tower::Service;
/// use axum::{
/// extract::State,
/// body::Body,
/// http::Request,
/// handler::{HandlerWithoutStateExt, Handler},
/// };
///
/// // this handler doesn't require any state
/// async fn one() {}
/// // so it can be converted to a service with `HandlerWithoutStateExt::into_service`
/// assert_service(one.into_service());
///
/// // this handler requires state
/// async fn two(_: State<String>) {}
/// // so we have to provide it
/// let handler_with_state = two.with_state(String::new());
/// // which gives us a `Service`
/// assert_service(handler_with_state);
///
/// // helper to check that a value implements `Service`
/// fn assert_service<S>(service: S)
/// where
/// S: Service<Request<Body>>,
/// {}
/// ```
#[doc = include_str!("../docs/debugging_handler_type_errors.md")] #[doc = include_str!("../docs/debugging_handler_type_errors.md")]
pub trait Handler<T, B = Body>: Clone + Send + Sized + 'static { pub trait Handler<T, S = (), B = Body>: Clone + Send + Sized + 'static {
/// The type of future calling this handler returns. /// The type of future calling this handler returns.
type Future: Future<Output = Response> + Send + 'static; type Future: Future<Output = Response> + Send + 'static;
/// Call the handler with the given request. /// Call the handler with the given request.
fn call(self, req: Request<B>) -> Self::Future; fn call(self, state: S, req: Request<B>) -> Self::Future;
/// Apply a [`tower::Layer`] to the handler. /// Apply a [`tower::Layer`] to the handler.
/// ///
@ -103,112 +138,26 @@ pub trait Handler<T, B = Body>: Clone + Send + Sized + 'static {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # }; /// # };
/// ``` /// ```
fn layer<L>(self, layer: L) -> Layered<L::Service, T> fn layer<L>(self, layer: L) -> Layered<L, Self, T, S, B>
where where
L: Layer<IntoService<Self, T, B>>, L: Layer<WithState<Self, T, S, B>>,
{ {
Layered::new(layer.layer(self.into_service())) Layered {
layer,
handler: self,
_marker: PhantomData,
}
} }
/// Convert the handler into a [`Service`]. /// Convert the handler into a [`Service`] by providing the state
/// fn with_state(self, state: S) -> WithState<Self, T, S, B> {
/// This is commonly used together with [`Router::fallback`]: WithState {
/// service: IntoService::new(self, state),
/// ```rust }
/// use axum::{
/// Server,
/// handler::Handler,
/// http::{Uri, Method, StatusCode},
/// response::IntoResponse,
/// routing::{get, Router},
/// };
/// use tower::make::Shared;
/// use std::net::SocketAddr;
///
/// async fn handler(method: Method, uri: Uri) -> (StatusCode, String) {
/// (StatusCode::NOT_FOUND, format!("Nothing to see at {} {}", method, uri))
/// }
///
/// let app = Router::new()
/// .route("/", get(|| async {}))
/// .fallback(handler.into_service());
///
/// # async {
/// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
/// .serve(app.into_make_service())
/// .await?;
/// # Ok::<_, hyper::Error>(())
/// # };
/// ```
///
/// [`Router::fallback`]: crate::routing::Router::fallback
fn into_service(self) -> IntoService<Self, T, B> {
IntoService::new(self)
}
/// Convert the handler into a [`MakeService`].
///
/// This allows you to serve a single handler if you don't need any routing:
///
/// ```rust
/// use axum::{
/// Server, handler::Handler, http::{Uri, Method}, response::IntoResponse,
/// };
/// use std::net::SocketAddr;
///
/// async fn handler(method: Method, uri: Uri, body: String) -> String {
/// format!("received `{} {}` with body `{:?}`", method, uri, body)
/// }
///
/// # async {
/// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
/// .serve(handler.into_make_service())
/// .await?;
/// # Ok::<_, hyper::Error>(())
/// # };
/// ```
///
/// [`MakeService`]: tower::make::MakeService
fn into_make_service(self) -> IntoMakeService<IntoService<Self, T, B>> {
IntoMakeService::new(self.into_service())
}
/// Convert the handler into a [`MakeService`] which stores information
/// about the incoming connection.
///
/// See [`Router::into_make_service_with_connect_info`] for more details.
///
/// ```rust
/// use axum::{
/// Server,
/// handler::Handler,
/// response::IntoResponse,
/// extract::ConnectInfo,
/// };
/// use std::net::SocketAddr;
///
/// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
/// format!("Hello {}", addr)
/// }
///
/// # async {
/// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
/// .serve(handler.into_make_service_with_connect_info::<SocketAddr>())
/// .await?;
/// # Ok::<_, hyper::Error>(())
/// # };
/// ```
///
/// [`MakeService`]: tower::make::MakeService
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
fn into_make_service_with_connect_info<C>(
self,
) -> IntoMakeServiceWithConnectInfo<IntoService<Self, T, B>, C> {
IntoMakeServiceWithConnectInfo::new(self.into_service())
} }
} }
impl<F, Fut, Res, B> Handler<(), B> for F impl<F, Fut, Res, S, B> Handler<(), S, B> for F
where where
F: FnOnce() -> Fut + Clone + Send + 'static, F: FnOnce() -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send, Fut: Future<Output = Res> + Send,
@ -217,7 +166,7 @@ where
{ {
type Future = Pin<Box<dyn Future<Output = Response> + Send>>; type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, _req: Request<B>) -> Self::Future { fn call(self, _state: S, _req: Request<B>) -> Self::Future {
Box::pin(async move { self().await.into_response() }) Box::pin(async move { self().await.into_response() })
} }
} }
@ -225,19 +174,20 @@ where
macro_rules! impl_handler { macro_rules! impl_handler {
( $($ty:ident),* $(,)? ) => { ( $($ty:ident),* $(,)? ) => {
#[allow(non_snake_case)] #[allow(non_snake_case)]
impl<F, Fut, B, Res, $($ty,)*> Handler<($($ty,)*), B> for F impl<F, Fut, S, B, Res, $($ty,)*> Handler<($($ty,)*), S, B> for F
where where
F: FnOnce($($ty,)*) -> Fut + Clone + Send + 'static, F: FnOnce($($ty,)*) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send, Fut: Future<Output = Res> + Send,
B: Send + 'static, B: Send + 'static,
S: Send + 'static,
Res: IntoResponse, Res: IntoResponse,
$( $ty: FromRequest<B> + Send,)* $( $ty: FromRequest<S, B> + Send,)*
{ {
type Future = Pin<Box<dyn Future<Output = Response> + Send>>; type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, req: Request<B>) -> Self::Future { fn call(self, state: S, req: Request<B>) -> Self::Future {
Box::pin(async move { Box::pin(async move {
let mut req = RequestParts::new(req); let mut req = RequestParts::with_state(state, req);
$( $(
let $ty = match $ty::from_request(&mut req).await { let $ty = match $ty::from_request(&mut req).await {
@ -260,58 +210,116 @@ all_the_tuples!(impl_handler);
/// A [`Service`] created from a [`Handler`] by applying a Tower middleware. /// A [`Service`] created from a [`Handler`] by applying a Tower middleware.
/// ///
/// Created with [`Handler::layer`]. See that method for more details. /// Created with [`Handler::layer`]. See that method for more details.
pub struct Layered<S, T> { pub struct Layered<L, H, T, S, B> {
svc: S, layer: L,
_input: PhantomData<fn() -> T>, handler: H,
_marker: PhantomData<fn() -> (T, S, B)>,
} }
impl<S, T> fmt::Debug for Layered<S, T> impl<L, H, T, S, B> fmt::Debug for Layered<L, H, T, S, B>
where where
S: fmt::Debug, L: fmt::Debug,
{ {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Layered").field("svc", &self.svc).finish() f.debug_struct("Layered")
.field("layer", &self.layer)
.finish()
} }
} }
impl<S, T> Clone for Layered<S, T> impl<L, H, T, S, B> Clone for Layered<L, H, T, S, B>
where where
S: Clone, L: Clone,
H: Clone,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self::new(self.svc.clone()) Self {
layer: self.layer.clone(),
handler: self.handler.clone(),
_marker: PhantomData,
}
} }
} }
impl<S, T, ReqBody> Handler<T, ReqBody> for Layered<S, T> impl<H, S, T, B, L> Handler<T, S, B> for Layered<L, H, T, S, B>
where where
S: Service<Request<ReqBody>, Error = Infallible> + Clone + Send + 'static, L: Layer<WithState<H, T, S, B>> + Clone + Send + 'static,
S::Response: IntoResponse, H: Handler<T, S, B>,
S::Future: Send, L::Service: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
<L::Service as Service<Request<B>>>::Response: IntoResponse,
<L::Service as Service<Request<B>>>::Future: Send,
T: 'static, T: 'static,
ReqBody: Send + 'static, S: 'static,
B: Send + 'static,
{ {
type Future = future::LayeredFuture<S, ReqBody>; type Future = future::LayeredFuture<B, L::Service>;
fn call(self, req: Request<ReqBody>) -> Self::Future { fn call(self, state: S, req: Request<B>) -> Self::Future {
use futures_util::future::{FutureExt, Map}; use futures_util::future::{FutureExt, Map};
let future: Map<_, fn(Result<S::Response, S::Error>) -> _> = let svc = self.handler.with_state(state);
self.svc.oneshot(req).map(|result| match result { let svc = self.layer.layer(svc);
Ok(res) => res.into_response(),
Err(err) => match err {}, let future: Map<
}); _,
fn(
Result<
<L::Service as Service<Request<B>>>::Response,
<L::Service as Service<Request<B>>>::Error,
>,
) -> _,
> = svc.oneshot(req).map(|result| match result {
Ok(res) => res.into_response(),
Err(err) => match err {},
});
future::LayeredFuture::new(future) future::LayeredFuture::new(future)
} }
} }
impl<S, T> Layered<S, T> { /// Extension trait for [`Handler`]s that don't have state.
pub(crate) fn new(svc: S) -> Self { ///
Self { /// This provides convenience methods to convert the [`Handler`] into a [`Service`] or [`MakeService`].
svc, ///
_input: PhantomData, /// [`MakeService`]: tower::make::MakeService
} pub trait HandlerWithoutStateExt<T, B>: Handler<T, (), B> {
/// Convert the handler into a [`Service`] and no state.
fn into_service(self) -> WithState<Self, T, (), B>;
/// Convert the handler into a [`MakeService`] and no state.
///
/// See [`WithState::into_make_service`] for more details.
///
/// [`MakeService`]: tower::make::MakeService
fn into_make_service(self) -> IntoMakeService<IntoService<Self, T, (), B>>;
/// Convert the handler into a [`MakeService`] which stores information
/// about the incoming connection and has no state.
///
/// See [`WithState::into_make_service_with_connect_info`] for more details.
///
/// [`MakeService`]: tower::make::MakeService
fn into_make_service_with_connect_info<C>(
self,
) -> IntoMakeServiceWithConnectInfo<IntoService<Self, T, (), B>, C>;
}
impl<H, T, B> HandlerWithoutStateExt<T, B> for H
where
H: Handler<T, (), B>,
{
fn into_service(self) -> WithState<Self, T, (), B> {
self.with_state(())
}
fn into_make_service(self) -> IntoMakeService<IntoService<Self, T, (), B>> {
self.with_state(()).into_make_service()
}
fn into_make_service_with_connect_info<C>(
self,
) -> IntoMakeServiceWithConnectInfo<IntoService<Self, T, (), B>, C> {
self.with_state(()).into_make_service_with_connect_info()
} }
} }

View file

@ -0,0 +1,144 @@
use super::{Handler, IntoService};
use crate::{extract::connect_info::IntoMakeServiceWithConnectInfo, routing::IntoMakeService};
use http::Request;
use std::task::{Context, Poll};
use tower_service::Service;
/// A [`Handler`] which has access to some state.
///
/// Implements [`Service`].
///
/// The state can be extracted with [`State`](crate::extract::State).
///
/// Created with [`Handler::with_state`].
pub struct WithState<H, T, S, B> {
pub(super) service: IntoService<H, T, S, B>,
}
impl<H, T, S, B> WithState<H, T, S, B> {
/// Get a reference to the state.
pub fn state(&self) -> &S {
self.service.state()
}
}
impl<H, T, S, B> WithState<H, T, S, B> {
/// Convert the handler into a [`MakeService`].
///
/// This allows you to serve a single handler if you don't need any routing:
///
/// ```rust
/// use axum::{
/// Server,
/// handler::Handler,
/// extract::State,
/// http::{Uri, Method},
/// response::IntoResponse,
/// };
/// use std::net::SocketAddr;
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// async fn handler(State(state): State<AppState>) {
/// // ...
/// }
///
/// let app = handler.with_state(AppState {});
///
/// # async {
/// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
/// .serve(app.into_make_service())
/// .await?;
/// # Ok::<_, hyper::Error>(())
/// # };
/// ```
///
/// [`MakeService`]: tower::make::MakeService
pub fn into_make_service(self) -> IntoMakeService<IntoService<H, T, S, B>> {
IntoMakeService::new(self.service)
}
/// Convert the handler into a [`MakeService`] which stores information
/// about the incoming connection.
///
/// See [`Router::into_make_service_with_connect_info`] for more details.
///
/// ```rust
/// use axum::{
/// Server,
/// handler::Handler,
/// response::IntoResponse,
/// extract::{ConnectInfo, State},
/// };
/// use std::net::SocketAddr;
///
/// #[derive(Clone)]
/// struct AppState {};
///
/// async fn handler(
/// ConnectInfo(addr): ConnectInfo<SocketAddr>,
/// State(state): State<AppState>,
/// ) -> String {
/// format!("Hello {}", addr)
/// }
///
/// let app = handler.with_state(AppState {});
///
/// # async {
/// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
/// .serve(app.into_make_service_with_connect_info::<SocketAddr>())
/// .await?;
/// # Ok::<_, hyper::Error>(())
/// # };
/// ```
///
/// [`MakeService`]: tower::make::MakeService
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
pub fn into_make_service_with_connect_info<C>(
self,
) -> IntoMakeServiceWithConnectInfo<IntoService<H, T, S, B>, C> {
IntoMakeServiceWithConnectInfo::new(self.service)
}
}
impl<H, T, S, B> Service<Request<B>> for WithState<H, T, S, B>
where
H: Handler<T, S, B> + Clone + Send + 'static,
B: Send + 'static,
S: Clone,
{
type Response = <IntoService<H, T, S, B> as Service<Request<B>>>::Response;
type Error = <IntoService<H, T, S, B> as Service<Request<B>>>::Error;
type Future = <IntoService<H, T, S, B> as Service<Request<B>>>::Future;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
#[inline]
fn call(&mut self, req: Request<B>) -> Self::Future {
self.service.call(req)
}
}
impl<H, T, S, B> std::fmt::Debug for WithState<H, T, S, B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WithState")
.field("service", &self.service)
.finish()
}
}
impl<H, T, S, B> Clone for WithState<H, T, S, B>
where
H: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
service: self.service.clone(),
}
}
}

View file

@ -94,16 +94,17 @@ use std::ops::{Deref, DerefMut};
pub struct Json<T>(pub T); pub struct Json<T>(pub T);
#[async_trait] #[async_trait]
impl<T, B> FromRequest<B> for Json<T> impl<T, S, B> FromRequest<S, B> for Json<T>
where where
T: DeserializeOwned, T: DeserializeOwned,
B: HttpBody + Send, B: HttpBody + Send,
B::Data: Send, B::Data: Send,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
S: Send,
{ {
type Rejection = JsonRejection; type Rejection = JsonRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if json_content_type(req) { if json_content_type(req) {
let bytes = Bytes::from_request(req).await?; let bytes = Bytes::from_request(req).await?;
@ -136,7 +137,7 @@ where
} }
} }
fn json_content_type<B>(req: &RequestParts<B>) -> bool { fn json_content_type<S, B>(req: &RequestParts<S, B>) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
content_type content_type
} else { } else {

View file

@ -168,13 +168,48 @@
//! pool of database connections or clients to other services. //! pool of database connections or clients to other services.
//! //!
//! The two most common ways of doing that are: //! The two most common ways of doing that are:
//! - Using the [`State`] extractor.
//! - Using request extensions //! - Using request extensions
//! - Using closure captures //! - Using closure captures
//! //!
//! ## Using the [`State`] extractor
//!
//! ```rust,no_run
//! use axum::{
//! extract::State,
//! routing::get,
//! Router,
//! };
//! use std::sync::Arc;
//!
//! struct AppState {
//! // ...
//! }
//!
//! let shared_state = Arc::new(AppState { /* ... */ });
//!
//! let app = Router::with_state(shared_state)
//! .route("/", get(handler));
//!
//! async fn handler(
//! State(state): State<Arc<AppState>>,
//! ) {
//! // ...
//! }
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! You should prefer using [`State`] if possible since it's more type safe. The downside is that
//! its less dynamic than request extensions.
//!
//! See [`State`] for more details about accessing state.
//!
//! ## Using request extensions //! ## Using request extensions
//! //!
//! The easiest way to extract state in handlers is using [`Extension`](crate::extract::Extension) //! Another way to extract state in handlers is using [`Extension`](crate::extract::Extension) as
//! as layer and extractor: //! layer and extractor:
//! //!
//! ```rust,no_run //! ```rust,no_run
//! use axum::{ //! use axum::{
@ -184,18 +219,18 @@
//! }; //! };
//! use std::sync::Arc; //! use std::sync::Arc;
//! //!
//! struct State { //! struct AppState {
//! // ... //! // ...
//! } //! }
//! //!
//! let shared_state = Arc::new(State { /* ... */ }); //! let shared_state = Arc::new(AppState { /* ... */ });
//! //!
//! let app = Router::new() //! let app = Router::new()
//! .route("/", get(handler)) //! .route("/", get(handler))
//! .layer(Extension(shared_state)); //! .layer(Extension(shared_state));
//! //!
//! async fn handler( //! async fn handler(
//! Extension(state): Extension<Arc<State>>, //! Extension(state): Extension<Arc<AppState>>,
//! ) { //! ) {
//! // ... //! // ...
//! } //! }
@ -223,11 +258,11 @@
//! use std::sync::Arc; //! use std::sync::Arc;
//! use serde::Deserialize; //! use serde::Deserialize;
//! //!
//! struct State { //! struct AppState {
//! // ... //! // ...
//! } //! }
//! //!
//! let shared_state = Arc::new(State { /* ... */ }); //! let shared_state = Arc::new(AppState { /* ... */ });
//! //!
//! let app = Router::new() //! let app = Router::new()
//! .route( //! .route(
@ -245,11 +280,11 @@
//! }), //! }),
//! ); //! );
//! //!
//! async fn get_user(Path(user_id): Path<String>, state: Arc<State>) { //! async fn get_user(Path(user_id): Path<String>, state: Arc<AppState>) {
//! // ... //! // ...
//! } //! }
//! //!
//! async fn create_user(Json(payload): Json<CreateUserPayload>, state: Arc<State>) { //! async fn create_user(Json(payload): Json<CreateUserPayload>, state: Arc<AppState>) {
//! // ... //! // ...
//! } //! }
//! //!
@ -263,7 +298,7 @@
//! ``` //! ```
//! //!
//! The downside to this approach is that it's a little more verbose than using //! The downside to this approach is that it's a little more verbose than using
//! extensions. //! [`State`] or extensions.
//! //!
//! # Building integrations for axum //! # Building integrations for axum
//! //!
@ -350,6 +385,7 @@
//! [`Infallible`]: std::convert::Infallible //! [`Infallible`]: std::convert::Infallible
//! [load shed]: tower::load_shed //! [load shed]: tower::load_shed
//! [`axum-core`]: http://crates.io/crates/axum-core //! [`axum-core`]: http://crates.io/crates/axum-core
//! [`State`]: crate::extract::State
#![warn( #![warn(
clippy::all, clippy::all,

View file

@ -45,13 +45,14 @@ use tower_service::Service;
/// struct RequireAuth; /// struct RequireAuth;
/// ///
/// #[async_trait] /// #[async_trait]
/// impl<B> FromRequest<B> for RequireAuth /// impl<S, B> FromRequest<S, B> for RequireAuth
/// where /// where
/// B: Send, /// B: Send,
/// S: Send,
/// { /// {
/// type Rejection = StatusCode; /// type Rejection = StatusCode;
/// ///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { /// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// let auth_header = req /// let auth_header = req
/// .headers() /// .headers()
/// .get(header::AUTHORIZATION) /// .get(header::AUTHORIZATION)
@ -166,23 +167,23 @@ where
} }
} }
impl<S, E, ReqBody> Service<Request<ReqBody>> for FromExtractor<S, E> impl<S, E, B> Service<Request<B>> for FromExtractor<S, E>
where where
E: FromRequest<ReqBody> + 'static, E: FromRequest<(), B> + 'static,
ReqBody: Default + Send + 'static, B: Default + Send + 'static,
S: Service<Request<ReqBody>> + Clone, S: Service<Request<B>> + Clone,
S::Response: IntoResponse, S::Response: IntoResponse,
{ {
type Response = Response; type Response = Response;
type Error = S::Error; type Error = S::Error;
type Future = ResponseFuture<ReqBody, S, E>; type Future = ResponseFuture<B, S, E>;
#[inline] #[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx) self.inner.poll_ready(cx)
} }
fn call(&mut self, req: Request<ReqBody>) -> Self::Future { fn call(&mut self, req: Request<B>) -> Self::Future {
let extract_future = Box::pin(async move { let extract_future = Box::pin(async move {
let mut req = RequestParts::new(req); let mut req = RequestParts::new(req);
let extracted = E::from_request(&mut req).await; let extracted = E::from_request(&mut req).await;
@ -201,35 +202,37 @@ where
pin_project! { pin_project! {
/// Response future for [`FromExtractor`]. /// Response future for [`FromExtractor`].
#[allow(missing_debug_implementations)] #[allow(missing_debug_implementations)]
pub struct ResponseFuture<ReqBody, S, E> pub struct ResponseFuture<B, S, E>
where where
E: FromRequest<ReqBody>, E: FromRequest<(), B>,
S: Service<Request<ReqBody>>, S: Service<Request<B>>,
{ {
#[pin] #[pin]
state: State<ReqBody, S, E>, state: State<B, S, E>,
svc: Option<S>, svc: Option<S>,
} }
} }
pin_project! { pin_project! {
#[project = StateProj] #[project = StateProj]
enum State<ReqBody, S, E> enum State<B, S, E>
where where
E: FromRequest<ReqBody>, E: FromRequest<(), B>,
S: Service<Request<ReqBody>>, S: Service<Request<B>>,
{ {
Extracting { future: BoxFuture<'static, (RequestParts<ReqBody>, Result<E, E::Rejection>)> }, Extracting {
future: BoxFuture<'static, (RequestParts<(), B>, Result<E, E::Rejection>)>,
},
Call { #[pin] future: S::Future }, Call { #[pin] future: S::Future },
} }
} }
impl<ReqBody, S, E> Future for ResponseFuture<ReqBody, S, E> impl<B, S, E> Future for ResponseFuture<B, S, E>
where where
E: FromRequest<ReqBody>, E: FromRequest<(), B>,
S: Service<Request<ReqBody>>, S: Service<Request<B>>,
S::Response: IntoResponse, S::Response: IntoResponse,
ReqBody: Default, B: Default,
{ {
type Output = Result<Response, S::Error>; type Output = Result<Response, S::Error>;
@ -277,13 +280,14 @@ mod tests {
struct RequireAuth; struct RequireAuth;
#[async_trait::async_trait] #[async_trait::async_trait]
impl<B> FromRequest<B> for RequireAuth impl<S, B> FromRequest<S, B> for RequireAuth
where where
B: Send, B: Send,
S: Send,
{ {
type Rejection = StatusCode; type Rejection = StatusCode;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if let Some(auth) = req if let Some(auth) = req
.headers() .headers()
.get(header::AUTHORIZATION) .get(header::AUTHORIZATION)

Some files were not shown because too many files have changed in this diff Show more