mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-01 08:56:15 +01:00
Add middleware::{from_fn_with_state, from_fn_with_state_arc}
(#1342)
This commit is contained in:
parent
3f92f7d254
commit
4c9edb4cd4
4 changed files with 74 additions and 115 deletions
|
@ -13,7 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
without any routes will now result in a panic. Previously, this just did
|
||||
nothing. [#1327]
|
||||
|
||||
## Middleware
|
||||
|
||||
- **added**: Add `middleware::from_fn_with_state` and
|
||||
`middleware::from_fn_with_state_arc` to enable running extractors that require
|
||||
state ([#1342])
|
||||
|
||||
[#1327]: https://github.com/tokio-rs/axum/pull/1327
|
||||
[#1342]: https://github.com/tokio-rs/axum/pull/1342
|
||||
|
||||
# 0.6.0-rc.1 (23. August, 2022)
|
||||
|
||||
|
|
|
@ -390,45 +390,12 @@ middleware you don't have to worry about any of this.
|
|||
|
||||
# Accessing state in middleware
|
||||
|
||||
Handlers can access state using the [`State`] extractor but this isn't available
|
||||
to middleware. Instead you have to pass the state directly to middleware using
|
||||
either closure captures (for [`axum::middleware::from_fn`]) or regular struct
|
||||
fields (if you're implementing a [`tower::Layer`])
|
||||
How to make state available to middleware depends on how the middleware is
|
||||
written.
|
||||
|
||||
## Accessing state in `axum::middleware::from_fn`
|
||||
|
||||
```rust
|
||||
use axum::{
|
||||
Router,
|
||||
routing::get,
|
||||
middleware::{self, Next},
|
||||
response::Response,
|
||||
extract::State,
|
||||
http::Request,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {}
|
||||
|
||||
async fn my_middleware<B>(
|
||||
state: AppState,
|
||||
req: Request<B>,
|
||||
next: Next<B>,
|
||||
) -> Response {
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
async fn handler(_: State<AppState>) {}
|
||||
|
||||
let state = AppState {};
|
||||
|
||||
let app = Router::with_state(state.clone())
|
||||
.route("/", get(handler))
|
||||
.layer(middleware::from_fn(move |req, next| {
|
||||
my_middleware(state.clone(), req, next)
|
||||
}));
|
||||
# let _: Router<_> = app;
|
||||
```
|
||||
Use [`axum::middleware::from_fn_with_state`](crate::middleware::from_fn_with_state).
|
||||
|
||||
## Accessing state in custom `tower::Layer`s
|
||||
|
||||
|
@ -482,7 +449,10 @@ where
|
|||
}
|
||||
|
||||
fn call(&mut self, req: Request<B>) -> Self::Future {
|
||||
// do something with `self.state`
|
||||
// Do something with `self.state`.
|
||||
//
|
||||
// See `axum::RequestExt` for how to run extractors directly from
|
||||
// a `Request`.
|
||||
|
||||
self.inner.call(req)
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ use std::{
|
|||
future::Future,
|
||||
marker::PhantomData,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tower::{util::BoxCloneService, ServiceBuilder};
|
||||
|
@ -90,9 +91,16 @@ use tower_service::Service;
|
|||
/// # let app: Router = app;
|
||||
/// ```
|
||||
///
|
||||
/// # Passing state
|
||||
/// [extractors]: crate::extract::FromRequest
|
||||
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
|
||||
from_fn_with_state((), f)
|
||||
}
|
||||
|
||||
/// Create a middleware from an async function with the given state.
|
||||
///
|
||||
/// State can be passed to the function like so:
|
||||
/// See [`State`](crate::extract::State) for more details about accessing state.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use axum::{
|
||||
|
@ -100,49 +108,15 @@ use tower_service::Service;
|
|||
/// http::{Request, StatusCode},
|
||||
/// routing::get,
|
||||
/// response::{IntoResponse, Response},
|
||||
/// middleware::{self, Next}
|
||||
/// };
|
||||
///
|
||||
/// #[derive(Clone)]
|
||||
/// struct State { /* ... */ }
|
||||
///
|
||||
/// async fn my_middleware<B>(
|
||||
/// req: Request<B>,
|
||||
/// next: Next<B>,
|
||||
/// state: State,
|
||||
/// ) -> Response {
|
||||
/// // ...
|
||||
/// # ().into_response()
|
||||
/// }
|
||||
///
|
||||
/// let state = State { /* ... */ };
|
||||
///
|
||||
/// let app = Router::new()
|
||||
/// .route("/", get(|| async { /* ... */ }))
|
||||
/// .route_layer(middleware::from_fn(move |req, next| {
|
||||
/// my_middleware(req, next, state.clone())
|
||||
/// }));
|
||||
/// # let app: Router = app;
|
||||
/// ```
|
||||
///
|
||||
/// Or via extensions:
|
||||
///
|
||||
/// ```rust
|
||||
/// use axum::{
|
||||
/// Router,
|
||||
/// extract::Extension,
|
||||
/// http::{Request, StatusCode},
|
||||
/// routing::get,
|
||||
/// response::{IntoResponse, Response},
|
||||
/// middleware::{self, Next},
|
||||
/// extract::State,
|
||||
/// };
|
||||
/// use tower::ServiceBuilder;
|
||||
///
|
||||
/// #[derive(Clone)]
|
||||
/// struct State { /* ... */ }
|
||||
/// struct AppState { /* ... */ }
|
||||
///
|
||||
/// async fn my_middleware<B>(
|
||||
/// Extension(state): Extension<State>,
|
||||
/// State(state): State<AppState>,
|
||||
/// req: Request<B>,
|
||||
/// next: Next<B>,
|
||||
/// ) -> Response {
|
||||
|
@ -150,22 +124,24 @@ use tower_service::Service;
|
|||
/// # ().into_response()
|
||||
/// }
|
||||
///
|
||||
/// let state = State { /* ... */ };
|
||||
/// let state = AppState { /* ... */ };
|
||||
///
|
||||
/// let app = Router::new()
|
||||
/// let app = Router::with_state(state.clone())
|
||||
/// .route("/", get(|| async { /* ... */ }))
|
||||
/// .layer(
|
||||
/// ServiceBuilder::new()
|
||||
/// .layer(Extension(state))
|
||||
/// .layer(middleware::from_fn(my_middleware)),
|
||||
/// );
|
||||
/// # let app: Router = app;
|
||||
/// .route_layer(middleware::from_fn_with_state(state, my_middleware));
|
||||
/// # let app: Router<_> = app;
|
||||
/// ```
|
||||
pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
|
||||
from_fn_with_state_arc(Arc::new(state), f)
|
||||
}
|
||||
|
||||
/// Create a middleware from an async function with the given [`Arc`]'ed state.
|
||||
///
|
||||
/// [extractors]: crate::extract::FromRequest
|
||||
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, T> {
|
||||
/// See [`State`](crate::extract::State) for more details about accessing state.
|
||||
pub fn from_fn_with_state_arc<F, S, T>(state: Arc<S>, f: F) -> FromFnLayer<F, S, T> {
|
||||
FromFnLayer {
|
||||
f,
|
||||
state,
|
||||
_extractor: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -175,45 +151,50 @@ pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, T> {
|
|||
/// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
|
||||
///
|
||||
/// Created with [`from_fn`]. See that function for more details.
|
||||
pub struct FromFnLayer<F, T> {
|
||||
pub struct FromFnLayer<F, S, T> {
|
||||
f: F,
|
||||
state: Arc<S>,
|
||||
_extractor: PhantomData<fn() -> T>,
|
||||
}
|
||||
|
||||
impl<F, T> Clone for FromFnLayer<F, T>
|
||||
impl<F, S, T> Clone for FromFnLayer<F, S, T>
|
||||
where
|
||||
F: Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
f: self.f.clone(),
|
||||
state: Arc::clone(&self.state),
|
||||
_extractor: self._extractor,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, T> Copy for FromFnLayer<F, T> where F: Copy {}
|
||||
|
||||
impl<S, F, T> Layer<S> for FromFnLayer<F, T>
|
||||
impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
|
||||
where
|
||||
F: Clone,
|
||||
{
|
||||
type Service = FromFn<F, S, T>;
|
||||
type Service = FromFn<F, S, I, T>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
fn layer(&self, inner: I) -> Self::Service {
|
||||
FromFn {
|
||||
f: self.f.clone(),
|
||||
state: Arc::clone(&self.state),
|
||||
inner,
|
||||
_extractor: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, T> fmt::Debug for FromFnLayer<F, T> {
|
||||
impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
|
||||
where
|
||||
S: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("FromFnLayer")
|
||||
// Write out the type name, without quoting it as `&type_name::<F>()` would
|
||||
.field("f", &format_args!("{}", type_name::<F>()))
|
||||
.field("state", &self.state)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
@ -221,52 +202,48 @@ impl<F, T> fmt::Debug for FromFnLayer<F, T> {
|
|||
/// A middleware created from an async function.
|
||||
///
|
||||
/// Created with [`from_fn`]. See that function for more details.
|
||||
pub struct FromFn<F, S, T> {
|
||||
pub struct FromFn<F, S, I, T> {
|
||||
f: F,
|
||||
inner: S,
|
||||
inner: I,
|
||||
state: Arc<S>,
|
||||
_extractor: PhantomData<fn() -> T>,
|
||||
}
|
||||
|
||||
impl<F, S, T> Clone for FromFn<F, S, T>
|
||||
impl<F, S, I, T> Clone for FromFn<F, S, I, T>
|
||||
where
|
||||
F: Clone,
|
||||
S: Clone,
|
||||
I: Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
f: self.f.clone(),
|
||||
inner: self.inner.clone(),
|
||||
state: Arc::clone(&self.state),
|
||||
_extractor: self._extractor,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, S, T> Copy for FromFn<F, S, T>
|
||||
where
|
||||
F: Copy,
|
||||
S: Copy,
|
||||
{
|
||||
}
|
||||
|
||||
macro_rules! impl_service {
|
||||
(
|
||||
[$($ty:ident),*], $last:ident
|
||||
) => {
|
||||
#[allow(non_snake_case, unused_mut)]
|
||||
impl<F, Fut, Out, S, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, ($($ty,)* $last,)>
|
||||
impl<F, Fut, Out, S, I, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, I, ($($ty,)* $last,)>
|
||||
where
|
||||
F: FnMut($($ty,)* $last, Next<B>) -> Fut + Clone + Send + 'static,
|
||||
$( $ty: FromRequestParts<()> + Send, )*
|
||||
$last: FromRequest<(), B> + Send,
|
||||
$( $ty: FromRequestParts<S> + Send, )*
|
||||
$last: FromRequest<S, B> + Send,
|
||||
Fut: Future<Output = Out> + Send + 'static,
|
||||
Out: IntoResponse + 'static,
|
||||
S: Service<Request<B>, Error = Infallible>
|
||||
I: Service<Request<B>, Error = Infallible>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
S::Response: IntoResponse,
|
||||
S::Future: Send + 'static,
|
||||
I::Response: IntoResponse,
|
||||
I::Future: Send + 'static,
|
||||
B: Send + 'static,
|
||||
S: Send + Sync + 'static,
|
||||
{
|
||||
type Response = Response;
|
||||
type Error = Infallible;
|
||||
|
@ -281,12 +258,13 @@ macro_rules! impl_service {
|
|||
let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
|
||||
|
||||
let mut f = self.f.clone();
|
||||
let state = Arc::clone(&self.state);
|
||||
|
||||
let future = Box::pin(async move {
|
||||
let (mut parts, body) = req.into_parts();
|
||||
|
||||
$(
|
||||
let $ty = match $ty::from_request_parts(&mut parts, &()).await {
|
||||
let $ty = match $ty::from_request_parts(&mut parts, &state).await {
|
||||
Ok(value) => value,
|
||||
Err(rejection) => return rejection.into_response(),
|
||||
};
|
||||
|
@ -294,7 +272,7 @@ macro_rules! impl_service {
|
|||
|
||||
let req = Request::from_parts(parts, body);
|
||||
|
||||
let $last = match $last::from_request(req, &()).await {
|
||||
let $last = match $last::from_request(req, &state).await {
|
||||
Ok(value) => value,
|
||||
Err(rejection) => return rejection.into_response(),
|
||||
};
|
||||
|
@ -342,14 +320,16 @@ impl_service!(
|
|||
T16
|
||||
);
|
||||
|
||||
impl<F, S, T> fmt::Debug for FromFn<F, S, T>
|
||||
impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
|
||||
where
|
||||
S: fmt::Debug,
|
||||
I: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("FromFnLayer")
|
||||
.field("f", &format_args!("{}", type_name::<F>()))
|
||||
.field("inner", &self.inner)
|
||||
.field("state", &self.state)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,9 @@ mod from_extractor;
|
|||
mod from_fn;
|
||||
|
||||
pub use self::from_extractor::{from_extractor, FromExtractor, FromExtractorLayer};
|
||||
pub use self::from_fn::{from_fn, FromFn, FromFnLayer, Next};
|
||||
pub use self::from_fn::{
|
||||
from_fn, from_fn_with_state, from_fn_with_state_arc, FromFn, FromFnLayer, Next,
|
||||
};
|
||||
pub use crate::extension::AddExtension;
|
||||
|
||||
pub mod future {
|
||||
|
|
Loading…
Reference in a new issue