mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-21 07:50:49 +01:00
Add middleware::{from_fn_with_state, from_fn_with_state_arc}
(#1342)
This commit is contained in:
parent
3f92f7d254
commit
4c9edb4cd4
4 changed files with 74 additions and 115 deletions
|
@ -13,7 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
without any routes will now result in a panic. Previously, this just did
|
without any routes will now result in a panic. Previously, this just did
|
||||||
nothing. [#1327]
|
nothing. [#1327]
|
||||||
|
|
||||||
|
## Middleware
|
||||||
|
|
||||||
|
- **added**: Add `middleware::from_fn_with_state` and
|
||||||
|
`middleware::from_fn_with_state_arc` to enable running extractors that require
|
||||||
|
state ([#1342])
|
||||||
|
|
||||||
[#1327]: https://github.com/tokio-rs/axum/pull/1327
|
[#1327]: https://github.com/tokio-rs/axum/pull/1327
|
||||||
|
[#1342]: https://github.com/tokio-rs/axum/pull/1342
|
||||||
|
|
||||||
# 0.6.0-rc.1 (23. August, 2022)
|
# 0.6.0-rc.1 (23. August, 2022)
|
||||||
|
|
||||||
|
|
|
@ -390,45 +390,12 @@ middleware you don't have to worry about any of this.
|
||||||
|
|
||||||
# Accessing state in middleware
|
# Accessing state in middleware
|
||||||
|
|
||||||
Handlers can access state using the [`State`] extractor but this isn't available
|
How to make state available to middleware depends on how the middleware is
|
||||||
to middleware. Instead you have to pass the state directly to middleware using
|
written.
|
||||||
either closure captures (for [`axum::middleware::from_fn`]) or regular struct
|
|
||||||
fields (if you're implementing a [`tower::Layer`])
|
|
||||||
|
|
||||||
## Accessing state in `axum::middleware::from_fn`
|
## Accessing state in `axum::middleware::from_fn`
|
||||||
|
|
||||||
```rust
|
Use [`axum::middleware::from_fn_with_state`](crate::middleware::from_fn_with_state).
|
||||||
use axum::{
|
|
||||||
Router,
|
|
||||||
routing::get,
|
|
||||||
middleware::{self, Next},
|
|
||||||
response::Response,
|
|
||||||
extract::State,
|
|
||||||
http::Request,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct AppState {}
|
|
||||||
|
|
||||||
async fn my_middleware<B>(
|
|
||||||
state: AppState,
|
|
||||||
req: Request<B>,
|
|
||||||
next: Next<B>,
|
|
||||||
) -> Response {
|
|
||||||
next.run(req).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handler(_: State<AppState>) {}
|
|
||||||
|
|
||||||
let state = AppState {};
|
|
||||||
|
|
||||||
let app = Router::with_state(state.clone())
|
|
||||||
.route("/", get(handler))
|
|
||||||
.layer(middleware::from_fn(move |req, next| {
|
|
||||||
my_middleware(state.clone(), req, next)
|
|
||||||
}));
|
|
||||||
# let _: Router<_> = app;
|
|
||||||
```
|
|
||||||
|
|
||||||
## Accessing state in custom `tower::Layer`s
|
## Accessing state in custom `tower::Layer`s
|
||||||
|
|
||||||
|
@ -482,7 +449,10 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, req: Request<B>) -> Self::Future {
|
fn call(&mut self, req: Request<B>) -> Self::Future {
|
||||||
// do something with `self.state`
|
// Do something with `self.state`.
|
||||||
|
//
|
||||||
|
// See `axum::RequestExt` for how to run extractors directly from
|
||||||
|
// a `Request`.
|
||||||
|
|
||||||
self.inner.call(req)
|
self.inner.call(req)
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ use std::{
|
||||||
future::Future,
|
future::Future,
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
|
sync::Arc,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
use tower::{util::BoxCloneService, ServiceBuilder};
|
use tower::{util::BoxCloneService, ServiceBuilder};
|
||||||
|
@ -90,9 +91,16 @@ use tower_service::Service;
|
||||||
/// # let app: Router = app;
|
/// # let app: Router = app;
|
||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
/// # Passing state
|
/// [extractors]: crate::extract::FromRequest
|
||||||
|
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
|
||||||
|
from_fn_with_state((), f)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a middleware from an async function with the given state.
|
||||||
///
|
///
|
||||||
/// State can be passed to the function like so:
|
/// See [`State`](crate::extract::State) for more details about accessing state.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// use axum::{
|
/// use axum::{
|
||||||
|
@ -100,49 +108,15 @@ use tower_service::Service;
|
||||||
/// http::{Request, StatusCode},
|
/// http::{Request, StatusCode},
|
||||||
/// routing::get,
|
/// routing::get,
|
||||||
/// response::{IntoResponse, Response},
|
/// response::{IntoResponse, Response},
|
||||||
/// middleware::{self, Next}
|
|
||||||
/// };
|
|
||||||
///
|
|
||||||
/// #[derive(Clone)]
|
|
||||||
/// struct State { /* ... */ }
|
|
||||||
///
|
|
||||||
/// async fn my_middleware<B>(
|
|
||||||
/// req: Request<B>,
|
|
||||||
/// next: Next<B>,
|
|
||||||
/// state: State,
|
|
||||||
/// ) -> Response {
|
|
||||||
/// // ...
|
|
||||||
/// # ().into_response()
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// let state = State { /* ... */ };
|
|
||||||
///
|
|
||||||
/// let app = Router::new()
|
|
||||||
/// .route("/", get(|| async { /* ... */ }))
|
|
||||||
/// .route_layer(middleware::from_fn(move |req, next| {
|
|
||||||
/// my_middleware(req, next, state.clone())
|
|
||||||
/// }));
|
|
||||||
/// # let app: Router = app;
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// Or via extensions:
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// use axum::{
|
|
||||||
/// Router,
|
|
||||||
/// extract::Extension,
|
|
||||||
/// http::{Request, StatusCode},
|
|
||||||
/// routing::get,
|
|
||||||
/// response::{IntoResponse, Response},
|
|
||||||
/// middleware::{self, Next},
|
/// middleware::{self, Next},
|
||||||
|
/// extract::State,
|
||||||
/// };
|
/// };
|
||||||
/// use tower::ServiceBuilder;
|
|
||||||
///
|
///
|
||||||
/// #[derive(Clone)]
|
/// #[derive(Clone)]
|
||||||
/// struct State { /* ... */ }
|
/// struct AppState { /* ... */ }
|
||||||
///
|
///
|
||||||
/// async fn my_middleware<B>(
|
/// async fn my_middleware<B>(
|
||||||
/// Extension(state): Extension<State>,
|
/// State(state): State<AppState>,
|
||||||
/// req: Request<B>,
|
/// req: Request<B>,
|
||||||
/// next: Next<B>,
|
/// next: Next<B>,
|
||||||
/// ) -> Response {
|
/// ) -> Response {
|
||||||
|
@ -150,22 +124,24 @@ use tower_service::Service;
|
||||||
/// # ().into_response()
|
/// # ().into_response()
|
||||||
/// }
|
/// }
|
||||||
///
|
///
|
||||||
/// let state = State { /* ... */ };
|
/// let state = AppState { /* ... */ };
|
||||||
///
|
///
|
||||||
/// let app = Router::new()
|
/// let app = Router::with_state(state.clone())
|
||||||
/// .route("/", get(|| async { /* ... */ }))
|
/// .route("/", get(|| async { /* ... */ }))
|
||||||
/// .layer(
|
/// .route_layer(middleware::from_fn_with_state(state, my_middleware));
|
||||||
/// ServiceBuilder::new()
|
/// # let app: Router<_> = app;
|
||||||
/// .layer(Extension(state))
|
|
||||||
/// .layer(middleware::from_fn(my_middleware)),
|
|
||||||
/// );
|
|
||||||
/// # let app: Router = app;
|
|
||||||
/// ```
|
/// ```
|
||||||
|
pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
|
||||||
|
from_fn_with_state_arc(Arc::new(state), f)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a middleware from an async function with the given [`Arc`]'ed state.
|
||||||
///
|
///
|
||||||
/// [extractors]: crate::extract::FromRequest
|
/// See [`State`](crate::extract::State) for more details about accessing state.
|
||||||
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, T> {
|
pub fn from_fn_with_state_arc<F, S, T>(state: Arc<S>, f: F) -> FromFnLayer<F, S, T> {
|
||||||
FromFnLayer {
|
FromFnLayer {
|
||||||
f,
|
f,
|
||||||
|
state,
|
||||||
_extractor: PhantomData,
|
_extractor: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -175,45 +151,50 @@ pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, T> {
|
||||||
/// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
|
/// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
|
||||||
///
|
///
|
||||||
/// Created with [`from_fn`]. See that function for more details.
|
/// Created with [`from_fn`]. See that function for more details.
|
||||||
pub struct FromFnLayer<F, T> {
|
pub struct FromFnLayer<F, S, T> {
|
||||||
f: F,
|
f: F,
|
||||||
|
state: Arc<S>,
|
||||||
_extractor: PhantomData<fn() -> T>,
|
_extractor: PhantomData<fn() -> T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F, T> Clone for FromFnLayer<F, T>
|
impl<F, S, T> Clone for FromFnLayer<F, S, T>
|
||||||
where
|
where
|
||||||
F: Clone,
|
F: Clone,
|
||||||
{
|
{
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
f: self.f.clone(),
|
f: self.f.clone(),
|
||||||
|
state: Arc::clone(&self.state),
|
||||||
_extractor: self._extractor,
|
_extractor: self._extractor,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F, T> Copy for FromFnLayer<F, T> where F: Copy {}
|
impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
|
||||||
|
|
||||||
impl<S, F, T> Layer<S> for FromFnLayer<F, T>
|
|
||||||
where
|
where
|
||||||
F: Clone,
|
F: Clone,
|
||||||
{
|
{
|
||||||
type Service = FromFn<F, S, T>;
|
type Service = FromFn<F, S, I, T>;
|
||||||
|
|
||||||
fn layer(&self, inner: S) -> Self::Service {
|
fn layer(&self, inner: I) -> Self::Service {
|
||||||
FromFn {
|
FromFn {
|
||||||
f: self.f.clone(),
|
f: self.f.clone(),
|
||||||
|
state: Arc::clone(&self.state),
|
||||||
inner,
|
inner,
|
||||||
_extractor: PhantomData,
|
_extractor: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F, T> fmt::Debug for FromFnLayer<F, T> {
|
impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
|
||||||
|
where
|
||||||
|
S: fmt::Debug,
|
||||||
|
{
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("FromFnLayer")
|
f.debug_struct("FromFnLayer")
|
||||||
// Write out the type name, without quoting it as `&type_name::<F>()` would
|
// Write out the type name, without quoting it as `&type_name::<F>()` would
|
||||||
.field("f", &format_args!("{}", type_name::<F>()))
|
.field("f", &format_args!("{}", type_name::<F>()))
|
||||||
|
.field("state", &self.state)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -221,52 +202,48 @@ impl<F, T> fmt::Debug for FromFnLayer<F, T> {
|
||||||
/// A middleware created from an async function.
|
/// A middleware created from an async function.
|
||||||
///
|
///
|
||||||
/// Created with [`from_fn`]. See that function for more details.
|
/// Created with [`from_fn`]. See that function for more details.
|
||||||
pub struct FromFn<F, S, T> {
|
pub struct FromFn<F, S, I, T> {
|
||||||
f: F,
|
f: F,
|
||||||
inner: S,
|
inner: I,
|
||||||
|
state: Arc<S>,
|
||||||
_extractor: PhantomData<fn() -> T>,
|
_extractor: PhantomData<fn() -> T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F, S, T> Clone for FromFn<F, S, T>
|
impl<F, S, I, T> Clone for FromFn<F, S, I, T>
|
||||||
where
|
where
|
||||||
F: Clone,
|
F: Clone,
|
||||||
S: Clone,
|
I: Clone,
|
||||||
{
|
{
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
f: self.f.clone(),
|
f: self.f.clone(),
|
||||||
inner: self.inner.clone(),
|
inner: self.inner.clone(),
|
||||||
|
state: Arc::clone(&self.state),
|
||||||
_extractor: self._extractor,
|
_extractor: self._extractor,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F, S, T> Copy for FromFn<F, S, T>
|
|
||||||
where
|
|
||||||
F: Copy,
|
|
||||||
S: Copy,
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! impl_service {
|
macro_rules! impl_service {
|
||||||
(
|
(
|
||||||
[$($ty:ident),*], $last:ident
|
[$($ty:ident),*], $last:ident
|
||||||
) => {
|
) => {
|
||||||
#[allow(non_snake_case, unused_mut)]
|
#[allow(non_snake_case, unused_mut)]
|
||||||
impl<F, Fut, Out, S, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, ($($ty,)* $last,)>
|
impl<F, Fut, Out, S, I, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, I, ($($ty,)* $last,)>
|
||||||
where
|
where
|
||||||
F: FnMut($($ty,)* $last, Next<B>) -> Fut + Clone + Send + 'static,
|
F: FnMut($($ty,)* $last, Next<B>) -> Fut + Clone + Send + 'static,
|
||||||
$( $ty: FromRequestParts<()> + Send, )*
|
$( $ty: FromRequestParts<S> + Send, )*
|
||||||
$last: FromRequest<(), B> + Send,
|
$last: FromRequest<S, B> + Send,
|
||||||
Fut: Future<Output = Out> + Send + 'static,
|
Fut: Future<Output = Out> + Send + 'static,
|
||||||
Out: IntoResponse + 'static,
|
Out: IntoResponse + 'static,
|
||||||
S: Service<Request<B>, Error = Infallible>
|
I: Service<Request<B>, Error = Infallible>
|
||||||
+ Clone
|
+ Clone
|
||||||
+ Send
|
+ Send
|
||||||
+ 'static,
|
+ 'static,
|
||||||
S::Response: IntoResponse,
|
I::Response: IntoResponse,
|
||||||
S::Future: Send + 'static,
|
I::Future: Send + 'static,
|
||||||
B: Send + 'static,
|
B: Send + 'static,
|
||||||
|
S: Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
type Response = Response;
|
type Response = Response;
|
||||||
type Error = Infallible;
|
type Error = Infallible;
|
||||||
|
@ -281,12 +258,13 @@ macro_rules! impl_service {
|
||||||
let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
|
let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
|
||||||
|
|
||||||
let mut f = self.f.clone();
|
let mut f = self.f.clone();
|
||||||
|
let state = Arc::clone(&self.state);
|
||||||
|
|
||||||
let future = Box::pin(async move {
|
let future = Box::pin(async move {
|
||||||
let (mut parts, body) = req.into_parts();
|
let (mut parts, body) = req.into_parts();
|
||||||
|
|
||||||
$(
|
$(
|
||||||
let $ty = match $ty::from_request_parts(&mut parts, &()).await {
|
let $ty = match $ty::from_request_parts(&mut parts, &state).await {
|
||||||
Ok(value) => value,
|
Ok(value) => value,
|
||||||
Err(rejection) => return rejection.into_response(),
|
Err(rejection) => return rejection.into_response(),
|
||||||
};
|
};
|
||||||
|
@ -294,7 +272,7 @@ macro_rules! impl_service {
|
||||||
|
|
||||||
let req = Request::from_parts(parts, body);
|
let req = Request::from_parts(parts, body);
|
||||||
|
|
||||||
let $last = match $last::from_request(req, &()).await {
|
let $last = match $last::from_request(req, &state).await {
|
||||||
Ok(value) => value,
|
Ok(value) => value,
|
||||||
Err(rejection) => return rejection.into_response(),
|
Err(rejection) => return rejection.into_response(),
|
||||||
};
|
};
|
||||||
|
@ -342,14 +320,16 @@ impl_service!(
|
||||||
T16
|
T16
|
||||||
);
|
);
|
||||||
|
|
||||||
impl<F, S, T> fmt::Debug for FromFn<F, S, T>
|
impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
|
||||||
where
|
where
|
||||||
S: fmt::Debug,
|
S: fmt::Debug,
|
||||||
|
I: fmt::Debug,
|
||||||
{
|
{
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("FromFnLayer")
|
f.debug_struct("FromFnLayer")
|
||||||
.field("f", &format_args!("{}", type_name::<F>()))
|
.field("f", &format_args!("{}", type_name::<F>()))
|
||||||
.field("inner", &self.inner)
|
.field("inner", &self.inner)
|
||||||
|
.field("state", &self.state)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,9 @@ mod from_extractor;
|
||||||
mod from_fn;
|
mod from_fn;
|
||||||
|
|
||||||
pub use self::from_extractor::{from_extractor, FromExtractor, FromExtractorLayer};
|
pub use self::from_extractor::{from_extractor, FromExtractor, FromExtractorLayer};
|
||||||
pub use self::from_fn::{from_fn, FromFn, FromFnLayer, Next};
|
pub use self::from_fn::{
|
||||||
|
from_fn, from_fn_with_state, from_fn_with_state_arc, FromFn, FromFnLayer, Next,
|
||||||
|
};
|
||||||
pub use crate::extension::AddExtension;
|
pub use crate::extension::AddExtension;
|
||||||
|
|
||||||
pub mod future {
|
pub mod future {
|
||||||
|
|
Loading…
Reference in a new issue