Add type safe state extractor (#1155)

* begin threading the state through

* Pass state to extractors

* make state extractor work

* make sure nesting with different states work

* impl Service for MethodRouter<()>

* Fix some of axum-macro's tests

* Implement more traits for `State`

* Update examples to use `State`

* consistent naming of request body param

* swap type params

* Default the state param to ()

* fix docs references

* Docs and handler state refactoring

* docs clean ups

* more consistent naming

* when does MethodRouter implement Service?

* add missing docs

* use `Router`'s default state type param

* changelog

* don't use default type param for FromRequest and RequestParts

probably safer for library authors so you don't accidentally forget

* fix examples

* minor docs tweaks

* clarify how to convert handlers into services

* group methods in one impl block

* make sure merged `MethodRouter`s can access state

* fix docs link

* test merge with same state type

* Document how to access state from middleware

* Port cookie extractors to use state to extract keys (#1250)

* Updates ECOSYSTEM with a new sample project (#1252)

* Avoid unhelpful compiler suggestion (#1251)

* fix docs typo

* document how library authors should access state

* Add `RequestParts::with_state`

* fix example

* apply suggestions from review

* add relevant changes to axum-extra and axum-core changelogs

* Add `route_service_with_tsr`

* fix trybuild expectations

* make sure `SpaRouter` works with routers that have state

* Change order of type params on FromRequest and RequestParts

* reverse order of `RequestParts::with_state` args to match type params

* Add `FromRef` trait (#1268)

* Add `FromRef` trait

* Remove unnecessary type params

* format

* fix docs link

* format examples

* Avoid unnecessary `MethodRouter`

* apply suggestions from review

Co-authored-by: Dani Pardo <dani.pardo@inmensys.com>
Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
David Pedersen 2022-08-17 17:13:31 +02:00 committed by GitHub
parent 90dbd52ee4
commit 423308de3c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
132 changed files with 2404 additions and 1126 deletions

View file

@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- None.
- **breaking:** `FromRequest` and `RequestParts` has a new `S` type param which
represents the state ([#1155])
[#1155]: https://github.com/tokio-rs/axum/pull/1155
# 0.2.6 (18. June, 2022)

View file

@ -0,0 +1,23 @@
/// Used to do reference-to-value conversions thus not consuming the input value.
///
/// This is mainly used with [`State`] to extract "substates" from a reference to main application
/// state.
///
/// See [`State`] for more details on how library authors should use this trait.
///
/// [`State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html
// NOTE: This trait is defined in axum-core, even though it is mainly used with `State` which is
// defined in axum. That allows crate authors to use it when implementing extractors.
pub trait FromRef<T> {
/// Converts to this type from a reference to the input type.
fn from_ref(input: &T) -> Self;
}
impl<T> FromRef<T> for T
where
T: Clone,
{
fn from_ref(input: &T) -> Self {
input.clone()
}
}

View file

@ -12,9 +12,12 @@ use std::convert::Infallible;
pub mod rejection;
mod from_ref;
mod request_parts;
mod tuple;
pub use self::from_ref::FromRef;
/// Types that can be created from requests.
///
/// See [`axum::extract`] for more details.
@ -42,13 +45,15 @@ mod tuple;
/// struct MyExtractor;
///
/// #[async_trait]
/// impl<B> FromRequest<B> for MyExtractor
/// impl<S, B> FromRequest<S, B> for MyExtractor
/// where
/// B: Send, // required by `async_trait`
/// // these bounds are required by `async_trait`
/// B: Send,
/// S: Send,
/// {
/// type Rejection = http::StatusCode;
///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // ...
/// # unimplemented!()
/// }
@ -60,20 +65,21 @@ mod tuple;
/// [`http::Request<B>`]: http::Request
/// [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html
#[async_trait]
pub trait FromRequest<B>: Sized {
pub trait FromRequest<S, B>: Sized {
/// If the extractor fails it'll use this "rejection" type. A rejection is
/// a kind of error that can be converted into a response.
type Rejection: IntoResponse;
/// Perform the extraction.
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection>;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection>;
}
/// The type used with [`FromRequest`] to extract data from requests.
///
/// Has several convenience methods for getting owned parts of the request.
#[derive(Debug)]
pub struct RequestParts<B> {
pub struct RequestParts<S, B> {
state: S,
method: Method,
uri: Uri,
version: Version,
@ -82,8 +88,8 @@ pub struct RequestParts<B> {
body: Option<B>,
}
impl<B> RequestParts<B> {
/// Create a new `RequestParts`.
impl<B> RequestParts<(), B> {
/// Create a new `RequestParts` without any state.
///
/// You generally shouldn't need to construct this type yourself, unless
/// using extractors outside of axum for example to implement a
@ -91,6 +97,19 @@ impl<B> RequestParts<B> {
///
/// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html
pub fn new(req: Request<B>) -> Self {
Self::with_state((), req)
}
}
impl<S, B> RequestParts<S, B> {
/// Create a new `RequestParts` with the given state.
///
/// You generally shouldn't need to construct this type yourself, unless
/// using extractors outside of axum for example to implement a
/// [`tower::Service`].
///
/// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html
pub fn with_state(state: S, req: Request<B>) -> Self {
let (
http::request::Parts {
method,
@ -104,6 +123,7 @@ impl<B> RequestParts<B> {
) = req.into_parts();
RequestParts {
state,
method,
uri,
version,
@ -130,10 +150,14 @@ impl<B> RequestParts<B> {
/// use http::{Method, Uri};
///
/// #[async_trait]
/// impl<B: Send> FromRequest<B> for MyExtractor {
/// impl<S, B> FromRequest<S, B> for MyExtractor
/// where
/// B: Send,
/// S: Send,
/// {
/// type Rejection = Infallible;
///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Infallible> {
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Infallible> {
/// let method = req.extract::<Method>().await?;
/// let path = req.extract::<Uri>().await?.path().to_owned();
///
@ -141,7 +165,10 @@ impl<B> RequestParts<B> {
/// }
/// }
/// ```
pub async fn extract<E: FromRequest<B>>(&mut self) -> Result<E, E::Rejection> {
pub async fn extract<E>(&mut self) -> Result<E, E::Rejection>
where
E: FromRequest<S, B>,
{
E::from_request(self).await
}
@ -153,6 +180,7 @@ impl<B> RequestParts<B> {
/// [`take_body`]: RequestParts::take_body
pub fn try_into_request(self) -> Result<Request<B>, BodyAlreadyExtracted> {
let Self {
state: _,
method,
uri,
version,
@ -245,30 +273,37 @@ impl<B> RequestParts<B> {
pub fn take_body(&mut self) -> Option<B> {
self.body.take()
}
/// Get a reference to the state.
pub fn state(&self) -> &S {
&self.state
}
}
#[async_trait]
impl<T, B> FromRequest<B> for Option<T>
impl<S, T, B> FromRequest<S, B> for Option<T>
where
T: FromRequest<B>,
T: FromRequest<S, B>,
B: Send,
S: Send,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Option<T>, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Option<T>, Self::Rejection> {
Ok(T::from_request(req).await.ok())
}
}
#[async_trait]
impl<T, B> FromRequest<B> for Result<T, T::Rejection>
impl<S, T, B> FromRequest<S, B> for Result<T, T::Rejection>
where
T: FromRequest<B>,
T: FromRequest<S, B>,
B: Send,
S: Send,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(T::from_request(req).await)
}
}

View file

@ -6,16 +6,18 @@ use http::{Extensions, HeaderMap, Method, Request, Uri, Version};
use std::convert::Infallible;
#[async_trait]
impl<B> FromRequest<B> for Request<B>
impl<S, B> FromRequest<S, B> for Request<B>
where
B: Send,
S: Clone + Send,
{
type Rejection = BodyAlreadyExtracted;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let req = std::mem::replace(
req,
RequestParts {
state: req.state().clone(),
method: req.method.clone(),
version: req.version,
uri: req.uri.clone(),
@ -30,37 +32,40 @@ where
}
#[async_trait]
impl<B> FromRequest<B> for Method
impl<S, B> FromRequest<S, B> for Method
where
B: Send,
S: Send,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.method().clone())
}
}
#[async_trait]
impl<B> FromRequest<B> for Uri
impl<S, B> FromRequest<S, B> for Uri
where
B: Send,
S: Send,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.uri().clone())
}
}
#[async_trait]
impl<B> FromRequest<B> for Version
impl<S, B> FromRequest<S, B> for Version
where
B: Send,
S: Send,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.version())
}
}
@ -71,27 +76,29 @@ where
///
/// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html
#[async_trait]
impl<B> FromRequest<B> for HeaderMap
impl<S, B> FromRequest<S, B> for HeaderMap
where
B: Send,
S: Send,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.headers().clone())
}
}
#[async_trait]
impl<B> FromRequest<B> for Bytes
impl<S, B> FromRequest<S, B> for Bytes
where
B: http_body::Body + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
{
type Rejection = BytesRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?;
let bytes = crate::body::to_bytes(body)
@ -103,15 +110,16 @@ where
}
#[async_trait]
impl<B> FromRequest<B> for String
impl<S, B> FromRequest<S, B> for String
where
B: http_body::Body + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
{
type Rejection = StringRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?;
let bytes = crate::body::to_bytes(body)
@ -126,13 +134,14 @@ where
}
#[async_trait]
impl<B> FromRequest<B> for http::request::Parts
impl<S, B> FromRequest<S, B> for http::request::Parts
where
B: Send,
S: Send,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let method = unwrap_infallible(Method::from_request(req).await);
let uri = unwrap_infallible(Uri::from_request(req).await);
let version = unwrap_infallible(Version::from_request(req).await);
@ -159,6 +168,6 @@ fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
}
}
pub(crate) fn take_body<B>(req: &mut RequestParts<B>) -> Result<B, BodyAlreadyExtracted> {
pub(crate) fn take_body<S, B>(req: &mut RequestParts<S, B>) -> Result<B, BodyAlreadyExtracted> {
req.take_body().ok_or(BodyAlreadyExtracted)
}

View file

@ -4,13 +4,14 @@ use async_trait::async_trait;
use std::convert::Infallible;
#[async_trait]
impl<B> FromRequest<B> for ()
impl<S, B> FromRequest<S, B> for ()
where
B: Send,
S: Send,
{
type Rejection = Infallible;
async fn from_request(_: &mut RequestParts<B>) -> Result<(), Self::Rejection> {
async fn from_request(_: &mut RequestParts<S, B>) -> Result<(), Self::Rejection> {
Ok(())
}
}
@ -21,14 +22,15 @@ macro_rules! impl_from_request {
( $($ty:ident),* $(,)? ) => {
#[async_trait]
#[allow(non_snake_case)]
impl<B, $($ty,)*> FromRequest<B> for ($($ty,)*)
impl<S, B, $($ty,)*> FromRequest<S, B> for ($($ty,)*)
where
$( $ty: FromRequest<B> + Send, )*
$( $ty: FromRequest<S, B> + Send, )*
B: Send,
S: Send,
{
type Rejection = Response;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
$( let $ty = $ty::from_request(req).await.map_err(|err| err.into_response())?; )*
Ok(($($ty,)*))
}

View file

@ -17,15 +17,21 @@ and this project adheres to [Semantic Versioning].
literal `Response`
- **added:** Support chaining handlers with `HandlerCallWithExtractors::or` ([#1170])
- **change:** axum-extra's MSRV is now 1.60 ([#1239])
- **breaking:** `SignedCookieJar` and `PrivateCookieJar` now extracts the keys
from the router's state, rather than extensions
- **added:** Add Protocol Buffer extractor and response ([#1239])
- **added:** Add `Either*` types for combining extractors and responses into a
single type ([#1263])
- **added:** `WithRejection` extractor for customizing other extractors' rejections ([#1262])
- **added:** Add sync constructors to `CookieJar`, `PrivateCookieJar`, and
`SignedCookieJar` so they're easier to use in custom middleware
- **breaking:** `Resource` has a new `S` type param which represents the state ([#1155])
- **breaking:** `RouterExt::route_with_tsr` now only accepts `MethodRouter`s ([#1155])
- **added:** `RouterExt::route_service_with_tsr` for routing to any `Service` ([#1155])
[#1086]: https://github.com/tokio-rs/axum/pull/1086
[#1119]: https://github.com/tokio-rs/axum/pull/1119
[#1155]: https://github.com/tokio-rs/axum/pull/1155
[#1170]: https://github.com/tokio-rs/axum/pull/1170
[#1214]: https://github.com/tokio-rs/axum/pull/1214
[#1239]: https://github.com/tokio-rs/axum/pull/1239

View file

@ -190,15 +190,16 @@ macro_rules! impl_traits_for_either {
$last:ident $(,)?
) => {
#[async_trait]
impl<B, $($ident),*, $last> FromRequest<B> for $either<$($ident),*, $last>
impl<S, B, $($ident),*, $last> FromRequest<S, B> for $either<$($ident),*, $last>
where
$($ident: FromRequest<B>),*,
$last: FromRequest<B>,
$($ident: FromRequest<S, B>),*,
$last: FromRequest<S, B>,
B: Send,
S: Send,
{
type Rejection = $last::Rejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
$(
if let Ok(value) = req.extract().await {
return Ok(Self::$ident(value));

View file

@ -30,13 +30,14 @@ use std::ops::{Deref, DerefMut};
/// struct Session { /* ... */ }
///
/// #[async_trait]
/// impl<B> FromRequest<B> for Session
/// impl<S, B> FromRequest<S, B> for Session
/// where
/// B: Send,
/// S: Send,
/// {
/// type Rejection = (StatusCode, String);
///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // load session...
/// # unimplemented!()
/// }
@ -45,13 +46,14 @@ use std::ops::{Deref, DerefMut};
/// struct CurrentUser { /* ... */ }
///
/// #[async_trait]
/// impl<B> FromRequest<B> for CurrentUser
/// impl<S, B> FromRequest<S, B> for CurrentUser
/// where
/// B: Send,
/// S: Send,
/// {
/// type Rejection = Response;
///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // loading a `CurrentUser` requires first loading the `Session`
/// //
/// // by using `Cached<Session>` we avoid extracting the session more than
@ -88,14 +90,15 @@ pub struct Cached<T>(pub T);
struct CachedEntry<T>(T);
#[async_trait]
impl<B, T> FromRequest<B> for Cached<T>
impl<S, B, T> FromRequest<S, B> for Cached<T>
where
B: Send,
T: FromRequest<B> + Clone + Send + Sync + 'static,
S: Send,
T: FromRequest<S, B> + Clone + Send + Sync + 'static,
{
type Rejection = T::Rejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
match Extension::<CachedEntry<T>>::from_request(req).await {
Ok(Extension(CachedEntry(value))) => Ok(Self(value)),
Err(_) => {
@ -139,13 +142,14 @@ mod tests {
struct Extractor(Instant);
#[async_trait]
impl<B> FromRequest<B> for Extractor
impl<S, B> FromRequest<S, B> for Extractor
where
B: Send,
S: Send,
{
type Rejection = Infallible;
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
COUNTER.fetch_add(1, Ordering::SeqCst);
Ok(Self(Instant::now()))
}

View file

@ -80,7 +80,7 @@ pub use cookie::Key;
/// let app = Router::new()
/// .route("/sessions", post(create_session))
/// .route("/me", get(me));
/// # let app: Router<axum::body::Body> = app;
/// # let app: Router = app;
/// ```
#[derive(Debug, Default)]
pub struct CookieJar {
@ -88,13 +88,14 @@ pub struct CookieJar {
}
#[async_trait]
impl<B> FromRequest<B> for CookieJar
impl<S, B> FromRequest<S, B> for CookieJar
where
B: Send,
S: Send,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(Self::from_headers(req.headers()))
}
}
@ -226,7 +227,7 @@ fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) {
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, http::Request, routing::get, Extension, Router};
use axum::{body::Body, extract::FromRef, http::Request, routing::get, Router};
use tower::ServiceExt;
macro_rules! cookie_test {
@ -245,12 +246,15 @@ mod tests {
jar.remove(Cookie::named("key"))
}
let app = Router::<Body>::new()
let state = AppState {
key: Key::generate(),
custom_key: CustomKey(Key::generate()),
};
let app = Router::<_, Body>::with_state(state)
.route("/set", get(set_cookie))
.route("/get", get(get_cookie))
.route("/remove", get(remove_cookie))
.layer(Extension(Key::generate()))
.layer(Extension(CustomKey(Key::generate())));
.route("/remove", get(remove_cookie));
let res = app
.clone()
@ -298,6 +302,24 @@ mod tests {
cookie_test!(private_cookies, PrivateCookieJar);
cookie_test!(private_cookies_with_custom_key, PrivateCookieJar<CustomKey>);
#[derive(Clone)]
struct AppState {
key: Key,
custom_key: CustomKey,
}
impl FromRef<AppState> for Key {
fn from_ref(state: &AppState) -> Key {
state.key.clone()
}
}
impl FromRef<AppState> for CustomKey {
fn from_ref(state: &AppState) -> CustomKey {
state.custom_key.clone()
}
}
#[derive(Clone)]
struct CustomKey(Key);
@ -313,9 +335,12 @@ mod tests {
format!("{:?}", jar.get("key"))
}
let app = Router::<Body>::new()
.route("/get", get(get_cookie))
.layer(Extension(Key::generate()));
let state = AppState {
key: Key::generate(),
custom_key: CustomKey(Key::generate()),
};
let app = Router::<_, Body>::with_state(state).route("/get", get(get_cookie));
let res = app
.clone()

View file

@ -1,9 +1,8 @@
use super::{cookies_from_request, set_cookies, Cookie, Key};
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
extract::{FromRef, FromRequest, RequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
Extension,
};
use cookie::PrivateJar;
use http::HeaderMap;
@ -23,9 +22,8 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// ```rust
/// use axum::{
/// Router,
/// Extension,
/// routing::{post, get},
/// extract::TypedHeader,
/// extract::{TypedHeader, FromRef},
/// response::{IntoResponse, Redirect},
/// headers::authorization::{Authorization, Bearer},
/// http::StatusCode,
@ -45,22 +43,36 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// }
/// }
///
/// // our application state
/// #[derive(Clone)]
/// struct AppState {
/// // that holds the key used to sign cookies
/// key: Key,
/// }
///
/// // this impl tells `SignedCookieJar` how to access the key from our state
/// impl FromRef<AppState> for Key {
/// fn from_ref(state: &AppState) -> Self {
/// state.key.clone()
/// }
/// }
///
/// let state = AppState {
/// // Generate a secure key
/// //
/// // You probably don't wanna generate a new one each time the app starts though
/// let key = Key::generate();
/// key: Key::generate(),
/// };
///
/// let app = Router::new()
/// let app = Router::with_state(state)
/// .route("/set", post(set_secret))
/// .route("/get", get(get_secret))
/// // add extension with the key so `PrivateCookieJar` can access it
/// .layer(Extension(key));
/// # let app: Router<axum::body::Body> = app;
/// .route("/get", get(get_secret));
/// # let app: Router<_> = app;
/// ```
pub struct PrivateCookieJar<K = Key> {
jar: cookie::CookieJar,
key: Key,
// The key used to extract the key extension. Allows users to use multiple keys for different
// The key used to extract the key. Allows users to use multiple keys for different
// jars. Maybe a library wants its own key.
_marker: PhantomData<K>,
}
@ -75,15 +87,17 @@ impl<K> fmt::Debug for PrivateCookieJar<K> {
}
#[async_trait]
impl<B, K> FromRequest<B> for PrivateCookieJar<K>
impl<S, B, K> FromRequest<S, B> for PrivateCookieJar<K>
where
B: Send,
K: Into<Key> + Clone + Send + Sync + 'static,
S: Send,
K: FromRef<S> + Into<Key>,
{
type Rejection = <axum::Extension<K> as FromRequest<B>>::Rejection;
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let key = req.extract::<Extension<K>>().await?.0.into();
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let k = K::from_ref(req.state());
let key = k.into();
let PrivateCookieJar {
jar,
key,

View file

@ -1,9 +1,8 @@
use super::{cookies_from_request, set_cookies};
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
extract::{FromRef, FromRequest, RequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
Extension,
};
use cookie::SignedJar;
use cookie::{Cookie, Key};
@ -24,9 +23,8 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// ```rust
/// use axum::{
/// Router,
/// Extension,
/// routing::{post, get},
/// extract::TypedHeader,
/// extract::{TypedHeader, FromRef},
/// response::{IntoResponse, Redirect},
/// headers::authorization::{Authorization, Bearer},
/// http::StatusCode,
@ -63,22 +61,36 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// # todo!()
/// }
///
/// // our application state
/// #[derive(Clone)]
/// struct AppState {
/// // that holds the key used to sign cookies
/// key: Key,
/// }
///
/// // this impl tells `SignedCookieJar` how to access the key from our state
/// impl FromRef<AppState> for Key {
/// fn from_ref(state: &AppState) -> Self {
/// state.key.clone()
/// }
/// }
///
/// let state = AppState {
/// // Generate a secure key
/// //
/// // You probably don't wanna generate a new one each time the app starts though
/// let key = Key::generate();
/// key: Key::generate(),
/// };
///
/// let app = Router::new()
/// let app = Router::with_state(state)
/// .route("/sessions", post(create_session))
/// .route("/me", get(me))
/// // add extension with the key so `SignedCookieJar` can access it
/// .layer(Extension(key));
/// # let app: Router<axum::body::Body> = app;
/// .route("/me", get(me));
/// # let app: Router<_> = app;
/// ```
pub struct SignedCookieJar<K = Key> {
jar: cookie::CookieJar,
key: Key,
// The key used to extract the key extension. Allows users to use multiple keys for different
// The key used to extract the key. Allows users to use multiple keys for different
// jars. Maybe a library wants its own key.
_marker: PhantomData<K>,
}
@ -93,15 +105,17 @@ impl<K> fmt::Debug for SignedCookieJar<K> {
}
#[async_trait]
impl<B, K> FromRequest<B> for SignedCookieJar<K>
impl<S, B, K> FromRequest<S, B> for SignedCookieJar<K>
where
B: Send,
K: Into<Key> + Clone + Send + Sync + 'static,
S: Send,
K: FromRef<S> + Into<Key>,
{
type Rejection = <axum::Extension<K> as FromRequest<B>>::Rejection;
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let key = req.extract::<Extension<K>>().await?.0.into();
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let k = K::from_ref(req.state());
let key = k.into();
let SignedCookieJar {
jar,
key,

View file

@ -55,16 +55,17 @@ impl<T> Deref for Form<T> {
}
#[async_trait]
impl<T, B> FromRequest<B> for Form<T>
impl<T, S, B> FromRequest<S, B> for Form<T>
where
T: DeserializeOwned,
B: HttpBody + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
{
type Rejection = FormRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if req.method() == Method::GET {
let query = req.uri().query().unwrap_or_default();
let value = serde_html_form::from_str(query)
@ -85,7 +86,7 @@ where
}
// this is duplicated in `axum/src/extract/mod.rs`
fn has_content_type<B>(req: &RequestParts<B>, expected_content_type: &mime::Mime) -> bool {
fn has_content_type<S, B>(req: &RequestParts<S, B>, expected_content_type: &mime::Mime) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
content_type
} else {

View file

@ -58,14 +58,15 @@ use std::ops::Deref;
pub struct Query<T>(pub T);
#[async_trait]
impl<T, B> FromRequest<B> for Query<T>
impl<T, S, B> FromRequest<S, B> for Query<T>
where
T: DeserializeOwned,
B: Send,
S: Send,
{
type Rejection = QueryRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let query = req.uri().query().unwrap_or_default();
let value = serde_html_form::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?;

View file

@ -107,15 +107,16 @@ impl<E, R> DerefMut for WithRejection<E, R> {
}
#[async_trait]
impl<B, E, R> FromRequest<B> for WithRejection<E, R>
impl<B, E, R, S> FromRequest<S, B> for WithRejection<E, R>
where
B: Send,
E: FromRequest<B>,
S: Send,
E: FromRequest<S, B>,
R: From<E::Rejection> + IntoResponse,
{
type Rejection = R;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let extractor = req.extract::<E>().await?;
Ok(WithRejection(extractor, PhantomData))
}
@ -134,10 +135,14 @@ mod tests {
struct TestRejection;
#[async_trait]
impl<B: Send> FromRequest<B> for TestExtractor {
impl<S, B> FromRequest<S, B> for TestExtractor
where
B: Send,
S: Send,
{
type Rejection = ();
async fn from_request(_: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(_: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Err(())
}
}

View file

@ -19,15 +19,15 @@ pub use self::or::Or;
///
/// The drawbacks of this trait is that you cannot apply middleware to individual handlers like you
/// can with [`Handler::layer`].
pub trait HandlerCallWithExtractors<T, B>: Sized {
pub trait HandlerCallWithExtractors<T, S, B>: Sized {
/// The type of future calling this handler returns.
type Future: Future<Output = Response> + Send + 'static;
/// Call the handler with the extracted inputs.
fn call(self, extractors: T) -> <Self as HandlerCallWithExtractors<T, B>>::Future;
fn call(self, state: S, extractors: T) -> <Self as HandlerCallWithExtractors<T, S, B>>::Future;
/// Conver this `HandlerCallWithExtractors` into [`Handler`].
fn into_handler(self) -> IntoHandler<Self, T, B> {
fn into_handler(self) -> IntoHandler<Self, T, S, B> {
IntoHandler {
handler: self,
_marker: PhantomData,
@ -67,10 +67,14 @@ pub trait HandlerCallWithExtractors<T, B>: Sized {
/// struct AdminPermissions {}
///
/// #[async_trait]
/// impl<B: Send> FromRequest<B> for AdminPermissions {
/// impl<S, B> FromRequest<S, B> for AdminPermissions
/// where
/// B: Send,
/// S: Send,
/// {
/// // check for admin permissions...
/// # type Rejection = ();
/// # async fn from_request(req: &mut axum::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
/// # async fn from_request(req: &mut axum::extract::RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// # todo!()
/// # }
/// }
@ -78,10 +82,14 @@ pub trait HandlerCallWithExtractors<T, B>: Sized {
/// struct User {}
///
/// #[async_trait]
/// impl<B: Send> FromRequest<B> for User {
/// impl<S, B> FromRequest<S, B> for User
/// where
/// B: Send,
/// S: Send,
/// {
/// // check for a logged in user...
/// # type Rejection = ();
/// # async fn from_request(req: &mut axum::extract::RequestParts<B>) -> Result<Self, Self::Rejection> {
/// # async fn from_request(req: &mut axum::extract::RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// # todo!()
/// # }
/// }
@ -96,9 +104,9 @@ pub trait HandlerCallWithExtractors<T, B>: Sized {
/// );
/// # let _: Router = app;
/// ```
fn or<R, Rt>(self, rhs: R) -> Or<Self, R, T, Rt, B>
fn or<R, Rt>(self, rhs: R) -> Or<Self, R, T, Rt, S, B>
where
R: HandlerCallWithExtractors<Rt, B>,
R: HandlerCallWithExtractors<Rt, S, B>,
{
Or {
lhs: self,
@ -111,7 +119,7 @@ pub trait HandlerCallWithExtractors<T, B>: Sized {
macro_rules! impl_handler_call_with {
( $($ty:ident),* $(,)? ) => {
#[allow(non_snake_case)]
impl<F, Fut, B, $($ty,)*> HandlerCallWithExtractors<($($ty,)*), B> for F
impl<F, Fut, S, B, $($ty,)*> HandlerCallWithExtractors<($($ty,)*), S, B> for F
where
F: FnOnce($($ty,)*) -> Fut,
Fut: Future + Send + 'static,
@ -122,8 +130,9 @@ macro_rules! impl_handler_call_with {
fn call(
self,
_state: S,
($($ty,)*): ($($ty,)*),
) -> <Self as HandlerCallWithExtractors<($($ty,)*), B>>::Future {
) -> <Self as HandlerCallWithExtractors<($($ty,)*), S, B>>::Future {
self($($ty,)*).map(IntoResponse::into_response)
}
}
@ -152,34 +161,35 @@ impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13,
///
/// Created with [`HandlerCallWithExtractors::into_handler`].
#[allow(missing_debug_implementations)]
pub struct IntoHandler<H, T, B> {
pub struct IntoHandler<H, T, S, B> {
handler: H,
_marker: PhantomData<fn() -> (T, B)>,
_marker: PhantomData<fn() -> (T, S, B)>,
}
impl<H, T, B> Handler<T, B> for IntoHandler<H, T, B>
impl<H, T, S, B> Handler<T, S, B> for IntoHandler<H, T, S, B>
where
H: HandlerCallWithExtractors<T, B> + Clone + Send + 'static,
T: FromRequest<B> + Send + 'static,
H: HandlerCallWithExtractors<T, S, B> + Clone + Send + 'static,
T: FromRequest<S, B> + Send + 'static,
T::Rejection: Send,
B: Send + 'static,
S: Clone + Send + 'static,
{
type Future = BoxFuture<'static, Response>;
fn call(self, req: http::Request<B>) -> Self::Future {
fn call(self, state: S, req: http::Request<B>) -> Self::Future {
Box::pin(async move {
let mut req = RequestParts::new(req);
let mut req = RequestParts::with_state(state.clone(), req);
match req.extract::<T>().await {
Ok(t) => self.handler.call(t).await,
Ok(t) => self.handler.call(state, t).await,
Err(rejection) => rejection.into_response(),
}
})
}
}
impl<H, T, B> Copy for IntoHandler<H, T, B> where H: Copy {}
impl<H, T, S, B> Copy for IntoHandler<H, T, S, B> where H: Copy {}
impl<H, T, B> Clone for IntoHandler<H, T, B>
impl<H, T, S, B> Clone for IntoHandler<H, T, S, B>
where
H: Clone,
{

View file

@ -15,16 +15,16 @@ use std::{future::Future, marker::PhantomData};
///
/// Created with [`HandlerCallWithExtractors::or`](super::HandlerCallWithExtractors::or).
#[allow(missing_debug_implementations)]
pub struct Or<L, R, Lt, Rt, B> {
pub struct Or<L, R, Lt, Rt, S, B> {
pub(super) lhs: L,
pub(super) rhs: R,
pub(super) _marker: PhantomData<fn() -> (Lt, Rt, B)>,
pub(super) _marker: PhantomData<fn() -> (Lt, Rt, S, B)>,
}
impl<B, L, R, Lt, Rt> HandlerCallWithExtractors<Either<Lt, Rt>, B> for Or<L, R, Lt, Rt, B>
impl<S, B, L, R, Lt, Rt> HandlerCallWithExtractors<Either<Lt, Rt>, S, B> for Or<L, R, Lt, Rt, S, B>
where
L: HandlerCallWithExtractors<Lt, B> + Send + 'static,
R: HandlerCallWithExtractors<Rt, B> + Send + 'static,
L: HandlerCallWithExtractors<Lt, S, B> + Send + 'static,
R: HandlerCallWithExtractors<Rt, S, B> + Send + 'static,
Rt: Send + 'static,
Lt: Send + 'static,
B: Send + 'static,
@ -37,46 +37,48 @@ where
fn call(
self,
state: S,
extractors: Either<Lt, Rt>,
) -> <Self as HandlerCallWithExtractors<Either<Lt, Rt>, B>>::Future {
) -> <Self as HandlerCallWithExtractors<Either<Lt, Rt>, S, B>>::Future {
match extractors {
Either::E1(lt) => self
.lhs
.call(lt)
.call(state, lt)
.map(IntoResponse::into_response as _)
.left_future(),
Either::E2(rt) => self
.rhs
.call(rt)
.call(state, rt)
.map(IntoResponse::into_response as _)
.right_future(),
}
}
}
impl<B, L, R, Lt, Rt> Handler<(Lt, Rt), B> for Or<L, R, Lt, Rt, B>
impl<S, B, L, R, Lt, Rt> Handler<(Lt, Rt), S, B> for Or<L, R, Lt, Rt, S, B>
where
L: HandlerCallWithExtractors<Lt, B> + Clone + Send + 'static,
R: HandlerCallWithExtractors<Rt, B> + Clone + Send + 'static,
Lt: FromRequest<B> + Send + 'static,
Rt: FromRequest<B> + Send + 'static,
L: HandlerCallWithExtractors<Lt, S, B> + Clone + Send + 'static,
R: HandlerCallWithExtractors<Rt, S, B> + Clone + Send + 'static,
Lt: FromRequest<S, B> + Send + 'static,
Rt: FromRequest<S, B> + Send + 'static,
Lt::Rejection: Send,
Rt::Rejection: Send,
B: Send + 'static,
S: Clone + Send + 'static,
{
// this puts `futures_util` in our public API but thats fine in axum-extra
type Future = BoxFuture<'static, Response>;
fn call(self, req: Request<B>) -> Self::Future {
fn call(self, state: S, req: Request<B>) -> Self::Future {
Box::pin(async move {
let mut req = RequestParts::new(req);
let mut req = RequestParts::with_state(state.clone(), req);
if let Ok(lt) = req.extract::<Lt>().await {
return self.lhs.call(lt).await;
return self.lhs.call(state, lt).await;
}
if let Ok(rt) = req.extract::<Rt>().await {
return self.rhs.call(rt).await;
return self.rhs.call(state, rt).await;
}
StatusCode::NOT_FOUND.into_response()
@ -84,14 +86,14 @@ where
}
}
impl<L, R, Lt, Rt, B> Copy for Or<L, R, Lt, Rt, B>
impl<L, R, Lt, Rt, S, B> Copy for Or<L, R, Lt, Rt, S, B>
where
L: Copy,
R: Copy,
{
}
impl<L, R, Lt, Rt, B> Clone for Or<L, R, Lt, Rt, B>
impl<L, R, Lt, Rt, S, B> Clone for Or<L, R, Lt, Rt, S, B>
where
L: Clone,
R: Clone,

View file

@ -98,16 +98,17 @@ impl<S> JsonLines<S, AsResponse> {
}
#[async_trait]
impl<B, T> FromRequest<B> for JsonLines<T, AsExtractor>
impl<S, B, T> FromRequest<S, B> for JsonLines<T, AsExtractor>
where
B: HttpBody + Send + 'static,
B::Data: Into<Bytes>,
B::Error: Into<BoxError>,
T: DeserializeOwned,
S: Send,
{
type Rejection = BodyAlreadyExtracted;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
// `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead`
// so we can call `AsyncRead::lines` and then convert it back to a `Stream`

View file

@ -97,16 +97,17 @@ use std::ops::{Deref, DerefMut};
pub struct ProtoBuf<T>(pub T);
#[async_trait]
impl<T, B> FromRequest<B> for ProtoBuf<T>
impl<T, S, B> FromRequest<S, B> for ProtoBuf<T>
where
T: Message + Default,
B: HttpBody + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
{
type Rejection = ProtoBufRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let mut bytes = Bytes::from_request(req).await?;
match T::decode(&mut bytes) {

View file

@ -1,12 +1,13 @@
//! Additional types for defining routes.
use axum::{
handler::Handler,
handler::{Handler, HandlerWithoutStateExt},
http::Request,
response::{IntoResponse, Redirect},
routing::{any, MethodRouter},
Router,
};
use std::{convert::Infallible, future::ready};
use std::{convert::Infallible, future::ready, sync::Arc};
use tower_service::Service;
mod resource;
@ -29,7 +30,7 @@ pub use self::typed::{FirstElementIs, TypedPath};
pub use self::spa::SpaRouter;
/// Extension trait that adds additional methods to [`Router`].
pub trait RouterExt<B>: sealed::Sealed {
pub trait RouterExt<S, B>: sealed::Sealed {
/// Add a typed `GET` route to the router.
///
/// The path will be inferred from the first argument to the handler function which must
@ -39,7 +40,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_get<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -52,7 +53,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_delete<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -65,7 +66,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_head<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -78,7 +79,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_options<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -91,7 +92,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_patch<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -104,7 +105,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_post<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -117,7 +118,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_put<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -130,7 +131,7 @@ pub trait RouterExt<B>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_trace<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -159,7 +160,14 @@ pub trait RouterExt<B>: sealed::Sealed {
/// .route_with_tsr("/bar/", get(|| async {}));
/// # let _: Router = app;
/// ```
fn route_with_tsr<T>(self, path: &str, service: T) -> Self
fn route_with_tsr(self, path: &str, method_router: MethodRouter<S, B>) -> Self
where
Self: Sized;
/// Add another route to the router with an additional "trailing slash redirect" route.
///
/// This works like [`RouterExt::route_with_tsr`] but accepts any [`Service`].
fn route_service_with_tsr<T>(self, path: &str, service: T) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
@ -167,14 +175,15 @@ pub trait RouterExt<B>: sealed::Sealed {
Self: Sized;
}
impl<B> RouterExt<B> for Router<B>
impl<S, B> RouterExt<S, B> for Router<S, B>
where
B: axum::body::HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
#[cfg(feature = "typed-routing")]
fn typed_get<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
@ -184,7 +193,7 @@ where
#[cfg(feature = "typed-routing")]
fn typed_delete<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
@ -194,7 +203,7 @@ where
#[cfg(feature = "typed-routing")]
fn typed_head<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
@ -204,7 +213,7 @@ where
#[cfg(feature = "typed-routing")]
fn typed_options<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
@ -214,7 +223,7 @@ where
#[cfg(feature = "typed-routing")]
fn typed_patch<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
@ -224,7 +233,7 @@ where
#[cfg(feature = "typed-routing")]
fn typed_post<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
@ -234,7 +243,7 @@ where
#[cfg(feature = "typed-routing")]
fn typed_put<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
@ -244,41 +253,56 @@ where
#[cfg(feature = "typed-routing")]
fn typed_trace<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::trace(handler))
}
fn route_with_tsr<T>(mut self, path: &str, service: T) -> Self
fn route_with_tsr(mut self, path: &str, method_router: MethodRouter<S, B>) -> Self
where
Self: Sized,
{
self = self.route(path, method_router);
let redirect_service = {
let path: Arc<str> = path.into();
(move || ready(Redirect::permanent(&path))).into_service()
};
if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
self.route_service(path_without_trailing_slash, redirect_service)
} else {
self.route_service(&format!("{}/", path), redirect_service)
}
}
fn route_service_with_tsr<T>(mut self, path: &str, service: T) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
Self: Sized,
{
self = self.route(path, service);
self = self.route_service(path, service);
let redirect = Redirect::permanent(path);
if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
self.route(
path_without_trailing_slash,
(move || ready(redirect.clone())).into_service(),
any(move || ready(redirect.clone())),
)
} else {
self.route(
&format!("{}/", path),
(move || ready(redirect.clone())).into_service(),
)
self.route(&format!("{}/", path), any(move || ready(redirect.clone())))
}
}
}
mod sealed {
pub trait Sealed {}
impl<B> Sealed for axum::Router<B> {}
impl<S, B> Sealed for axum::Router<S, B> {}
}
#[cfg(test)]

View file

@ -1,13 +1,9 @@
use axum::{
body::Body,
handler::Handler,
http::Request,
response::IntoResponse,
routing::{delete, get, on, post, MethodFilter},
routing::{delete, get, on, post, MethodFilter, MethodRouter},
Router,
};
use std::{convert::Infallible, fmt};
use tower_service::Service;
/// A resource which defines a set of conventional CRUD routes.
///
@ -34,14 +30,15 @@ use tower_service::Service;
/// .destroy(|Path(user_id): Path<u64>| async {});
///
/// let app = Router::new().merge(users);
/// # let _: Router<axum::body::Body> = app;
/// # let _: Router = app;
/// ```
pub struct Resource<B = Body> {
#[derive(Debug)]
pub struct Resource<S = (), B = Body> {
pub(crate) name: String,
pub(crate) router: Router<B>,
pub(crate) router: Router<S, B>,
}
impl<B> Resource<B>
impl<B> Resource<(), B>
where
B: axum::body::HttpBody + Send + 'static,
{
@ -49,16 +46,29 @@ where
///
/// All routes will be nested at `/{resource_name}`.
pub fn named(resource_name: &str) -> Self {
Self::named_with((), resource_name)
}
}
impl<S, B> Resource<S, B>
where
B: axum::body::HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
/// Create a `Resource` with the given name and state.
///
/// All routes will be nested at `/{resource_name}`.
pub fn named_with(state: S, resource_name: &str) -> Self {
Self {
name: resource_name.to_owned(),
router: Default::default(),
router: Router::with_state(state),
}
}
/// Add a handler at `GET /{resource_name}`.
pub fn index<H, T>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: 'static,
{
let path = self.index_create_path();
@ -68,7 +78,7 @@ where
/// Add a handler at `POST /{resource_name}`.
pub fn create<H, T>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: 'static,
{
let path = self.index_create_path();
@ -78,7 +88,7 @@ where
/// Add a handler at `GET /{resource_name}/new`.
pub fn new<H, T>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: 'static,
{
let path = format!("/{}/new", self.name);
@ -88,7 +98,7 @@ where
/// Add a handler at `GET /{resource_name}/:{resource_name}_id`.
pub fn show<H, T>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: 'static,
{
let path = self.show_update_destroy_path();
@ -98,7 +108,7 @@ where
/// Add a handler at `GET /{resource_name}/:{resource_name}_id/edit`.
pub fn edit<H, T>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: 'static,
{
let path = format!("/{0}/:{0}_id/edit", self.name);
@ -108,7 +118,7 @@ where
/// Add a handler at `PUT or PATCH /resource_name/:{resource_name}_id`.
pub fn update<H, T>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: 'static,
{
let path = self.show_update_destroy_path();
@ -118,7 +128,7 @@ where
/// Add a handler at `DELETE /{resource_name}/:{resource_name}_id`.
pub fn destroy<H, T>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<T, S, B>,
T: 'static,
{
let path = self.show_update_destroy_path();
@ -133,13 +143,8 @@ where
format!("/{0}/:{0}_id", self.name)
}
fn route<T>(mut self, path: &str, svc: T) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
self.router = self.router.route(path, svc);
fn route(mut self, path: &str, method_router: MethodRouter<S, B>) -> Self {
self.router = self.router.route(path, method_router);
self
}
}
@ -150,21 +155,13 @@ impl<B> From<Resource<B>> for Router<B> {
}
}
impl<B> fmt::Debug for Resource<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Resource")
.field("name", &self.name)
.field("router", &self.router)
.finish()
}
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
use axum::{extract::Path, http::Method, Router};
use tower::ServiceExt;
use http::Request;
use tower::{Service, ServiceExt};
#[tokio::test]
async fn works() {
@ -220,7 +217,7 @@ mod tests {
);
}
async fn call_route(app: &mut Router, method: Method, uri: &str) -> String {
async fn call_route(app: &mut Router<()>, method: Method, uri: &str) -> String {
let res = app
.ready()
.await

View file

@ -36,7 +36,7 @@ use tower_service::Service;
/// .merge(spa)
/// // we can still add other routes
/// .route("/api/foo", get(api_foo));
/// # let _: Router<axum::body::Body> = app;
/// # let _: Router = app;
///
/// async fn api_foo() {}
/// ```
@ -101,7 +101,7 @@ impl<B, T, F> SpaRouter<B, T, F> {
/// .index_file("another_file.html");
///
/// let app = Router::new().merge(spa);
/// # let _: Router<axum::body::Body> = app;
/// # let _: Router = app;
/// ```
pub fn index_file<P>(mut self, path: P) -> Self
where
@ -136,7 +136,7 @@ impl<B, T, F> SpaRouter<B, T, F> {
/// }
///
/// let app = Router::new().merge(spa);
/// # let _: Router<axum::body::Body> = app;
/// # let _: Router = app;
/// ```
pub fn handle_error<T2, F2>(self, f: F2) -> SpaRouter<B, T2, F2> {
SpaRouter {
@ -147,7 +147,7 @@ impl<B, T, F> SpaRouter<B, T, F> {
}
}
impl<B, F, T> From<SpaRouter<B, T, F>> for Router<B>
impl<B, F, T> From<SpaRouter<B, T, F>> for Router<(), B>
where
F: Clone + Send + 'static,
HandleError<Route<B, io::Error>, F, T>: Service<Request<B>, Error = Infallible>,
@ -162,7 +162,7 @@ where
Router::new()
.nest(&spa.paths.assets_path, assets_service)
.fallback(
.fallback_service(
get_service(ServeFile::new(&spa.paths.index_file)).handle_error(spa.handle_error),
)
}
@ -264,6 +264,13 @@ mod tests {
let spa = SpaRouter::new("/assets", "test_files").handle_error(handle_error);
Router::<Body>::new().merge(spa);
Router::<_, Body>::new().merge(spa);
}
#[allow(dead_code)]
fn works_with_router_with_state() {
let _: Router<String> = Router::with_state(String::new())
.merge(SpaRouter::new("/assets", "test_files"))
.route("/", get(|_: axum::extract::State<String>| async {}));
}
}

View file

@ -60,7 +60,7 @@ use http::Uri;
/// async fn users_destroy(_: UsersCollection) { /* ... */ }
///
/// #
/// # let app: Router<axum::body::Body> = app;
/// # let app: Router = app;
/// ```
///
/// # Using `#[derive(TypedPath)]`

View file

@ -203,7 +203,7 @@ fn check_inputs_impls_from_request(item_fn: &ItemFn, body_ty: &Type) -> TokenStr
#[allow(warnings)]
fn #name()
where
#ty: ::axum::extract::FromRequest<#body_ty> + Send,
#ty: ::axum::extract::FromRequest<(), #body_ty> + Send,
{}
}
})

View file

@ -218,16 +218,17 @@ fn impl_struct_by_extracting_each_field(
Ok(quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where
B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
B::Data: ::std::marker::Send,
B::Error: ::std::convert::Into<::axum::BoxError>,
S: Send,
{
type Rejection = #rejection_ident;
async fn from_request(
req: &mut ::axum::extract::RequestParts<B>,
req: &mut ::axum::extract::RequestParts<S, B>,
) -> ::std::result::Result<Self, Self::Rejection> {
::std::result::Result::Ok(Self {
#(#extract_fields)*
@ -422,7 +423,7 @@ fn extract_each_field_rejection(
Ok(quote_spanned! {ty_span=>
#[allow(non_camel_case_types)]
#variant_name(<#extractor_ty as ::axum::extract::FromRequest<::axum::body::Body>>::Rejection),
#variant_name(<#extractor_ty as ::axum::extract::FromRequest<(), ::axum::body::Body>>::Rejection),
})
})
.collect::<syn::Result<Vec<_>>>()?;
@ -609,26 +610,26 @@ fn impl_struct_by_extracting_all_at_once(
quote! { #rejection }
} else {
quote! {
<#path<Self> as ::axum::extract::FromRequest<B>>::Rejection
<#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection
}
};
let rejection_bound = rejection.as_ref().map(|rejection| {
if generic_ident.is_some() {
quote! {
#rejection: ::std::convert::From<<#path<T> as ::axum::extract::FromRequest<B>>::Rejection>,
#rejection: ::std::convert::From<<#path<T> as ::axum::extract::FromRequest<S, B>>::Rejection>,
}
} else {
quote! {
#rejection: ::std::convert::From<<#path<Self> as ::axum::extract::FromRequest<B>>::Rejection>,
#rejection: ::std::convert::From<<#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection>,
}
}
}).unwrap_or_default();
let impl_generics = if generic_ident.is_some() {
quote! { B, T }
quote! { S, B, T }
} else {
quote! { B }
quote! { S, B }
};
let type_generics = generic_ident
@ -653,18 +654,19 @@ fn impl_struct_by_extracting_all_at_once(
Ok(quote_spanned! {path_span=>
#[::axum::async_trait]
#[automatically_derived]
impl<#impl_generics> ::axum::extract::FromRequest<B> for #ident #type_generics
impl<#impl_generics> ::axum::extract::FromRequest<S, B> for #ident #type_generics
where
#path<#via_type_generics>: ::axum::extract::FromRequest<B>,
#path<#via_type_generics>: ::axum::extract::FromRequest<S, B>,
#rejection_bound
B: ::std::marker::Send,
S: ::std::marker::Send,
{
type Rejection = #associated_rejection_type;
async fn from_request(
req: &mut ::axum::extract::RequestParts<B>,
req: &mut ::axum::extract::RequestParts<S, B>,
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequest::<B>::from_request(req)
::axum::extract::FromRequest::<S, B>::from_request(req)
.await
.map(|#path(value)| #value_to_self)
.map_err(::std::convert::From::from)
@ -709,7 +711,7 @@ fn impl_enum_by_extracting_all_at_once(
quote! { #rejection }
} else {
quote! {
<#path<Self> as ::axum::extract::FromRequest<B>>::Rejection
<#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection
}
};
@ -718,18 +720,19 @@ fn impl_enum_by_extracting_all_at_once(
Ok(quote_spanned! {path_span=>
#[::axum::async_trait]
#[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where
B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
B::Data: ::std::marker::Send,
B::Error: ::std::convert::Into<::axum::BoxError>,
S: ::std::marker::Send,
{
type Rejection = #associated_rejection_type;
async fn from_request(
req: &mut ::axum::extract::RequestParts<B>,
req: &mut ::axum::extract::RequestParts<S, B>,
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequest::<B>::from_request(req)
::axum::extract::FromRequest::<S, B>::from_request(req)
.await
.map(|#path(inner)| inner)
.map_err(::std::convert::From::from)

View file

@ -125,7 +125,7 @@ mod typed_path;
/// ```
/// pub struct ViaExtractor<T>(pub T);
///
/// // impl<T, B> FromRequest<B> for ViaExtractor<T> { ... }
/// // impl<T, S, B> FromRequest<S, B> for ViaExtractor<T> { ... }
/// ```
///
/// More complex via extractors are not supported and require writing a manual implementation.
@ -223,14 +223,15 @@ mod typed_path;
/// struct OtherExtractor;
///
/// #[async_trait]
/// impl<B> FromRequest<B> for OtherExtractor
/// impl<S, B> FromRequest<S, B> for OtherExtractor
/// where
/// B: Send + 'static,
/// B: Send,
/// S: Send,
/// {
/// // this rejection doesn't implement `Display` and `Error`
/// type Rejection = (StatusCode, String);
///
/// async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
/// async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // ...
/// # unimplemented!()
/// }

View file

@ -127,13 +127,14 @@ fn expand_named_fields(
let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where
B: Send,
S: Send,
{
type Rejection = #rejection_assoc_type;
async fn from_request(req: &mut ::axum::extract::RequestParts<B>) -> ::std::result::Result<Self, Self::Rejection> {
async fn from_request(req: &mut ::axum::extract::RequestParts<S, B>) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request(req)
.await
.map(|path| path.0)
@ -229,13 +230,14 @@ fn expand_unnamed_fields(
let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where
B: Send,
S: Send,
{
type Rejection = #rejection_assoc_type;
async fn from_request(req: &mut ::axum::extract::RequestParts<B>) -> ::std::result::Result<Self, Self::Rejection> {
async fn from_request(req: &mut ::axum::extract::RequestParts<S, B>) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::Path::from_request(req)
.await
.map(|path| path.0)
@ -310,13 +312,14 @@ fn expand_unit_fields(
let from_request_impl = quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<B> ::axum::extract::FromRequest<B> for #ident
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where
B: Send,
S: Send,
{
type Rejection = #rejection_assoc_type;
async fn from_request(req: &mut ::axum::extract::RequestParts<B>) -> ::std::result::Result<Self, Self::Rejection> {
async fn from_request(req: &mut ::axum::extract::RequestParts<S, B>) -> ::std::result::Result<Self, Self::Rejection> {
if req.uri().path() == <Self as ::axum_extra::routing::TypedPath>::PATH {
Ok(Self)
} else {
@ -387,7 +390,7 @@ enum Segment {
fn path_rejection() -> TokenStream {
quote! {
<::axum::extract::Path<Self> as ::axum::extract::FromRequest<B>>::Rejection
<::axum::extract::Path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection
}
}

View file

@ -1,17 +1,17 @@
error[E0277]: the trait bound `bool: FromRequest<Body>` is not satisfied
error[E0277]: the trait bound `bool: FromRequest<(), Body>` is not satisfied
--> tests/debug_handler/fail/argument_not_extractor.rs:4:23
|
4 | async fn handler(foo: bool) {}
| ^^^^ the trait `FromRequest<Body>` is not implemented for `bool`
| ^^^^ the trait `FromRequest<(), Body>` is not implemented for `bool`
|
= help: the following other types implement trait `FromRequest<B>`:
()
(T1, T2)
(T1, T2, T3)
(T1, T2, T3, T4)
(T1, T2, T3, T4, T5)
(T1, T2, T3, T4, T5, T6)
(T1, T2, T3, T4, T5, T6, T7)
(T1, T2, T3, T4, T5, T6, T7, T8)
and 33 others
= help: the following other types implement trait `FromRequest<S, B>`:
<() as FromRequest<S, B>>
<(T1, T2) as FromRequest<S, B>>
<(T1, T2, T3) as FromRequest<S, B>>
<(T1, T2, T3, T4) as FromRequest<S, B>>
<(T1, T2, T3, T4, T5) as FromRequest<S, B>>
<(T1, T2, T3, T4, T5, T6) as FromRequest<S, B>>
<(T1, T2, T3, T4, T5, T6, T7) as FromRequest<S, B>>
<(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequest<S, B>>
and 34 others
= help: see issue #48214

View file

@ -7,13 +7,14 @@ use axum_macros::debug_handler;
struct A;
#[async_trait]
impl<B> FromRequest<B> for A
impl<S, B> FromRequest<S, B> for A
where
B: Send + 'static,
B: Send,
S: Send,
{
type Rejection = ();
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}

View file

@ -1,5 +1,5 @@
error: Handlers must only take owned values
--> tests/debug_handler/fail/extract_self_mut.rs:23:22
--> tests/debug_handler/fail/extract_self_mut.rs:24:22
|
23 | async fn handler(&mut self) {}
24 | async fn handler(&mut self) {}
| ^^^^^^^^^

View file

@ -7,13 +7,14 @@ use axum_macros::debug_handler;
struct A;
#[async_trait]
impl<B> FromRequest<B> for A
impl<S, B> FromRequest<S, B> for A
where
B: Send + 'static,
B: Send,
S: Send,
{
type Rejection = ();
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}

View file

@ -1,5 +1,5 @@
error: Handlers must only take owned values
--> tests/debug_handler/fail/extract_self_ref.rs:23:22
--> tests/debug_handler/fail/extract_self_ref.rs:24:22
|
23 | async fn handler(&self) {}
24 | async fn handler(&self) {}
| ^^^^^

View file

@ -120,13 +120,14 @@ impl A {
}
#[async_trait]
impl<B> FromRequest<B> for A
impl<S, B> FromRequest<S, B> for A
where
B: Send + 'static,
B: Send,
S: Send,
{
type Rejection = ();
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}

View file

@ -7,13 +7,14 @@ use axum_macros::debug_handler;
struct A;
#[async_trait]
impl<B> FromRequest<B> for A
impl<S, B> FromRequest<S, B> for A
where
B: Send + 'static,
B: Send,
S: Send,
{
type Rejection = ();
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}

View file

@ -1,8 +1,7 @@
use axum_macros::FromRequest;
use axum::extract::Extension;
#[derive(FromRequest)]
struct Extractor(#[from_request(via(Extension), via(Extension))] State);
struct Extractor(#[from_request(via(axum::Extension), via(axum::Extension))] State);
#[derive(Clone)]
struct State;

View file

@ -1,13 +1,5 @@
error: `via` specified more than once
--> tests/from_request/fail/double_via_attr.rs:5:49
--> tests/from_request/fail/double_via_attr.rs:4:55
|
5 | struct Extractor(#[from_request(via(Extension), via(Extension))] State);
4 | struct Extractor(#[from_request(via(axum::Extension), via(axum::Extension))] State);
| ^^^
warning: unused import: `axum::extract::Extension`
--> tests/from_request/fail/double_via_attr.rs:2:5
|
2 | use axum::extract::Extension;
| ^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default

View file

@ -1,4 +1,4 @@
use axum::{body::Body, routing::get, Extension, Router};
use axum::{body::Body, routing::get, Router};
use axum_macros::FromRequest;
#[derive(FromRequest, Clone)]
@ -7,5 +7,5 @@ struct Extractor<T>(T);
async fn foo(_: Extractor<()>) {}
fn main() {
Router::<Body>::new().route("/", get(foo));
Router::<(), Body>::new().route("/", get(foo));
}

View file

@ -4,23 +4,15 @@ error: #[derive(FromRequest)] only supports generics when used with #[from_reque
5 | struct Extractor<T>(T);
| ^
warning: unused import: `Extension`
--> tests/from_request/fail/generic_without_via.rs:1:38
error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {foo}: Handler<_, _, _>` is not satisfied
--> tests/from_request/fail/generic_without_via.rs:10:46
|
1 | use axum::{body::Body, routing::get, Extension, Router};
| ^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default
error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {foo}: Handler<_, _>` is not satisfied
--> tests/from_request/fail/generic_without_via.rs:10:42
|
10 | Router::<Body>::new().route("/", get(foo));
| --- ^^^ the trait `Handler<_, _>` is not implemented for `fn(Extractor<()>) -> impl Future<Output = ()> {foo}`
10 | Router::<(), Body>::new().route("/", get(foo));
| --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future<Output = ()> {foo}`
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, ReqBody>` is implemented for `Layered<S, T>`
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|

View file

@ -1,4 +1,4 @@
use axum::{body::Body, routing::get, Extension, Router};
use axum::{body::Body, routing::get, Router};
use axum_macros::FromRequest;
#[derive(FromRequest, Clone)]
@ -8,5 +8,5 @@ struct Extractor<T>(T);
async fn foo(_: Extractor<()>) {}
fn main() {
Router::<Body>::new().route("/", get(foo));
Router::<(), Body>::new().route("/", get(foo));
}

View file

@ -4,23 +4,15 @@ error: #[derive(FromRequest)] only supports generics when used with #[from_reque
6 | struct Extractor<T>(T);
| ^
warning: unused import: `Extension`
--> tests/from_request/fail/generic_without_via_rejection.rs:1:38
error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {foo}: Handler<_, _, _>` is not satisfied
--> tests/from_request/fail/generic_without_via_rejection.rs:11:46
|
1 | use axum::{body::Body, routing::get, Extension, Router};
| ^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default
error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {foo}: Handler<_, _>` is not satisfied
--> tests/from_request/fail/generic_without_via_rejection.rs:11:42
|
11 | Router::<Body>::new().route("/", get(foo));
| --- ^^^ the trait `Handler<_, _>` is not implemented for `fn(Extractor<()>) -> impl Future<Output = ()> {foo}`
11 | Router::<(), Body>::new().route("/", get(foo));
| --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future<Output = ()> {foo}`
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, ReqBody>` is implemented for `Layered<S, T>`
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|

View file

@ -1,4 +1,4 @@
use axum::{body::Body, routing::get, Extension, Router};
use axum::{body::Body, routing::get, Router};
use axum_macros::FromRequest;
#[derive(FromRequest, Clone)]
@ -8,5 +8,5 @@ struct Extractor<T>(T);
async fn foo(_: Extractor<()>) {}
fn main() {
Router::<Body>::new().route("/", get(foo));
Router::<(), Body>::new().route("/", get(foo));
}

View file

@ -4,23 +4,15 @@ error: #[derive(FromRequest)] only supports generics when used with #[from_reque
6 | struct Extractor<T>(T);
| ^
warning: unused import: `Extension`
--> tests/from_request/fail/generic_without_via_rejection_derive.rs:1:38
error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {foo}: Handler<_, _, _>` is not satisfied
--> tests/from_request/fail/generic_without_via_rejection_derive.rs:11:46
|
1 | use axum::{body::Body, routing::get, Extension, Router};
| ^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default
error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future<Output = ()> {foo}: Handler<_, _>` is not satisfied
--> tests/from_request/fail/generic_without_via_rejection_derive.rs:11:42
|
11 | Router::<Body>::new().route("/", get(foo));
| --- ^^^ the trait `Handler<_, _>` is not implemented for `fn(Extractor<()>) -> impl Future<Output = ()> {foo}`
11 | Router::<(), Body>::new().route("/", get(foo));
| --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future<Output = ()> {foo}`
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, ReqBody>` is implemented for `Layered<S, T>`
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|

View file

@ -4,15 +4,15 @@ error: cannot use `rejection` without `via`
18 | #[from_request(rejection(MyRejection))]
| ^^^^^^^^^^^
error[E0277]: the trait bound `fn(MyExtractor) -> impl Future<Output = ()> {handler}: Handler<_, _>` is not satisfied
error[E0277]: the trait bound `fn(MyExtractor) -> impl Future<Output = ()> {handler}: Handler<_, _, _>` is not satisfied
--> tests/from_request/fail/override_rejection_on_enum_without_via.rs:10:50
|
10 | let _: Router = Router::new().route("/", get(handler).post(handler_result));
| --- ^^^^^^^ the trait `Handler<_, _>` is not implemented for `fn(MyExtractor) -> impl Future<Output = ()> {handler}`
| --- ^^^^^^^ the trait `Handler<_, _, _>` is not implemented for `fn(MyExtractor) -> impl Future<Output = ()> {handler}`
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, ReqBody>` is implemented for `Layered<S, T>`
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
note: required by a bound in `axum::routing::get`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|
@ -20,18 +20,18 @@ note: required by a bound in `axum::routing::get`
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `axum::routing::get`
= note: this error originates in the macro `top_level_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info)
error[E0277]: the trait bound `fn(Result<MyExtractor, MyRejection>) -> impl Future<Output = ()> {handler_result}: Handler<_, _>` is not satisfied
error[E0277]: the trait bound `fn(Result<MyExtractor, MyRejection>) -> impl Future<Output = ()> {handler_result}: Handler<_, _, _>` is not satisfied
--> tests/from_request/fail/override_rejection_on_enum_without_via.rs:10:64
|
10 | let _: Router = Router::new().route("/", get(handler).post(handler_result));
| ---- ^^^^^^^^^^^^^^ the trait `Handler<_, _>` is not implemented for `fn(Result<MyExtractor, MyRejection>) -> impl Future<Output = ()> {handler_result}`
| ---- ^^^^^^^^^^^^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Result<MyExtractor, MyRejection>) -> impl Future<Output = ()> {handler_result}`
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, ReqBody>` is implemented for `Layered<S, T>`
note: required by a bound in `MethodRouter::<B>::post`
= help: the trait `Handler<T, S, B>` is implemented for `Layered<L, H, T, S, B>`
note: required by a bound in `MethodRouter::<S, B>::post`
--> $WORKSPACE/axum/src/routing/method_routing.rs
|
| chained_handler_fn!(post, POST);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `MethodRouter::<B>::post`
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `MethodRouter::<S, B>::post`
= note: this error originates in the macro `chained_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info)

View file

@ -1,8 +1,7 @@
use axum_macros::FromRequest;
use axum::extract::Extension;
#[derive(FromRequest, Clone)]
#[from_request(rejection_derive(!Error), via(Extension))]
#[from_request(rejection_derive(!Error), via(axum::Extension))]
struct Extractor {
config: String,
}

View file

@ -1,13 +1,5 @@
error: cannot use both `rejection_derive` and `via`
--> tests/from_request/fail/rejection_derive_and_via.rs:5:42
--> tests/from_request/fail/rejection_derive_and_via.rs:4:42
|
5 | #[from_request(rejection_derive(!Error), via(Extension))]
4 | #[from_request(rejection_derive(!Error), via(axum::Extension))]
| ^^^
warning: unused import: `axum::extract::Extension`
--> tests/from_request/fail/rejection_derive_and_via.rs:2:5
|
2 | use axum::extract::Extension;
| ^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default

View file

@ -1,8 +1,7 @@
use axum_macros::FromRequest;
use axum::extract::Extension;
#[derive(FromRequest, Clone)]
#[from_request(via(Extension), rejection_derive(!Error))]
#[from_request(via(axum::Extension), rejection_derive(!Error))]
struct Extractor {
config: String,
}

View file

@ -1,13 +1,5 @@
error: cannot use both `via` and `rejection_derive`
--> tests/from_request/fail/via_and_rejection_derive.rs:5:32
--> tests/from_request/fail/via_and_rejection_derive.rs:4:38
|
5 | #[from_request(via(Extension), rejection_derive(!Error))]
4 | #[from_request(via(axum::Extension), rejection_derive(!Error))]
| ^^^^^^^^^^^^^^^^
warning: unused import: `axum::extract::Extension`
--> tests/from_request/fail/via_and_rejection_derive.rs:2:5
|
2 | use axum::extract::Extension;
| ^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default

View file

@ -1,9 +1,8 @@
use axum_macros::FromRequest;
use axum::extract::Extension;
#[derive(FromRequest)]
#[from_request(via(Extension))]
struct Extractor(#[from_request(via(Extension))] State);
#[from_request(via(axum::Extension))]
struct Extractor(#[from_request(via(axum::Extension))] State);
#[derive(Clone)]
struct State;

View file

@ -1,13 +1,5 @@
error: `#[from_request(via(...))]` on a field cannot be used together with `#[from_request(...)]` on the container
--> tests/from_request/fail/via_on_container_and_field.rs:6:33
--> tests/from_request/fail/via_on_container_and_field.rs:5:33
|
6 | struct Extractor(#[from_request(via(Extension))] State);
5 | struct Extractor(#[from_request(via(axum::Extension))] State);
| ^^^
warning: unused import: `axum::extract::Extension`
--> tests/from_request/fail/via_on_container_and_field.rs:2:5
|
2 | use axum::extract::Extension;
| ^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default

View file

@ -15,7 +15,7 @@ struct Extractor {
fn assert_from_request()
where
Extractor: FromRequest<Body, Rejection = JsonRejection>,
Extractor: FromRequest<(), Body, Rejection = JsonRejection>,
{
}

View file

@ -14,13 +14,14 @@ struct Extractor {
struct OtherExtractor;
#[async_trait]
impl<B> FromRequest<B> for OtherExtractor
impl<S, B> FromRequest<S, B> for OtherExtractor
where
B: Send + 'static,
B: Send,
S: Send,
{
type Rejection = OtherExtractorRejection;
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}

View file

@ -5,7 +5,7 @@ struct Extractor {}
fn assert_from_request()
where
Extractor: axum::extract::FromRequest<axum::body::Body, Rejection = std::convert::Infallible>,
Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>,
{
}

View file

@ -5,7 +5,7 @@ struct Extractor();
fn assert_from_request()
where
Extractor: axum::extract::FromRequest<axum::body::Body, Rejection = std::convert::Infallible>,
Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>,
{
}

View file

@ -8,5 +8,5 @@ enum Extractor {}
async fn foo(_: Extractor) {}
fn main() {
Router::<Body>::new().route("/", get(foo));
Router::<(), Body>::new().route("/", get(foo));
}

View file

@ -18,7 +18,7 @@ struct Extractor {
fn assert_from_request()
where
Extractor: FromRequest<Body, Rejection = ExtractorRejection>,
Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>,
{
}

View file

@ -25,7 +25,7 @@ struct Extractor {
fn assert_from_request()
where
Extractor: FromRequest<Body, Rejection = ExtractorRejection>,
Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>,
{
}

View file

@ -28,14 +28,15 @@ struct MyExtractor {
struct OtherExtractor;
#[async_trait]
impl<B> FromRequest<B> for OtherExtractor
impl<S, B> FromRequest<S, B> for OtherExtractor
where
B: Send + 'static,
S: Send,
{
// this rejection doesn't implement `Display` and `Error`
type Rejection = (StatusCode, String);
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
todo!()
}
}

View file

@ -5,7 +5,7 @@ struct Extractor(axum::http::HeaderMap, String);
fn assert_from_request()
where
Extractor: axum::extract::FromRequest<axum::body::Body>,
Extractor: axum::extract::FromRequest<(), axum::body::Body>,
{
}

View file

@ -13,7 +13,7 @@ struct Payload {}
fn assert_from_request()
where
Extractor: axum::extract::FromRequest<axum::body::Body>,
Extractor: axum::extract::FromRequest<(), axum::body::Body>,
{
}

View file

@ -27,7 +27,7 @@ struct Payload {}
fn assert_from_request()
where
Extractor: axum::extract::FromRequest<axum::body::Body>,
Extractor: axum::extract::FromRequest<(), axum::body::Body>,
{
}

View file

@ -1,5 +1,5 @@
use axum::Extension;
use axum_macros::FromRequest;
use axum::extract::Extension;
#[derive(FromRequest)]
struct Extractor(#[from_request(via(Extension))] State);
@ -9,7 +9,7 @@ struct State;
fn assert_from_request()
where
Extractor: axum::extract::FromRequest<axum::body::Body>,
Extractor: axum::extract::FromRequest<(), axum::body::Body>,
{
}

View file

@ -5,7 +5,7 @@ struct Extractor;
fn assert_from_request()
where
Extractor: axum::extract::FromRequest<axum::body::Body, Rejection = std::convert::Infallible>,
Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>,
{
}

View file

@ -15,5 +15,5 @@ error[E0277]: the trait bound `for<'de> MyPath: serde::de::Deserialize<'de>` is
(T0, T1, T2, T3)
and 138 others
= note: required because of the requirements on the impl of `serde::de::DeserializeOwned` for `MyPath`
= note: required because of the requirements on the impl of `FromRequest<B>` for `axum::extract::Path<MyPath>`
= note: required because of the requirements on the impl of `FromRequest<S, B>` for `axum::extract::Path<MyPath>`
= note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info)

View file

@ -40,7 +40,7 @@ impl Default for MyRejection {
}
fn main() {
axum::Router::<axum::body::Body>::new()
axum::Router::<(), axum::body::Body>::new()
.typed_get(|_: Result<MyPathNamed, MyRejection>| async {})
.typed_post(|_: Result<MyPathUnnamed, MyRejection>| async {})
.typed_put(|_: Result<MyPathUnit, MyRejection>| async {});

View file

@ -9,7 +9,7 @@ struct MyPath {
}
fn main() {
axum::Router::<axum::body::Body>::new().route("/", axum::routing::get(|_: MyPath| async {}));
axum::Router::<(), axum::body::Body>::new().route("/", axum::routing::get(|_: MyPath| async {}));
assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id");
assert_eq!(

View file

@ -20,7 +20,7 @@ struct UsersIndex;
async fn result_handler_unit_struct(_: Result<UsersIndex, StatusCode>) {}
fn main() {
axum::Router::<axum::body::Body>::new()
axum::Router::<(), axum::body::Body>::new()
.typed_get(option_handler)
.typed_post(result_handler)
.typed_post(result_handler_unit_struct);

View file

@ -8,7 +8,7 @@ pub type Result<T> = std::result::Result<T, ()>;
struct MyPath(u32, u32);
fn main() {
axum::Router::<axum::body::Body>::new().route("/", axum::routing::get(|_: MyPath| async {}));
axum::Router::<(), axum::body::Body>::new().route("/", axum::routing::get(|_: MyPath| async {}));
assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id");
assert_eq!(format!("{}", MyPath(1, 2)), "/users/1/teams/2");

View file

@ -5,7 +5,7 @@ use axum_extra::routing::TypedPath;
struct MyPath;
fn main() {
axum::Router::<axum::body::Body>::new()
axum::Router::<(), axum::body::Body>::new()
.route("/", axum::routing::get(|_: MyPath| async {}));
assert_eq!(MyPath::PATH, "/users");

View file

@ -8,5 +8,5 @@ struct MyPath {
}
fn main() {
axum::Router::<axum::body::Body>::new().typed_get(|_: MyPath| async {});
axum::Router::<(), axum::body::Body>::new().typed_get(|_: MyPath| async {});
}

View file

@ -35,6 +35,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Added `debug_handler` which is an attribute macro that improves
type errors when applied to handler function. It is re-exported from
`axum-macros`
- **added:** Added new type safe `State` extractor. This can be used with
`Router::with_state` and gives compile errors for missing states, whereas
`Extension` would result in runtime errors ([#1155])
- **breaking:** The following types or traits have a new `S` type param
(`()` by default) which represents the state ([#1155]):
- `FromRequest`
- `RequestParts`
- `Router`
- `MethodRouter`
- `Handler`
- **breaking:** `Router::route` now only accepts `MethodRouter`s created with
`get`, `post`, etc ([#1155])
- **added:** `Router::route_service` for routing to arbitrary `Service`s ([#1155])
- **added:** Support any middleware response that implements `IntoResponse` ([#1152])
- **breaking:** Require middleware added with `Handler::layer` to have
`Infallible` as the error type ([#1152])
@ -54,6 +67,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1130]: https://github.com/tokio-rs/axum/pull/1130
[#1135]: https://github.com/tokio-rs/axum/pull/1135
[#1152]: https://github.com/tokio-rs/axum/pull/1152
[#1155]: https://github.com/tokio-rs/axum/pull/1155
[#1171]: https://github.com/tokio-rs/axum/pull/1171
[#1239]: https://github.com/tokio-rs/axum/pull/1239
[#1248]: https://github.com/tokio-rs/axum/pull/1248

View file

@ -69,7 +69,7 @@ let some_fallible_service = tower::service_fn(|_req| async {
Ok::<_, anyhow::Error>(Response::new(Body::empty()))
});
let app = Router::new().route(
let app = Router::new().route_service(
"/",
// we cannot route to `some_fallible_service` directly since it might fail.
// we have to use `handle_error` which converts its errors into responses

View file

@ -421,13 +421,14 @@ use http::{StatusCode, header::{HeaderValue, USER_AGENT}};
struct ExtractUserAgent(HeaderValue);
#[async_trait]
impl<B> FromRequest<B> for ExtractUserAgent
impl<S, B> FromRequest<S, B> for ExtractUserAgent
where
B: Send,
S: Send,
{
type Rejection = (StatusCode, &'static str);
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if let Some(user_agent) = req.headers().get(USER_AGENT) {
Ok(ExtractUserAgent(user_agent.clone()))
} else {
@ -472,13 +473,14 @@ struct AuthenticatedUser {
}
#[async_trait]
impl<B> FromRequest<B> for AuthenticatedUser
impl<S, B> FromRequest<S, B> for AuthenticatedUser
where
B: Send,
S: Send,
{
type Rejection = Response;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(token)) =
TypedHeader::<Authorization<Bearer>>::from_request(req)
.await
@ -633,7 +635,7 @@ fn token_is_valid(token: &str) -> bool {
}
let app = Router::new().layer(middleware::from_fn(auth_middleware));
# let _: Router = app;
# let _: Router<()> = app;
```
[`body::Body`]: crate::body::Body

View file

@ -11,7 +11,7 @@ use axum::{
http::{StatusCode, Method, Uri},
};
let handler = get(|| async {}).fallback(fallback.into_service());
let handler = get(|| async {}).fallback(fallback);
let app = Router::new().route("/", handler);
@ -36,11 +36,9 @@ use axum::{
http::{StatusCode, Uri},
};
let one = get(|| async {})
.fallback(fallback_one.into_service());
let one = get(|| async {}).fallback(fallback_one);
let two = post(|| async {})
.fallback(fallback_two.into_service());
let two = post(|| async {}).fallback(fallback_two);
let method_route = one.merge(two);

View file

@ -6,7 +6,8 @@
- [Ordering](#ordering)
- [Writing middleware](#writing-middleware)
- [Routing to services/middleware and backpressure](#routing-to-servicesmiddleware-and-backpressure)
- [Sharing state between handlers and middleware](#sharing-state-between-handlers-and-middleware)
- [Accessing state in middleware](#accessing-state-in-middleware)
- [Passing state from middleware to handlers](#passing-state-from-middleware-to-handlers)
# Intro
@ -95,7 +96,7 @@ let app = Router::new()
.layer(layer_one)
.layer(layer_two)
.layer(layer_three);
# let app: Router<axum::body::Body> = app;
# let _: Router<(), axum::body::Body> = app;
```
Think of the middleware as being layered like an onion where each new layer
@ -154,7 +155,7 @@ let app = Router::new()
.layer(layer_two)
.layer(layer_three),
);
# let app: Router<axum::body::Body> = app;
# let _: Router<(), axum::body::Body> = app;
```
`ServiceBuilder` works by composing all layers into one such that they run top
@ -386,9 +387,119 @@ Also note that handlers created from async functions don't care about
backpressure and are always ready. So if you're not using any Tower
middleware you don't have to worry about any of this.
# Sharing state between handlers and middleware
# Accessing state in middleware
State can be shared between middleware and handlers using [request extensions]:
Handlers can access state using the [`State`] extractor but this isn't available
to middleware. Instead you have to pass the state directly to middleware using
either closure captures (for [`axum::middleware::from_fn`]) or regular struct
fields (if you're implementing a [`tower::Layer`])
## Accessing state in `axum::middleware::from_fn`
```rust
use axum::{
Router,
routing::get,
middleware::{self, Next},
response::Response,
extract::State,
http::Request,
};
#[derive(Clone)]
struct AppState {}
async fn my_middleware<B>(
state: AppState,
req: Request<B>,
next: Next<B>,
) -> Response {
next.run(req).await
}
async fn handler(_: State<AppState>) {}
let state = AppState {};
let app = Router::with_state(state.clone())
.route("/", get(handler))
.layer(middleware::from_fn(move |req, next| {
my_middleware(state.clone(), req, next)
}));
# let _: Router<_> = app;
```
## Accessing state in custom `tower::Layer`s
```rust
use axum::{
Router,
routing::get,
middleware::{self, Next},
response::Response,
extract::State,
http::Request,
};
use tower::{Layer, Service};
use std::task::{Context, Poll};
#[derive(Clone)]
struct AppState {}
#[derive(Clone)]
struct MyLayer {
state: AppState,
}
impl<S> Layer<S> for MyLayer {
type Service = MyService<S>;
fn layer(&self, inner: S) -> Self::Service {
MyService {
inner,
state: self.state.clone(),
}
}
}
#[derive(Clone)]
struct MyService<S> {
inner: S,
state: AppState,
}
impl<S, B> Service<Request<B>> for MyService<S>
where
S: Service<Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
// do something with `self.state`
self.inner.call(req)
}
}
async fn handler(_: State<AppState>) {}
let state = AppState {};
let app = Router::with_state(state.clone())
.route("/", get(handler))
.layer(MyLayer { state });
# let _: Router<_> = app;
```
# Passing state from middleware to handlers
State can be passed from middleware to handlers using [request extensions]:
```rust
use axum::{
@ -415,6 +526,8 @@ async fn auth<B>(mut req: Request<B>, next: Next<B>) -> Result<Response, StatusC
};
if let Some(current_user) = authorize_current_user(auth_header).await {
// insert the current user into a request extension so the handler can
// extract it
req.extensions_mut().insert(current_user);
Ok(next.run(req).await)
} else {
@ -437,7 +550,7 @@ async fn handler(
let app = Router::new()
.route("/", get(handler))
.route_layer(middleware::from_fn(auth));
# let app: Router = app;
# let _: Router<()> = app;
```
[Response extensions] can also be used but note that request extensions are not
@ -462,3 +575,4 @@ extensions you need.
[`MethodRouter::route_layer`]: crate::routing::MethodRouter::route_layer
[request extensions]: https://docs.rs/http/latest/http/request/struct.Request.html#method.extensions
[Response extensions]: https://docs.rs/http/latest/http/response/struct.Response.html#method.extensions
[`State`]: crate::extract::State

View file

@ -1,4 +1,4 @@
Add a fallback service to the router.
Add a fallback [`Handler`] to the router.
This service will be called if no routes matches the incoming request.
@ -13,7 +13,7 @@ use axum::{
let app = Router::new()
.route("/foo", get(|| async { /* ... */ }))
.fallback(fallback.into_service());
.fallback(fallback);
async fn fallback(uri: Uri) -> (StatusCode, String) {
(StatusCode::NOT_FOUND, format!("No route for {}", uri))

View file

@ -104,7 +104,7 @@ let api_routes = Router::new().nest("/users", get(|| async {}));
let app = Router::new()
.nest("/api", api_routes)
.fallback(fallback.into_service());
.fallback(fallback);
# let _: Router = app;
```
@ -132,11 +132,11 @@ async fn api_fallback() -> (StatusCode, Json<Value>) {
let api_routes = Router::new()
.nest("/users", get(|| async {}))
// add dedicated fallback for requests starting with `/api`
.fallback(api_fallback.into_service());
.fallback(api_fallback);
let app = Router::new()
.nest("/api", api_routes)
.fallback(fallback.into_service());
.fallback(fallback);
# let _: Router = app;
```

View file

@ -3,10 +3,10 @@ Add another route to the router.
`path` is a string of path segments separated by `/`. Each segment
can be either static, a capture, or a wildcard.
`service` is the [`Service`] that should receive the request if the path matches
`path`. `service` will commonly be a handler wrapped in a method router like
[`get`](crate::routing::get). See [`handler`](crate::handler) for more details
on handlers.
`method_router` is the [`MethodRouter`] that should receive the request if the
path matches `path`. `method_router` will commonly be a handler wrapped in a method
router like [`get`](crate::routing::get). See [`handler`](crate::handler) for
more details on handlers.
# Static paths
@ -105,69 +105,6 @@ async fn serve_asset(Path(path): Path<String>) {}
# };
```
# Routing to any [`Service`]
axum also supports routing to general [`Service`]s:
```rust,no_run
use axum::{
Router,
body::Body,
routing::{any_service, get_service},
http::{Request, StatusCode},
error_handling::HandleErrorLayer,
};
use tower_http::services::ServeFile;
use http::Response;
use std::{convert::Infallible, io};
use tower::service_fn;
let app = Router::new()
.route(
// Any request to `/` goes to a service
"/",
// Services whose response body is not `axum::body::BoxBody`
// can be wrapped in `axum::routing::any_service` (or one of the other routing filters)
// to have the response body mapped
any_service(service_fn(|_: Request<Body>| async {
let res = Response::new(Body::from("Hi from `GET /`"));
Ok::<_, Infallible>(res)
}))
)
.route(
"/foo",
// This service's response body is `axum::body::BoxBody` so
// it can be routed to directly.
service_fn(|req: Request<Body>| async move {
let body = Body::from(format!("Hi from `{} /foo`", req.method()));
let body = axum::body::boxed(body);
let res = Response::new(body);
Ok::<_, Infallible>(res)
})
)
.route(
// GET `/static/Cargo.toml` goes to a service from tower-http
"/static/Cargo.toml",
get_service(ServeFile::new("Cargo.toml"))
// though we must handle any potential errors
.handle_error(|error: io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
})
);
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
```
Routing to arbitrary services in this way has complications for backpressure
([`Service::poll_ready`]). See the [Routing to services and backpressure] module
for more details.
[Routing to services and backpressure]: middleware/index.html#routing-to-servicesmiddleware-and-backpressure
# Panics
Panics if the route overlaps with another route:
@ -187,21 +124,3 @@ The static route `/foo` and the dynamic route `/:key` are not considered to
overlap and `/foo` will take precedence.
Also panics if `path` is empty.
## Nesting
`route` cannot be used to nest `Router`s. Instead use [`Router::nest`].
Attempting to will result in a panic:
```rust,should_panic
use axum::{routing::get, Router};
let app = Router::new().route(
"/",
Router::new().route("/foo", get(|| async {})),
);
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
```

View file

@ -0,0 +1,81 @@
Add another route to the router that calls a [`Service`].
# Example
```rust,no_run
use axum::{
Router,
body::Body,
routing::{any_service, get_service},
http::{Request, StatusCode},
error_handling::HandleErrorLayer,
};
use tower_http::services::ServeFile;
use http::Response;
use std::{convert::Infallible, io};
use tower::service_fn;
let app = Router::new()
.route(
// Any request to `/` goes to a service
"/",
// Services whose response body is not `axum::body::BoxBody`
// can be wrapped in `axum::routing::any_service` (or one of the other routing filters)
// to have the response body mapped
any_service(service_fn(|_: Request<Body>| async {
let res = Response::new(Body::from("Hi from `GET /`"));
Ok::<_, Infallible>(res)
}))
)
.route_service(
"/foo",
// This service's response body is `axum::body::BoxBody` so
// it can be routed to directly.
service_fn(|req: Request<Body>| async move {
let body = Body::from(format!("Hi from `{} /foo`", req.method()));
let body = axum::body::boxed(body);
let res = Response::new(body);
Ok::<_, Infallible>(res)
})
)
.route(
// GET `/static/Cargo.toml` goes to a service from tower-http
"/static/Cargo.toml",
get_service(ServeFile::new("Cargo.toml"))
// though we must handle any potential errors
.handle_error(|error: io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
})
);
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
```
Routing to arbitrary services in this way has complications for backpressure
([`Service::poll_ready`]). See the [Routing to services and backpressure] module
for more details.
# Panics
Panics for the same reasons as [`Router::route`] or if you attempt to route to a
`Router`:
```rust,should_panic
use axum::{routing::get, Router};
let app = Router::new().route_service(
"/",
Router::new().route("/foo", get(|| async {})),
);
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
```
Use [`Router::nest`] instead.
[Routing to services and backpressure]: middleware/index.html#routing-to-servicesmiddleware-and-backpressure

View file

@ -1,7 +1,6 @@
#![doc = include_str!("../docs/error_handling.md")]
use crate::{
body::boxed,
extract::{FromRequest, RequestParts},
http::{Request, StatusCode},
response::{IntoResponse, Response},
@ -113,16 +112,16 @@ where
}
}
impl<S, F, ReqBody, Fut, Res> Service<Request<ReqBody>> for HandleError<S, F, ()>
impl<S, F, B, Fut, Res> Service<Request<B>> for HandleError<S, F, ()>
where
S: Service<Request<ReqBody>> + Clone + Send + 'static,
S: Service<Request<B>> + Clone + Send + 'static,
S::Response: IntoResponse + Send,
S::Error: Send,
S::Future: Send,
F: FnOnce(S::Error) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
Res: IntoResponse,
ReqBody: Send + 'static,
B: Send + 'static,
{
type Response = Response;
type Error = Infallible;
@ -132,7 +131,7 @@ where
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
fn call(&mut self, req: Request<B>) -> Self::Future {
let f = self.f.clone();
let clone = self.inner.clone();
@ -152,18 +151,18 @@ where
#[allow(unused_macros)]
macro_rules! impl_service {
( $($ty:ident),* $(,)? ) => {
impl<S, F, ReqBody, Res, Fut, $($ty,)*> Service<Request<ReqBody>>
impl<S, F, B, Res, Fut, $($ty,)*> Service<Request<B>>
for HandleError<S, F, ($($ty,)*)>
where
S: Service<Request<ReqBody>> + Clone + Send + 'static,
S: Service<Request<B>> + Clone + Send + 'static,
S::Response: IntoResponse + Send,
S::Error: Send,
S::Future: Send,
F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
Res: IntoResponse,
$( $ty: FromRequest<ReqBody> + Send,)*
ReqBody: Send + 'static,
$( $ty: FromRequest<(), B> + Send,)*
B: Send + 'static,
{
type Response = Response;
type Error = Infallible;
@ -175,7 +174,7 @@ macro_rules! impl_service {
}
#[allow(non_snake_case)]
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
fn call(&mut self, req: Request<B>) -> Self::Future {
let f = self.f.clone();
let clone = self.inner.clone();
@ -187,7 +186,7 @@ macro_rules! impl_service {
$(
let $ty = match $ty::from_request(&mut req).await {
Ok(value) => value,
Err(rejection) => return Ok(rejection.into_response().map(boxed)),
Err(rejection) => return Ok(rejection.into_response()),
};
)*
@ -200,7 +199,7 @@ macro_rules! impl_service {
match inner.oneshot(req).await {
Ok(res) => Ok(res.into_response()),
Err(err) => Ok(f($($ty),*, err).await.into_response().map(boxed)),
Err(err) => Ok(f($($ty),*, err).await.into_response()),
}
});

View file

@ -73,14 +73,15 @@ use tower_service::Service;
pub struct Extension<T>(pub T);
#[async_trait]
impl<T, B> FromRequest<B> for Extension<T>
impl<T, S, B> FromRequest<S, B> for Extension<T>
where
T: Clone + Send + Sync + 'static,
B: Send,
S: Send,
{
type Rejection = ExtensionRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let value = req
.extensions()
.get::<T>()

View file

@ -128,14 +128,15 @@ opaque_future! {
pub struct ConnectInfo<T>(pub T);
#[async_trait]
impl<B, T> FromRequest<B> for ConnectInfo<T>
impl<S, B, T> FromRequest<S, B> for ConnectInfo<T>
where
B: Send,
S: Send,
T: Clone + Send + Sync + 'static,
{
type Rejection = <Extension<Self> as FromRequest<B>>::Rejection;
type Rejection = <Extension<Self> as FromRequest<S, B>>::Rejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let Extension(connect_info) = Extension::<Self>::from_request(req).await?;
Ok(connect_info)
}

View file

@ -36,15 +36,16 @@ use std::ops::Deref;
pub struct ContentLengthLimit<T, const N: u64>(pub T);
#[async_trait]
impl<T, B, const N: u64> FromRequest<B> for ContentLengthLimit<T, N>
impl<T, S, B, const N: u64> FromRequest<S, B> for ContentLengthLimit<T, N>
where
T: FromRequest<B>,
T: FromRequest<S, B>,
T::Rejection: IntoResponse,
B: Send,
S: Send,
{
type Rejection = ContentLengthLimitRejection<T::Rejection>;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let content_length = req
.headers()
.get(http::header::CONTENT_LENGTH)

View file

@ -21,13 +21,14 @@ const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
pub struct Host(pub String);
#[async_trait]
impl<B> FromRequest<B> for Host
impl<S, B> FromRequest<S, B> for Host
where
B: Send,
S: Send,
{
type Rejection = HostRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if let Some(host) = parse_forwarded(req.headers()) {
return Ok(Host(host.to_owned()));
}

View file

@ -64,13 +64,14 @@ impl MatchedPath {
}
#[async_trait]
impl<B> FromRequest<B> for MatchedPath
impl<S, B> FromRequest<S, B> for MatchedPath
where
B: Send,
S: Send,
{
type Rejection = MatchedPathRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let matched_path = req
.extensions()
.get::<Self>()
@ -84,7 +85,9 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::{extract::Extension, handler::Handler, routing::get, test_helpers::*, Router};
use crate::{
extract::Extension, handler::HandlerWithoutStateExt, routing::get, test_helpers::*, Router,
};
use http::{Request, StatusCode};
use std::task::{Context, Poll};
use tower::layer::layer_fn;
@ -93,7 +96,7 @@ mod tests {
#[derive(Clone)]
struct SetMatchedPathExtension<S>(S);
impl<B, S> Service<Request<B>> for SetMatchedPathExtension<S>
impl<S, B> Service<Request<B>> for SetMatchedPathExtension<S>
where
S: Service<Request<B>>,
{

View file

@ -14,9 +14,10 @@ mod content_length_limit;
mod host;
mod raw_query;
mod request_parts;
mod state;
#[doc(inline)]
pub use axum_core::extract::{FromRequest, RequestParts};
pub use axum_core::extract::{FromRef, FromRequest, RequestParts};
#[doc(inline)]
#[allow(deprecated)]
@ -27,6 +28,7 @@ pub use self::{
path::Path,
raw_query::RawQuery,
request_parts::{BodyStream, RawBody},
state::State,
};
#[doc(no_inline)]
@ -73,13 +75,13 @@ pub use self::ws::WebSocketUpgrade;
#[doc(no_inline)]
pub use crate::TypedHeader;
pub(crate) fn take_body<B>(req: &mut RequestParts<B>) -> Result<B, BodyAlreadyExtracted> {
pub(crate) fn take_body<S, B>(req: &mut RequestParts<S, B>) -> Result<B, BodyAlreadyExtracted> {
req.take_body().ok_or_else(BodyAlreadyExtracted::default)
}
// this is duplicated in `axum-extra/src/extract/form.rs`
pub(super) fn has_content_type<B>(
req: &RequestParts<B>,
pub(super) fn has_content_type<S, B>(
req: &RequestParts<S, B>,
expected_content_type: &mime::Mime,
) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {

View file

@ -50,14 +50,15 @@ pub struct Multipart {
}
#[async_trait]
impl<B> FromRequest<B> for Multipart
impl<S, B> FromRequest<S, B> for Multipart
where
B: HttpBody<Data = Bytes> + Default + Unpin + Send + 'static,
B::Error: Into<BoxError>,
S: Send,
{
type Rejection = MultipartRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let stream = BodyStream::from_request(req).await?;
let headers = req.headers();
let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?;
@ -179,7 +180,7 @@ impl<'a> Field<'a> {
/// }
///
/// let app = Router::new().route("/upload", post(upload));
/// # let _: Router<axum::body::Body> = app;
/// # let _: Router = app;
/// ```
pub async fn chunk(&mut self) -> Result<Option<Bytes>, MultipartError> {
self.inner

View file

@ -163,14 +163,15 @@ impl<T> DerefMut for Path<T> {
}
#[async_trait]
impl<T, B> FromRequest<B> for Path<T>
impl<T, S, B> FromRequest<S, B> for Path<T>
where
T: DeserializeOwned + Send,
B: Send,
S: Send,
{
type Rejection = PathRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let params = match req.extensions_mut().get::<UrlParams>() {
Some(UrlParams::Params(params)) => params,
Some(UrlParams::InvalidUtf8InPathParam { key }) => {

View file

@ -49,14 +49,15 @@ use std::ops::Deref;
pub struct Query<T>(pub T);
#[async_trait]
impl<T, B> FromRequest<B> for Query<T>
impl<T, S, B> FromRequest<S, B> for Query<T>
where
T: DeserializeOwned,
B: Send,
S: Send,
{
type Rejection = QueryRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let query = req.uri().query().unwrap_or_default();
let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new)?;
@ -81,7 +82,8 @@ mod tests {
use std::fmt::Debug;
async fn check<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) {
let mut req = RequestParts::new(Request::builder().uri(uri.as_ref()).body(()).unwrap());
let req = Request::builder().uri(uri.as_ref()).body(()).unwrap();
let mut req = RequestParts::new(req);
assert_eq!(Query::<T>::from_request(&mut req).await.unwrap().0, value);
}

View file

@ -27,13 +27,14 @@ use std::convert::Infallible;
pub struct RawQuery(pub Option<String>);
#[async_trait]
impl<B> FromRequest<B> for RawQuery
impl<S, B> FromRequest<S, B> for RawQuery
where
B: Send,
S: Send,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let query = req.uri().query().map(|query| query.to_owned());
Ok(Self(query))
}

View file

@ -86,13 +86,14 @@ pub struct OriginalUri(pub Uri);
#[cfg(feature = "original-uri")]
#[async_trait]
impl<B> FromRequest<B> for OriginalUri
impl<S, B> FromRequest<S, B> for OriginalUri
where
B: Send,
S: Send,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let uri = Extension::<Self>::from_request(req)
.await
.unwrap_or_else(|_| Extension(OriginalUri(req.uri().clone())))
@ -140,15 +141,16 @@ impl Stream for BodyStream {
}
#[async_trait]
impl<B> FromRequest<B> for BodyStream
impl<S, B> FromRequest<S, B> for BodyStream
where
B: HttpBody + Send + 'static,
B::Data: Into<Bytes>,
B::Error: Into<BoxError>,
S: Send,
{
type Rejection = BodyAlreadyExtracted;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?
.map_data(Into::into)
.map_err(|err| Error::new(err.into()));
@ -196,13 +198,14 @@ fn body_stream_traits() {
pub struct RawBody<B = Body>(pub B);
#[async_trait]
impl<B> FromRequest<B> for RawBody<B>
impl<S, B> FromRequest<S, B> for RawBody<B>
where
B: Send,
S: Send,
{
type Rejection = BodyAlreadyExtracted;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?;
Ok(Self(body))
}

207
axum/src/extract/state.rs Normal file
View file

@ -0,0 +1,207 @@
use async_trait::async_trait;
use axum_core::extract::{FromRef, FromRequest, RequestParts};
use std::{
convert::Infallible,
ops::{Deref, DerefMut},
};
/// Extractor for state.
///
/// Note this extractor is not available to middleware. See ["Accessing state in
/// middleware"][state-from-middleware] for how to access state in middleware.
///
/// [state-from-middleware]: ../middleware/index.html#accessing-state-in-middleware
///
/// # With `Router`
///
/// ```
/// use axum::{Router, routing::get, extract::State};
///
/// // the application state
/// //
/// // here you can put configuration, database connection pools, or whatever
/// // state you need
/// #[derive(Clone)]
/// struct AppState {}
///
/// let state = AppState {};
///
/// // create a `Router` that holds our state
/// let app = Router::with_state(state).route("/", get(handler));
///
/// async fn handler(
/// // access the state via the `State` extractor
/// // extracting a state of the wrong type results in a compile error
/// State(state): State<AppState>,
/// ) {
/// // use `state`...
/// }
/// # let _: Router<AppState> = app;
/// ```
///
/// # With `MethodRouter`
///
/// ```
/// use axum::{routing::get, extract::State};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// let state = AppState {};
///
/// let method_router_with_state = get(handler)
/// // provide the state so the handler can access it
/// .with_state(state);
///
/// async fn handler(State(state): State<AppState>) {
/// // use `state`...
/// }
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(method_router_with_state.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// # With `Handler`
///
/// ```
/// use axum::{routing::get, handler::Handler, extract::State};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// let state = AppState {};
///
/// async fn handler(State(state): State<AppState>) {
/// // use `state`...
/// }
///
/// // provide the state so the handler can access it
/// let handler_with_state = handler.with_state(state);
///
/// # async {
/// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
/// .serve(handler_with_state.into_make_service())
/// .await
/// .expect("server failed");
/// # };
/// ```
///
/// # Substates
///
/// [`State`] only allows a single state type but you can use [`From`] to extract "substates":
///
/// ```
/// use axum::{Router, routing::get, extract::{State, FromRef}};
///
/// // the application state
/// #[derive(Clone)]
/// struct AppState {
/// // that holds some api specific state
/// api_state: ApiState,
/// }
///
/// // the api specific state
/// #[derive(Clone)]
/// struct ApiState {}
///
/// // support converting an `AppState` in an `ApiState`
/// impl FromRef<AppState> for ApiState {
/// fn from_ref(app_state: &AppState) -> ApiState {
/// app_state.api_state.clone()
/// }
/// }
///
/// let state = AppState {
/// api_state: ApiState {},
/// };
///
/// let app = Router::with_state(state)
/// .route("/", get(handler))
/// .route("/api/users", get(api_users));
///
/// async fn api_users(
/// // access the api specific state
/// State(api_state): State<ApiState>,
/// ) {
/// }
///
/// async fn handler(
/// // we can still access to top level state
/// State(state): State<AppState>,
/// ) {
/// }
/// # let _: Router<AppState> = app;
/// ```
///
/// # For library authors
///
/// If you're writing a library that has an extractor that needs state, this is the recommended way
/// to do it:
///
/// ```rust
/// use axum_core::extract::{FromRequest, RequestParts, FromRef};
/// use async_trait::async_trait;
/// use std::convert::Infallible;
///
/// // the extractor your library provides
/// struct MyLibraryExtractor;
///
/// #[async_trait]
/// impl<S, B> FromRequest<S, B> for MyLibraryExtractor
/// where
/// B: Send,
/// // keep `S` generic but require that it can produce a `MyLibraryState`
/// // this means users will have to implement `FromRef<UserState> for MyLibraryState`
/// MyLibraryState: FromRef<S>,
/// S: Send,
/// {
/// type Rejection = Infallible;
///
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // get a `MyLibraryState` from a reference to the state
/// let state = MyLibraryState::from_ref(req.state());
///
/// // ...
/// # todo!()
/// }
/// }
///
/// // the state your library needs
/// struct MyLibraryState {
/// // ...
/// }
/// ```
///
/// Note that you don't need to use the `State` extractor since you can access the state directly
/// from [`RequestParts`].
#[derive(Debug, Default, Clone, Copy)]
pub struct State<S>(pub S);
#[async_trait]
impl<B, OuterState, InnerState> FromRequest<OuterState, B> for State<InnerState>
where
B: Send,
InnerState: FromRef<OuterState>,
OuterState: Send,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<OuterState, B>) -> Result<Self, Self::Rejection> {
let inner_state = InnerState::from_ref(req.state());
Ok(Self(inner_state))
}
}
impl<S> Deref for State<S> {
type Target = S;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S> DerefMut for State<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

View file

@ -275,13 +275,14 @@ impl WebSocketUpgrade {
}
#[async_trait]
impl<B> FromRequest<B> for WebSocketUpgrade
impl<S, B> FromRequest<S, B> for WebSocketUpgrade
where
B: Send,
S: Send,
{
type Rejection = WebSocketUpgradeRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if req.method() != Method::GET {
return Err(MethodNotGet.into());
}
@ -320,7 +321,7 @@ where
}
}
fn header_eq<B>(req: &RequestParts<B>, key: HeaderName, value: &'static str) -> bool {
fn header_eq<S, B>(req: &RequestParts<S, B>, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = req.headers().get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
} else {
@ -328,7 +329,7 @@ fn header_eq<B>(req: &RequestParts<B>, key: HeaderName, value: &'static str) ->
}
}
fn header_contains<B>(req: &RequestParts<B>, key: HeaderName, value: &'static str) -> bool {
fn header_contains<S, B>(req: &RequestParts<S, B>, key: HeaderName, value: &'static str) -> bool {
let header = if let Some(header) = req.headers().get(&key) {
header
} else {

View file

@ -56,16 +56,17 @@ use std::ops::Deref;
pub struct Form<T>(pub T);
#[async_trait]
impl<T, B> FromRequest<B> for Form<T>
impl<T, S, B> FromRequest<S, B> for Form<T>
where
T: DeserializeOwned,
B: HttpBody + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
{
type Rejection = FormRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if req.method() == Method::GET {
let query = req.uri().query().unwrap_or_default();
let value = serde_urlencoded::from_str(query)
@ -125,18 +126,16 @@ mod tests {
}
async fn check_query<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) {
let mut req = RequestParts::new(
Request::builder()
let req = Request::builder()
.uri(uri.as_ref())
.body(Empty::<Bytes>::new())
.unwrap(),
);
.unwrap();
let mut req = RequestParts::new(req);
assert_eq!(Form::<T>::from_request(&mut req).await.unwrap().0, value);
}
async fn check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T) {
let mut req = RequestParts::new(
Request::builder()
let req = Request::builder()
.uri("http://example.com/test")
.method(Method::POST)
.header(
@ -146,8 +145,8 @@ mod tests {
.body(Full::<Bytes>::new(
serde_urlencoded::to_string(&value).unwrap().into(),
))
.unwrap(),
);
.unwrap();
let mut req = RequestParts::new(req);
assert_eq!(Form::<T>::from_request(&mut req).await.unwrap().0, value);
}
@ -204,8 +203,7 @@ mod tests {
#[tokio::test]
async fn test_incorrect_content_type() {
let mut req = RequestParts::new(
Request::builder()
let req = Request::builder()
.uri("http://example.com/test")
.method(Method::POST)
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
@ -217,8 +215,8 @@ mod tests {
.unwrap()
.into(),
))
.unwrap(),
);
.unwrap();
let mut req = RequestParts::new(req);
assert!(matches!(
Form::<Pagination>::from_request(&mut req)
.await

View file

@ -19,29 +19,29 @@ opaque_future! {
pin_project! {
/// The response future for [`Layered`](super::Layered).
pub struct LayeredFuture<S, ReqBody>
pub struct LayeredFuture<B, S>
where
S: Service<Request<ReqBody>>,
S: Service<Request<B>>,
{
#[pin]
inner: Map<Oneshot<S, Request<ReqBody>>, fn(Result<S::Response, S::Error>) -> Response>,
inner: Map<Oneshot<S, Request<B>>, fn(Result<S::Response, S::Error>) -> Response>,
}
}
impl<S, ReqBody> LayeredFuture<S, ReqBody>
impl<B, S> LayeredFuture<B, S>
where
S: Service<Request<ReqBody>>,
S: Service<Request<B>>,
{
pub(super) fn new(
inner: Map<Oneshot<S, Request<ReqBody>>, fn(Result<S::Response, S::Error>) -> Response>,
inner: Map<Oneshot<S, Request<B>>, fn(Result<S::Response, S::Error>) -> Response>,
) -> Self {
Self { inner }
}
}
impl<S, ReqBody> Future for LayeredFuture<S, ReqBody>
impl<B, S> Future for LayeredFuture<B, S>
where
S: Service<Request<ReqBody>>,
S: Service<Request<B>>,
{
type Output = Response;

View file

@ -11,29 +11,40 @@ use tower_service::Service;
/// An adapter that makes a [`Handler`] into a [`Service`].
///
/// Created with [`Handler::into_service`].
pub struct IntoService<H, T, B> {
/// Created with [`HandlerWithoutStateExt::into_service`].
///
/// [`HandlerWithoutStateExt::into_service`]: super::HandlerWithoutStateExt::into_service
pub struct IntoService<H, T, S, B> {
handler: H,
state: S,
_marker: PhantomData<fn() -> (T, B)>,
}
impl<H, T, S, B> IntoService<H, T, S, B> {
/// Get a reference to the state.
pub fn state(&self) -> &S {
&self.state
}
}
#[test]
fn traits() {
use crate::test_helpers::*;
assert_send::<IntoService<(), NotSendSync, NotSendSync>>();
assert_sync::<IntoService<(), NotSendSync, NotSendSync>>();
assert_send::<IntoService<(), NotSendSync, (), NotSendSync>>();
assert_sync::<IntoService<(), NotSendSync, (), NotSendSync>>();
}
impl<H, T, B> IntoService<H, T, B> {
pub(super) fn new(handler: H) -> Self {
impl<H, T, S, B> IntoService<H, T, S, B> {
pub(super) fn new(handler: H, state: S) -> Self {
Self {
handler,
state,
_marker: PhantomData,
}
}
}
impl<H, T, B> fmt::Debug for IntoService<H, T, B> {
impl<H, T, S, B> fmt::Debug for IntoService<H, T, S, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("IntoService")
.field(&format_args!("..."))
@ -41,22 +52,25 @@ impl<H, T, B> fmt::Debug for IntoService<H, T, B> {
}
}
impl<H, T, B> Clone for IntoService<H, T, B>
impl<H, T, S, B> Clone for IntoService<H, T, S, B>
where
H: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
state: self.state.clone(),
_marker: PhantomData,
}
}
}
impl<H, T, B> Service<Request<B>> for IntoService<H, T, B>
impl<H, T, S, B> Service<Request<B>> for IntoService<H, T, S, B>
where
H: Handler<T, B> + Clone + Send + 'static,
H: Handler<T, S, B> + Clone + Send + 'static,
B: Send + 'static,
S: Clone,
{
type Response = Response;
type Error = Infallible;
@ -74,7 +88,7 @@ where
use futures_util::future::FutureExt;
let handler = self.handler.clone();
let future = Handler::call(handler, req);
let future = Handler::call(handler, self.state.clone(), req);
let future = future.map(Ok as _);
super::future::IntoServiceFuture::new(future)

View file

@ -0,0 +1,85 @@
use super::Handler;
use crate::response::Response;
use http::Request;
use std::{
convert::Infallible,
fmt,
marker::PhantomData,
task::{Context, Poll},
};
use tower_service::Service;
pub(crate) struct IntoServiceStateInExtension<H, T, S, B> {
handler: H,
_marker: PhantomData<fn() -> (T, S, B)>,
}
#[test]
fn traits() {
use crate::test_helpers::*;
assert_send::<IntoServiceStateInExtension<(), NotSendSync, (), NotSendSync>>();
assert_sync::<IntoServiceStateInExtension<(), NotSendSync, (), NotSendSync>>();
}
impl<H, T, S, B> IntoServiceStateInExtension<H, T, S, B> {
pub(crate) fn new(handler: H) -> Self {
Self {
handler,
_marker: PhantomData,
}
}
}
impl<H, T, S, B> fmt::Debug for IntoServiceStateInExtension<H, T, S, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("IntoServiceStateInExtension")
.field(&format_args!("..."))
.finish()
}
}
impl<H, T, S, B> Clone for IntoServiceStateInExtension<H, T, S, B>
where
H: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
_marker: PhantomData,
}
}
}
impl<H, T, S, B> Service<Request<B>> for IntoServiceStateInExtension<H, T, S, B>
where
H: Handler<T, S, B> + Clone + Send + 'static,
B: Send + 'static,
S: Clone + Send + Sync + 'static,
{
type Response = Response;
type Error = Infallible;
type Future = super::future::IntoServiceFuture<H::Future>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// `IntoServiceStateInExtension` can only be constructed from async functions which are always ready, or
// from `Layered` which bufferes in `<Layered as Handler>::call` and is therefore
// also always ready.
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
use futures_util::future::FutureExt;
let state = req
.extensions_mut()
.remove::<S>()
.expect("state extension missing. This is a bug in axum, please file an issue");
let handler = self.handler.clone();
let future = Handler::call(handler, state, req);
let future = future.map(Ok as _);
super::future::IntoServiceFuture::new(future)
}
}

View file

@ -49,8 +49,11 @@ use tower_service::Service;
pub mod future;
mod into_service;
mod into_service_state_in_extension;
mod with_state;
pub use self::into_service::IntoService;
pub(crate) use self::into_service_state_in_extension::IntoServiceStateInExtension;
pub use self::{into_service::IntoService, with_state::WithState};
/// Trait for async functions that can be used to handle requests.
///
@ -59,13 +62,45 @@ pub use self::into_service::IntoService;
///
/// See the [module docs](crate::handler) for more details.
///
/// # Converting `Handler`s into [`Service`]s
///
/// To convert `Handler`s into [`Service`]s you have to call either
/// [`HandlerWithoutStateExt::into_service`] or [`Handler::with_state`]:
///
/// ```
/// use tower::Service;
/// use axum::{
/// extract::State,
/// body::Body,
/// http::Request,
/// handler::{HandlerWithoutStateExt, Handler},
/// };
///
/// // this handler doesn't require any state
/// async fn one() {}
/// // so it can be converted to a service with `HandlerWithoutStateExt::into_service`
/// assert_service(one.into_service());
///
/// // this handler requires state
/// async fn two(_: State<String>) {}
/// // so we have to provide it
/// let handler_with_state = two.with_state(String::new());
/// // which gives us a `Service`
/// assert_service(handler_with_state);
///
/// // helper to check that a value implements `Service`
/// fn assert_service<S>(service: S)
/// where
/// S: Service<Request<Body>>,
/// {}
/// ```
#[doc = include_str!("../docs/debugging_handler_type_errors.md")]
pub trait Handler<T, B = Body>: Clone + Send + Sized + 'static {
pub trait Handler<T, S = (), B = Body>: Clone + Send + Sized + 'static {
/// The type of future calling this handler returns.
type Future: Future<Output = Response> + Send + 'static;
/// Call the handler with the given request.
fn call(self, req: Request<B>) -> Self::Future;
fn call(self, state: S, req: Request<B>) -> Self::Future;
/// Apply a [`tower::Layer`] to the handler.
///
@ -103,112 +138,26 @@ pub trait Handler<T, B = Body>: Clone + Send + Sized + 'static {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
fn layer<L>(self, layer: L) -> Layered<L::Service, T>
fn layer<L>(self, layer: L) -> Layered<L, Self, T, S, B>
where
L: Layer<IntoService<Self, T, B>>,
L: Layer<WithState<Self, T, S, B>>,
{
Layered::new(layer.layer(self.into_service()))
}
/// Convert the handler into a [`Service`].
///
/// This is commonly used together with [`Router::fallback`]:
///
/// ```rust
/// use axum::{
/// Server,
/// handler::Handler,
/// http::{Uri, Method, StatusCode},
/// response::IntoResponse,
/// routing::{get, Router},
/// };
/// use tower::make::Shared;
/// use std::net::SocketAddr;
///
/// async fn handler(method: Method, uri: Uri) -> (StatusCode, String) {
/// (StatusCode::NOT_FOUND, format!("Nothing to see at {} {}", method, uri))
/// }
///
/// let app = Router::new()
/// .route("/", get(|| async {}))
/// .fallback(handler.into_service());
///
/// # async {
/// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
/// .serve(app.into_make_service())
/// .await?;
/// # Ok::<_, hyper::Error>(())
/// # };
/// ```
///
/// [`Router::fallback`]: crate::routing::Router::fallback
fn into_service(self) -> IntoService<Self, T, B> {
IntoService::new(self)
}
/// Convert the handler into a [`MakeService`].
///
/// This allows you to serve a single handler if you don't need any routing:
///
/// ```rust
/// use axum::{
/// Server, handler::Handler, http::{Uri, Method}, response::IntoResponse,
/// };
/// use std::net::SocketAddr;
///
/// async fn handler(method: Method, uri: Uri, body: String) -> String {
/// format!("received `{} {}` with body `{:?}`", method, uri, body)
/// }
///
/// # async {
/// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
/// .serve(handler.into_make_service())
/// .await?;
/// # Ok::<_, hyper::Error>(())
/// # };
/// ```
///
/// [`MakeService`]: tower::make::MakeService
fn into_make_service(self) -> IntoMakeService<IntoService<Self, T, B>> {
IntoMakeService::new(self.into_service())
}
/// Convert the handler into a [`MakeService`] which stores information
/// about the incoming connection.
///
/// See [`Router::into_make_service_with_connect_info`] for more details.
///
/// ```rust
/// use axum::{
/// Server,
/// handler::Handler,
/// response::IntoResponse,
/// extract::ConnectInfo,
/// };
/// use std::net::SocketAddr;
///
/// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
/// format!("Hello {}", addr)
/// }
///
/// # async {
/// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
/// .serve(handler.into_make_service_with_connect_info::<SocketAddr>())
/// .await?;
/// # Ok::<_, hyper::Error>(())
/// # };
/// ```
///
/// [`MakeService`]: tower::make::MakeService
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
fn into_make_service_with_connect_info<C>(
self,
) -> IntoMakeServiceWithConnectInfo<IntoService<Self, T, B>, C> {
IntoMakeServiceWithConnectInfo::new(self.into_service())
Layered {
layer,
handler: self,
_marker: PhantomData,
}
}
impl<F, Fut, Res, B> Handler<(), B> for F
/// Convert the handler into a [`Service`] by providing the state
fn with_state(self, state: S) -> WithState<Self, T, S, B> {
WithState {
service: IntoService::new(self, state),
}
}
}
impl<F, Fut, Res, S, B> Handler<(), S, B> for F
where
F: FnOnce() -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
@ -217,7 +166,7 @@ where
{
type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, _req: Request<B>) -> Self::Future {
fn call(self, _state: S, _req: Request<B>) -> Self::Future {
Box::pin(async move { self().await.into_response() })
}
}
@ -225,19 +174,20 @@ where
macro_rules! impl_handler {
( $($ty:ident),* $(,)? ) => {
#[allow(non_snake_case)]
impl<F, Fut, B, Res, $($ty,)*> Handler<($($ty,)*), B> for F
impl<F, Fut, S, B, Res, $($ty,)*> Handler<($($ty,)*), S, B> for F
where
F: FnOnce($($ty,)*) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
B: Send + 'static,
S: Send + 'static,
Res: IntoResponse,
$( $ty: FromRequest<B> + Send,)*
$( $ty: FromRequest<S, B> + Send,)*
{
type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, req: Request<B>) -> Self::Future {
fn call(self, state: S, req: Request<B>) -> Self::Future {
Box::pin(async move {
let mut req = RequestParts::new(req);
let mut req = RequestParts::with_state(state, req);
$(
let $ty = match $ty::from_request(&mut req).await {
@ -260,44 +210,65 @@ all_the_tuples!(impl_handler);
/// A [`Service`] created from a [`Handler`] by applying a Tower middleware.
///
/// Created with [`Handler::layer`]. See that method for more details.
pub struct Layered<S, T> {
svc: S,
_input: PhantomData<fn() -> T>,
pub struct Layered<L, H, T, S, B> {
layer: L,
handler: H,
_marker: PhantomData<fn() -> (T, S, B)>,
}
impl<S, T> fmt::Debug for Layered<S, T>
impl<L, H, T, S, B> fmt::Debug for Layered<L, H, T, S, B>
where
S: fmt::Debug,
L: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Layered").field("svc", &self.svc).finish()
f.debug_struct("Layered")
.field("layer", &self.layer)
.finish()
}
}
impl<S, T> Clone for Layered<S, T>
impl<L, H, T, S, B> Clone for Layered<L, H, T, S, B>
where
S: Clone,
L: Clone,
H: Clone,
{
fn clone(&self) -> Self {
Self::new(self.svc.clone())
Self {
layer: self.layer.clone(),
handler: self.handler.clone(),
_marker: PhantomData,
}
}
}
impl<S, T, ReqBody> Handler<T, ReqBody> for Layered<S, T>
impl<H, S, T, B, L> Handler<T, S, B> for Layered<L, H, T, S, B>
where
S: Service<Request<ReqBody>, Error = Infallible> + Clone + Send + 'static,
S::Response: IntoResponse,
S::Future: Send,
L: Layer<WithState<H, T, S, B>> + Clone + Send + 'static,
H: Handler<T, S, B>,
L::Service: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
<L::Service as Service<Request<B>>>::Response: IntoResponse,
<L::Service as Service<Request<B>>>::Future: Send,
T: 'static,
ReqBody: Send + 'static,
S: 'static,
B: Send + 'static,
{
type Future = future::LayeredFuture<S, ReqBody>;
type Future = future::LayeredFuture<B, L::Service>;
fn call(self, req: Request<ReqBody>) -> Self::Future {
fn call(self, state: S, req: Request<B>) -> Self::Future {
use futures_util::future::{FutureExt, Map};
let future: Map<_, fn(Result<S::Response, S::Error>) -> _> =
self.svc.oneshot(req).map(|result| match result {
let svc = self.handler.with_state(state);
let svc = self.layer.layer(svc);
let future: Map<
_,
fn(
Result<
<L::Service as Service<Request<B>>>::Response,
<L::Service as Service<Request<B>>>::Error,
>,
) -> _,
> = svc.oneshot(req).map(|result| match result {
Ok(res) => res.into_response(),
Err(err) => match err {},
});
@ -306,12 +277,49 @@ where
}
}
impl<S, T> Layered<S, T> {
pub(crate) fn new(svc: S) -> Self {
Self {
svc,
_input: PhantomData,
/// Extension trait for [`Handler`]s that don't have state.
///
/// This provides convenience methods to convert the [`Handler`] into a [`Service`] or [`MakeService`].
///
/// [`MakeService`]: tower::make::MakeService
pub trait HandlerWithoutStateExt<T, B>: Handler<T, (), B> {
/// Convert the handler into a [`Service`] and no state.
fn into_service(self) -> WithState<Self, T, (), B>;
/// Convert the handler into a [`MakeService`] and no state.
///
/// See [`WithState::into_make_service`] for more details.
///
/// [`MakeService`]: tower::make::MakeService
fn into_make_service(self) -> IntoMakeService<IntoService<Self, T, (), B>>;
/// Convert the handler into a [`MakeService`] which stores information
/// about the incoming connection and has no state.
///
/// See [`WithState::into_make_service_with_connect_info`] for more details.
///
/// [`MakeService`]: tower::make::MakeService
fn into_make_service_with_connect_info<C>(
self,
) -> IntoMakeServiceWithConnectInfo<IntoService<Self, T, (), B>, C>;
}
impl<H, T, B> HandlerWithoutStateExt<T, B> for H
where
H: Handler<T, (), B>,
{
fn into_service(self) -> WithState<Self, T, (), B> {
self.with_state(())
}
fn into_make_service(self) -> IntoMakeService<IntoService<Self, T, (), B>> {
self.with_state(()).into_make_service()
}
fn into_make_service_with_connect_info<C>(
self,
) -> IntoMakeServiceWithConnectInfo<IntoService<Self, T, (), B>, C> {
self.with_state(()).into_make_service_with_connect_info()
}
}

View file

@ -0,0 +1,144 @@
use super::{Handler, IntoService};
use crate::{extract::connect_info::IntoMakeServiceWithConnectInfo, routing::IntoMakeService};
use http::Request;
use std::task::{Context, Poll};
use tower_service::Service;
/// A [`Handler`] which has access to some state.
///
/// Implements [`Service`].
///
/// The state can be extracted with [`State`](crate::extract::State).
///
/// Created with [`Handler::with_state`].
pub struct WithState<H, T, S, B> {
pub(super) service: IntoService<H, T, S, B>,
}
impl<H, T, S, B> WithState<H, T, S, B> {
/// Get a reference to the state.
pub fn state(&self) -> &S {
self.service.state()
}
}
impl<H, T, S, B> WithState<H, T, S, B> {
/// Convert the handler into a [`MakeService`].
///
/// This allows you to serve a single handler if you don't need any routing:
///
/// ```rust
/// use axum::{
/// Server,
/// handler::Handler,
/// extract::State,
/// http::{Uri, Method},
/// response::IntoResponse,
/// };
/// use std::net::SocketAddr;
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// async fn handler(State(state): State<AppState>) {
/// // ...
/// }
///
/// let app = handler.with_state(AppState {});
///
/// # async {
/// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
/// .serve(app.into_make_service())
/// .await?;
/// # Ok::<_, hyper::Error>(())
/// # };
/// ```
///
/// [`MakeService`]: tower::make::MakeService
pub fn into_make_service(self) -> IntoMakeService<IntoService<H, T, S, B>> {
IntoMakeService::new(self.service)
}
/// Convert the handler into a [`MakeService`] which stores information
/// about the incoming connection.
///
/// See [`Router::into_make_service_with_connect_info`] for more details.
///
/// ```rust
/// use axum::{
/// Server,
/// handler::Handler,
/// response::IntoResponse,
/// extract::{ConnectInfo, State},
/// };
/// use std::net::SocketAddr;
///
/// #[derive(Clone)]
/// struct AppState {};
///
/// async fn handler(
/// ConnectInfo(addr): ConnectInfo<SocketAddr>,
/// State(state): State<AppState>,
/// ) -> String {
/// format!("Hello {}", addr)
/// }
///
/// let app = handler.with_state(AppState {});
///
/// # async {
/// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
/// .serve(app.into_make_service_with_connect_info::<SocketAddr>())
/// .await?;
/// # Ok::<_, hyper::Error>(())
/// # };
/// ```
///
/// [`MakeService`]: tower::make::MakeService
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
pub fn into_make_service_with_connect_info<C>(
self,
) -> IntoMakeServiceWithConnectInfo<IntoService<H, T, S, B>, C> {
IntoMakeServiceWithConnectInfo::new(self.service)
}
}
impl<H, T, S, B> Service<Request<B>> for WithState<H, T, S, B>
where
H: Handler<T, S, B> + Clone + Send + 'static,
B: Send + 'static,
S: Clone,
{
type Response = <IntoService<H, T, S, B> as Service<Request<B>>>::Response;
type Error = <IntoService<H, T, S, B> as Service<Request<B>>>::Error;
type Future = <IntoService<H, T, S, B> as Service<Request<B>>>::Future;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
#[inline]
fn call(&mut self, req: Request<B>) -> Self::Future {
self.service.call(req)
}
}
impl<H, T, S, B> std::fmt::Debug for WithState<H, T, S, B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WithState")
.field("service", &self.service)
.finish()
}
}
impl<H, T, S, B> Clone for WithState<H, T, S, B>
where
H: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
service: self.service.clone(),
}
}
}

View file

@ -94,16 +94,17 @@ use std::ops::{Deref, DerefMut};
pub struct Json<T>(pub T);
#[async_trait]
impl<T, B> FromRequest<B> for Json<T>
impl<T, S, B> FromRequest<S, B> for Json<T>
where
T: DeserializeOwned,
B: HttpBody + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
{
type Rejection = JsonRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if json_content_type(req) {
let bytes = Bytes::from_request(req).await?;
@ -136,7 +137,7 @@ where
}
}
fn json_content_type<B>(req: &RequestParts<B>) -> bool {
fn json_content_type<S, B>(req: &RequestParts<S, B>) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
content_type
} else {

View file

@ -168,13 +168,48 @@
//! pool of database connections or clients to other services.
//!
//! The two most common ways of doing that are:
//! - Using the [`State`] extractor.
//! - Using request extensions
//! - Using closure captures
//!
//! ## Using the [`State`] extractor
//!
//! ```rust,no_run
//! use axum::{
//! extract::State,
//! routing::get,
//! Router,
//! };
//! use std::sync::Arc;
//!
//! struct AppState {
//! // ...
//! }
//!
//! let shared_state = Arc::new(AppState { /* ... */ });
//!
//! let app = Router::with_state(shared_state)
//! .route("/", get(handler));
//!
//! async fn handler(
//! State(state): State<Arc<AppState>>,
//! ) {
//! // ...
//! }
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! You should prefer using [`State`] if possible since it's more type safe. The downside is that
//! its less dynamic than request extensions.
//!
//! See [`State`] for more details about accessing state.
//!
//! ## Using request extensions
//!
//! The easiest way to extract state in handlers is using [`Extension`](crate::extract::Extension)
//! as layer and extractor:
//! Another way to extract state in handlers is using [`Extension`](crate::extract::Extension) as
//! layer and extractor:
//!
//! ```rust,no_run
//! use axum::{
@ -184,18 +219,18 @@
//! };
//! use std::sync::Arc;
//!
//! struct State {
//! struct AppState {
//! // ...
//! }
//!
//! let shared_state = Arc::new(State { /* ... */ });
//! let shared_state = Arc::new(AppState { /* ... */ });
//!
//! let app = Router::new()
//! .route("/", get(handler))
//! .layer(Extension(shared_state));
//!
//! async fn handler(
//! Extension(state): Extension<Arc<State>>,
//! Extension(state): Extension<Arc<AppState>>,
//! ) {
//! // ...
//! }
@ -223,11 +258,11 @@
//! use std::sync::Arc;
//! use serde::Deserialize;
//!
//! struct State {
//! struct AppState {
//! // ...
//! }
//!
//! let shared_state = Arc::new(State { /* ... */ });
//! let shared_state = Arc::new(AppState { /* ... */ });
//!
//! let app = Router::new()
//! .route(
@ -245,11 +280,11 @@
//! }),
//! );
//!
//! async fn get_user(Path(user_id): Path<String>, state: Arc<State>) {
//! async fn get_user(Path(user_id): Path<String>, state: Arc<AppState>) {
//! // ...
//! }
//!
//! async fn create_user(Json(payload): Json<CreateUserPayload>, state: Arc<State>) {
//! async fn create_user(Json(payload): Json<CreateUserPayload>, state: Arc<AppState>) {
//! // ...
//! }
//!
@ -263,7 +298,7 @@
//! ```
//!
//! The downside to this approach is that it's a little more verbose than using
//! extensions.
//! [`State`] or extensions.
//!
//! # Building integrations for axum
//!
@ -350,6 +385,7 @@
//! [`Infallible`]: std::convert::Infallible
//! [load shed]: tower::load_shed
//! [`axum-core`]: http://crates.io/crates/axum-core
//! [`State`]: crate::extract::State
#![warn(
clippy::all,

View file

@ -45,13 +45,14 @@ use tower_service::Service;
/// struct RequireAuth;
///
/// #[async_trait]
/// impl<B> FromRequest<B> for RequireAuth
/// impl<S, B> FromRequest<S, B> for RequireAuth
/// where
/// B: Send,
/// S: Send,
/// {
/// type Rejection = StatusCode;
///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// let auth_header = req
/// .headers()
/// .get(header::AUTHORIZATION)
@ -166,23 +167,23 @@ where
}
}
impl<S, E, ReqBody> Service<Request<ReqBody>> for FromExtractor<S, E>
impl<S, E, B> Service<Request<B>> for FromExtractor<S, E>
where
E: FromRequest<ReqBody> + 'static,
ReqBody: Default + Send + 'static,
S: Service<Request<ReqBody>> + Clone,
E: FromRequest<(), B> + 'static,
B: Default + Send + 'static,
S: Service<Request<B>> + Clone,
S::Response: IntoResponse,
{
type Response = Response;
type Error = S::Error;
type Future = ResponseFuture<ReqBody, S, E>;
type Future = ResponseFuture<B, S, E>;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
fn call(&mut self, req: Request<B>) -> Self::Future {
let extract_future = Box::pin(async move {
let mut req = RequestParts::new(req);
let extracted = E::from_request(&mut req).await;
@ -201,35 +202,37 @@ where
pin_project! {
/// Response future for [`FromExtractor`].
#[allow(missing_debug_implementations)]
pub struct ResponseFuture<ReqBody, S, E>
pub struct ResponseFuture<B, S, E>
where
E: FromRequest<ReqBody>,
S: Service<Request<ReqBody>>,
E: FromRequest<(), B>,
S: Service<Request<B>>,
{
#[pin]
state: State<ReqBody, S, E>,
state: State<B, S, E>,
svc: Option<S>,
}
}
pin_project! {
#[project = StateProj]
enum State<ReqBody, S, E>
enum State<B, S, E>
where
E: FromRequest<ReqBody>,
S: Service<Request<ReqBody>>,
E: FromRequest<(), B>,
S: Service<Request<B>>,
{
Extracting { future: BoxFuture<'static, (RequestParts<ReqBody>, Result<E, E::Rejection>)> },
Extracting {
future: BoxFuture<'static, (RequestParts<(), B>, Result<E, E::Rejection>)>,
},
Call { #[pin] future: S::Future },
}
}
impl<ReqBody, S, E> Future for ResponseFuture<ReqBody, S, E>
impl<B, S, E> Future for ResponseFuture<B, S, E>
where
E: FromRequest<ReqBody>,
S: Service<Request<ReqBody>>,
E: FromRequest<(), B>,
S: Service<Request<B>>,
S::Response: IntoResponse,
ReqBody: Default,
B: Default,
{
type Output = Result<Response, S::Error>;
@ -277,13 +280,14 @@ mod tests {
struct RequireAuth;
#[async_trait::async_trait]
impl<B> FromRequest<B> for RequireAuth
impl<S, B> FromRequest<S, B> for RequireAuth
where
B: Send,
S: Send,
{
type Rejection = StatusCode;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if let Some(auth) = req
.headers()
.get(header::AUTHORIZATION)

Some files were not shown because too many files have changed in this diff Show more