Add middleware::{from_fn_with_state, from_fn_with_state_arc} (#1342)

This commit is contained in:
David Pedersen 2022-08-31 20:28:54 +02:00 committed by GitHub
parent 3f92f7d254
commit 4c9edb4cd4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 115 deletions

View file

@ -13,7 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
without any routes will now result in a panic. Previously, this just did without any routes will now result in a panic. Previously, this just did
nothing. [#1327] nothing. [#1327]
## Middleware
- **added**: Add `middleware::from_fn_with_state` and
`middleware::from_fn_with_state_arc` to enable running extractors that require
state ([#1342])
[#1327]: https://github.com/tokio-rs/axum/pull/1327 [#1327]: https://github.com/tokio-rs/axum/pull/1327
[#1342]: https://github.com/tokio-rs/axum/pull/1342
# 0.6.0-rc.1 (23. August, 2022) # 0.6.0-rc.1 (23. August, 2022)

View file

@ -390,45 +390,12 @@ middleware you don't have to worry about any of this.
# Accessing state in middleware # Accessing state in middleware
Handlers can access state using the [`State`] extractor but this isn't available How to make state available to middleware depends on how the middleware is
to middleware. Instead you have to pass the state directly to middleware using written.
either closure captures (for [`axum::middleware::from_fn`]) or regular struct
fields (if you're implementing a [`tower::Layer`])
## Accessing state in `axum::middleware::from_fn` ## Accessing state in `axum::middleware::from_fn`
```rust Use [`axum::middleware::from_fn_with_state`](crate::middleware::from_fn_with_state).
use axum::{
Router,
routing::get,
middleware::{self, Next},
response::Response,
extract::State,
http::Request,
};
#[derive(Clone)]
struct AppState {}
async fn my_middleware<B>(
state: AppState,
req: Request<B>,
next: Next<B>,
) -> Response {
next.run(req).await
}
async fn handler(_: State<AppState>) {}
let state = AppState {};
let app = Router::with_state(state.clone())
.route("/", get(handler))
.layer(middleware::from_fn(move |req, next| {
my_middleware(state.clone(), req, next)
}));
# let _: Router<_> = app;
```
## Accessing state in custom `tower::Layer`s ## Accessing state in custom `tower::Layer`s
@ -482,7 +449,10 @@ where
} }
fn call(&mut self, req: Request<B>) -> Self::Future { fn call(&mut self, req: Request<B>) -> Self::Future {
// do something with `self.state` // Do something with `self.state`.
//
// See `axum::RequestExt` for how to run extractors directly from
// a `Request`.
self.inner.call(req) self.inner.call(req)
} }

View file

@ -9,6 +9,7 @@ use std::{
future::Future, future::Future,
marker::PhantomData, marker::PhantomData,
pin::Pin, pin::Pin,
sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
}; };
use tower::{util::BoxCloneService, ServiceBuilder}; use tower::{util::BoxCloneService, ServiceBuilder};
@ -90,9 +91,16 @@ use tower_service::Service;
/// # let app: Router = app; /// # let app: Router = app;
/// ``` /// ```
/// ///
/// # Passing state /// [extractors]: crate::extract::FromRequest
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
from_fn_with_state((), f)
}
/// Create a middleware from an async function with the given state.
/// ///
/// State can be passed to the function like so: /// See [`State`](crate::extract::State) for more details about accessing state.
///
/// # Example
/// ///
/// ```rust /// ```rust
/// use axum::{ /// use axum::{
@ -100,49 +108,15 @@ use tower_service::Service;
/// http::{Request, StatusCode}, /// http::{Request, StatusCode},
/// routing::get, /// routing::get,
/// response::{IntoResponse, Response}, /// response::{IntoResponse, Response},
/// middleware::{self, Next}
/// };
///
/// #[derive(Clone)]
/// struct State { /* ... */ }
///
/// async fn my_middleware<B>(
/// req: Request<B>,
/// next: Next<B>,
/// state: State,
/// ) -> Response {
/// // ...
/// # ().into_response()
/// }
///
/// let state = State { /* ... */ };
///
/// let app = Router::new()
/// .route("/", get(|| async { /* ... */ }))
/// .route_layer(middleware::from_fn(move |req, next| {
/// my_middleware(req, next, state.clone())
/// }));
/// # let app: Router = app;
/// ```
///
/// Or via extensions:
///
/// ```rust
/// use axum::{
/// Router,
/// extract::Extension,
/// http::{Request, StatusCode},
/// routing::get,
/// response::{IntoResponse, Response},
/// middleware::{self, Next}, /// middleware::{self, Next},
/// extract::State,
/// }; /// };
/// use tower::ServiceBuilder;
/// ///
/// #[derive(Clone)] /// #[derive(Clone)]
/// struct State { /* ... */ } /// struct AppState { /* ... */ }
/// ///
/// async fn my_middleware<B>( /// async fn my_middleware<B>(
/// Extension(state): Extension<State>, /// State(state): State<AppState>,
/// req: Request<B>, /// req: Request<B>,
/// next: Next<B>, /// next: Next<B>,
/// ) -> Response { /// ) -> Response {
@ -150,22 +124,24 @@ use tower_service::Service;
/// # ().into_response() /// # ().into_response()
/// } /// }
/// ///
/// let state = State { /* ... */ }; /// let state = AppState { /* ... */ };
/// ///
/// let app = Router::new() /// let app = Router::with_state(state.clone())
/// .route("/", get(|| async { /* ... */ })) /// .route("/", get(|| async { /* ... */ }))
/// .layer( /// .route_layer(middleware::from_fn_with_state(state, my_middleware));
/// ServiceBuilder::new() /// # let app: Router<_> = app;
/// .layer(Extension(state))
/// .layer(middleware::from_fn(my_middleware)),
/// );
/// # let app: Router = app;
/// ``` /// ```
pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
from_fn_with_state_arc(Arc::new(state), f)
}
/// Create a middleware from an async function with the given [`Arc`]'ed state.
/// ///
/// [extractors]: crate::extract::FromRequest /// See [`State`](crate::extract::State) for more details about accessing state.
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, T> { pub fn from_fn_with_state_arc<F, S, T>(state: Arc<S>, f: F) -> FromFnLayer<F, S, T> {
FromFnLayer { FromFnLayer {
f, f,
state,
_extractor: PhantomData, _extractor: PhantomData,
} }
} }
@ -175,45 +151,50 @@ pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, T> {
/// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s. /// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
/// ///
/// Created with [`from_fn`]. See that function for more details. /// Created with [`from_fn`]. See that function for more details.
pub struct FromFnLayer<F, T> { pub struct FromFnLayer<F, S, T> {
f: F, f: F,
state: Arc<S>,
_extractor: PhantomData<fn() -> T>, _extractor: PhantomData<fn() -> T>,
} }
impl<F, T> Clone for FromFnLayer<F, T> impl<F, S, T> Clone for FromFnLayer<F, S, T>
where where
F: Clone, F: Clone,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
f: self.f.clone(), f: self.f.clone(),
state: Arc::clone(&self.state),
_extractor: self._extractor, _extractor: self._extractor,
} }
} }
} }
impl<F, T> Copy for FromFnLayer<F, T> where F: Copy {} impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
impl<S, F, T> Layer<S> for FromFnLayer<F, T>
where where
F: Clone, F: Clone,
{ {
type Service = FromFn<F, S, T>; type Service = FromFn<F, S, I, T>;
fn layer(&self, inner: S) -> Self::Service { fn layer(&self, inner: I) -> Self::Service {
FromFn { FromFn {
f: self.f.clone(), f: self.f.clone(),
state: Arc::clone(&self.state),
inner, inner,
_extractor: PhantomData, _extractor: PhantomData,
} }
} }
} }
impl<F, T> fmt::Debug for FromFnLayer<F, T> { impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FromFnLayer") f.debug_struct("FromFnLayer")
// Write out the type name, without quoting it as `&type_name::<F>()` would // Write out the type name, without quoting it as `&type_name::<F>()` would
.field("f", &format_args!("{}", type_name::<F>())) .field("f", &format_args!("{}", type_name::<F>()))
.field("state", &self.state)
.finish() .finish()
} }
} }
@ -221,52 +202,48 @@ impl<F, T> fmt::Debug for FromFnLayer<F, T> {
/// A middleware created from an async function. /// A middleware created from an async function.
/// ///
/// Created with [`from_fn`]. See that function for more details. /// Created with [`from_fn`]. See that function for more details.
pub struct FromFn<F, S, T> { pub struct FromFn<F, S, I, T> {
f: F, f: F,
inner: S, inner: I,
state: Arc<S>,
_extractor: PhantomData<fn() -> T>, _extractor: PhantomData<fn() -> T>,
} }
impl<F, S, T> Clone for FromFn<F, S, T> impl<F, S, I, T> Clone for FromFn<F, S, I, T>
where where
F: Clone, F: Clone,
S: Clone, I: Clone,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
f: self.f.clone(), f: self.f.clone(),
inner: self.inner.clone(), inner: self.inner.clone(),
state: Arc::clone(&self.state),
_extractor: self._extractor, _extractor: self._extractor,
} }
} }
} }
impl<F, S, T> Copy for FromFn<F, S, T>
where
F: Copy,
S: Copy,
{
}
macro_rules! impl_service { macro_rules! impl_service {
( (
[$($ty:ident),*], $last:ident [$($ty:ident),*], $last:ident
) => { ) => {
#[allow(non_snake_case, unused_mut)] #[allow(non_snake_case, unused_mut)]
impl<F, Fut, Out, S, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, ($($ty,)* $last,)> impl<F, Fut, Out, S, I, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, I, ($($ty,)* $last,)>
where where
F: FnMut($($ty,)* $last, Next<B>) -> Fut + Clone + Send + 'static, F: FnMut($($ty,)* $last, Next<B>) -> Fut + Clone + Send + 'static,
$( $ty: FromRequestParts<()> + Send, )* $( $ty: FromRequestParts<S> + Send, )*
$last: FromRequest<(), B> + Send, $last: FromRequest<S, B> + Send,
Fut: Future<Output = Out> + Send + 'static, Fut: Future<Output = Out> + Send + 'static,
Out: IntoResponse + 'static, Out: IntoResponse + 'static,
S: Service<Request<B>, Error = Infallible> I: Service<Request<B>, Error = Infallible>
+ Clone + Clone
+ Send + Send
+ 'static, + 'static,
S::Response: IntoResponse, I::Response: IntoResponse,
S::Future: Send + 'static, I::Future: Send + 'static,
B: Send + 'static, B: Send + 'static,
S: Send + Sync + 'static,
{ {
type Response = Response; type Response = Response;
type Error = Infallible; type Error = Infallible;
@ -281,12 +258,13 @@ macro_rules! impl_service {
let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner); let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
let mut f = self.f.clone(); let mut f = self.f.clone();
let state = Arc::clone(&self.state);
let future = Box::pin(async move { let future = Box::pin(async move {
let (mut parts, body) = req.into_parts(); let (mut parts, body) = req.into_parts();
$( $(
let $ty = match $ty::from_request_parts(&mut parts, &()).await { let $ty = match $ty::from_request_parts(&mut parts, &state).await {
Ok(value) => value, Ok(value) => value,
Err(rejection) => return rejection.into_response(), Err(rejection) => return rejection.into_response(),
}; };
@ -294,7 +272,7 @@ macro_rules! impl_service {
let req = Request::from_parts(parts, body); let req = Request::from_parts(parts, body);
let $last = match $last::from_request(req, &()).await { let $last = match $last::from_request(req, &state).await {
Ok(value) => value, Ok(value) => value,
Err(rejection) => return rejection.into_response(), Err(rejection) => return rejection.into_response(),
}; };
@ -342,14 +320,16 @@ impl_service!(
T16 T16
); );
impl<F, S, T> fmt::Debug for FromFn<F, S, T> impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
where where
S: fmt::Debug, S: fmt::Debug,
I: fmt::Debug,
{ {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FromFnLayer") f.debug_struct("FromFnLayer")
.field("f", &format_args!("{}", type_name::<F>())) .field("f", &format_args!("{}", type_name::<F>()))
.field("inner", &self.inner) .field("inner", &self.inner)
.field("state", &self.state)
.finish() .finish()
} }
} }

View file

@ -6,7 +6,9 @@ mod from_extractor;
mod from_fn; mod from_fn;
pub use self::from_extractor::{from_extractor, FromExtractor, FromExtractorLayer}; pub use self::from_extractor::{from_extractor, FromExtractor, FromExtractorLayer};
pub use self::from_fn::{from_fn, FromFn, FromFnLayer, Next}; pub use self::from_fn::{
from_fn, from_fn_with_state, from_fn_with_state_arc, FromFn, FromFnLayer, Next,
};
pub use crate::extension::AddExtension; pub use crate::extension::AddExtension;
pub mod future { pub mod future {