diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 705b26ce..b8ae2b57 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -13,7 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 without any routes will now result in a panic. Previously, this just did nothing. [#1327] +## Middleware + +- **added**: Add `middleware::from_fn_with_state` and + `middleware::from_fn_with_state_arc` to enable running extractors that require + state ([#1342]) + [#1327]: https://github.com/tokio-rs/axum/pull/1327 +[#1342]: https://github.com/tokio-rs/axum/pull/1342 # 0.6.0-rc.1 (23. August, 2022) diff --git a/axum/src/docs/middleware.md b/axum/src/docs/middleware.md index 1830bbd6..d070d1ec 100644 --- a/axum/src/docs/middleware.md +++ b/axum/src/docs/middleware.md @@ -390,45 +390,12 @@ middleware you don't have to worry about any of this. # Accessing state in middleware -Handlers can access state using the [`State`] extractor but this isn't available -to middleware. Instead you have to pass the state directly to middleware using -either closure captures (for [`axum::middleware::from_fn`]) or regular struct -fields (if you're implementing a [`tower::Layer`]) +How to make state available to middleware depends on how the middleware is +written. ## Accessing state in `axum::middleware::from_fn` -```rust -use axum::{ - Router, - routing::get, - middleware::{self, Next}, - response::Response, - extract::State, - http::Request, -}; - -#[derive(Clone)] -struct AppState {} - -async fn my_middleware( - state: AppState, - req: Request, - next: Next, -) -> Response { - next.run(req).await -} - -async fn handler(_: State) {} - -let state = AppState {}; - -let app = Router::with_state(state.clone()) - .route("/", get(handler)) - .layer(middleware::from_fn(move |req, next| { - my_middleware(state.clone(), req, next) - })); -# let _: Router<_> = app; -``` +Use [`axum::middleware::from_fn_with_state`](crate::middleware::from_fn_with_state). ## Accessing state in custom `tower::Layer`s @@ -482,7 +449,10 @@ where } fn call(&mut self, req: Request) -> Self::Future { - // do something with `self.state` + // Do something with `self.state`. + // + // See `axum::RequestExt` for how to run extractors directly from + // a `Request`. self.inner.call(req) } diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index 9728a43a..ac22bec1 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -9,6 +9,7 @@ use std::{ future::Future, marker::PhantomData, pin::Pin, + sync::Arc, task::{Context, Poll}, }; use tower::{util::BoxCloneService, ServiceBuilder}; @@ -90,9 +91,16 @@ use tower_service::Service; /// # let app: Router = app; /// ``` /// -/// # Passing state +/// [extractors]: crate::extract::FromRequest +pub fn from_fn(f: F) -> FromFnLayer { + from_fn_with_state((), f) +} + +/// Create a middleware from an async function with the given state. /// -/// State can be passed to the function like so: +/// See [`State`](crate::extract::State) for more details about accessing state. +/// +/// # Example /// /// ```rust /// use axum::{ @@ -100,49 +108,15 @@ use tower_service::Service; /// http::{Request, StatusCode}, /// routing::get, /// response::{IntoResponse, Response}, -/// middleware::{self, Next} -/// }; -/// -/// #[derive(Clone)] -/// struct State { /* ... */ } -/// -/// async fn my_middleware( -/// req: Request, -/// next: Next, -/// state: State, -/// ) -> Response { -/// // ... -/// # ().into_response() -/// } -/// -/// let state = State { /* ... */ }; -/// -/// let app = Router::new() -/// .route("/", get(|| async { /* ... */ })) -/// .route_layer(middleware::from_fn(move |req, next| { -/// my_middleware(req, next, state.clone()) -/// })); -/// # let app: Router = app; -/// ``` -/// -/// Or via extensions: -/// -/// ```rust -/// use axum::{ -/// Router, -/// extract::Extension, -/// http::{Request, StatusCode}, -/// routing::get, -/// response::{IntoResponse, Response}, /// middleware::{self, Next}, +/// extract::State, /// }; -/// use tower::ServiceBuilder; /// /// #[derive(Clone)] -/// struct State { /* ... */ } +/// struct AppState { /* ... */ } /// /// async fn my_middleware( -/// Extension(state): Extension, +/// State(state): State, /// req: Request, /// next: Next, /// ) -> Response { @@ -150,22 +124,24 @@ use tower_service::Service; /// # ().into_response() /// } /// -/// let state = State { /* ... */ }; +/// let state = AppState { /* ... */ }; /// -/// let app = Router::new() +/// let app = Router::with_state(state.clone()) /// .route("/", get(|| async { /* ... */ })) -/// .layer( -/// ServiceBuilder::new() -/// .layer(Extension(state)) -/// .layer(middleware::from_fn(my_middleware)), -/// ); -/// # let app: Router = app; +/// .route_layer(middleware::from_fn_with_state(state, my_middleware)); +/// # let app: Router<_> = app; /// ``` +pub fn from_fn_with_state(state: S, f: F) -> FromFnLayer { + from_fn_with_state_arc(Arc::new(state), f) +} + +/// Create a middleware from an async function with the given [`Arc`]'ed state. /// -/// [extractors]: crate::extract::FromRequest -pub fn from_fn(f: F) -> FromFnLayer { +/// See [`State`](crate::extract::State) for more details about accessing state. +pub fn from_fn_with_state_arc(state: Arc, f: F) -> FromFnLayer { FromFnLayer { f, + state, _extractor: PhantomData, } } @@ -175,45 +151,50 @@ pub fn from_fn(f: F) -> FromFnLayer { /// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s. /// /// Created with [`from_fn`]. See that function for more details. -pub struct FromFnLayer { +pub struct FromFnLayer { f: F, + state: Arc, _extractor: PhantomData T>, } -impl Clone for FromFnLayer +impl Clone for FromFnLayer where F: Clone, { fn clone(&self) -> Self { Self { f: self.f.clone(), + state: Arc::clone(&self.state), _extractor: self._extractor, } } } -impl Copy for FromFnLayer where F: Copy {} - -impl Layer for FromFnLayer +impl Layer for FromFnLayer where F: Clone, { - type Service = FromFn; + type Service = FromFn; - fn layer(&self, inner: S) -> Self::Service { + fn layer(&self, inner: I) -> Self::Service { FromFn { f: self.f.clone(), + state: Arc::clone(&self.state), inner, _extractor: PhantomData, } } } -impl fmt::Debug for FromFnLayer { +impl fmt::Debug for FromFnLayer +where + S: fmt::Debug, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromFnLayer") // Write out the type name, without quoting it as `&type_name::()` would .field("f", &format_args!("{}", type_name::())) + .field("state", &self.state) .finish() } } @@ -221,52 +202,48 @@ impl fmt::Debug for FromFnLayer { /// A middleware created from an async function. /// /// Created with [`from_fn`]. See that function for more details. -pub struct FromFn { +pub struct FromFn { f: F, - inner: S, + inner: I, + state: Arc, _extractor: PhantomData T>, } -impl Clone for FromFn +impl Clone for FromFn where F: Clone, - S: Clone, + I: Clone, { fn clone(&self) -> Self { Self { f: self.f.clone(), inner: self.inner.clone(), + state: Arc::clone(&self.state), _extractor: self._extractor, } } } -impl Copy for FromFn -where - F: Copy, - S: Copy, -{ -} - macro_rules! impl_service { ( [$($ty:ident),*], $last:ident ) => { #[allow(non_snake_case, unused_mut)] - impl Service> for FromFn + impl Service> for FromFn where F: FnMut($($ty,)* $last, Next) -> Fut + Clone + Send + 'static, - $( $ty: FromRequestParts<()> + Send, )* - $last: FromRequest<(), B> + Send, + $( $ty: FromRequestParts + Send, )* + $last: FromRequest + Send, Fut: Future + Send + 'static, Out: IntoResponse + 'static, - S: Service, Error = Infallible> + I: Service, Error = Infallible> + Clone + Send + 'static, - S::Response: IntoResponse, - S::Future: Send + 'static, + I::Response: IntoResponse, + I::Future: Send + 'static, B: Send + 'static, + S: Send + Sync + 'static, { type Response = Response; type Error = Infallible; @@ -281,12 +258,13 @@ macro_rules! impl_service { let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner); let mut f = self.f.clone(); + let state = Arc::clone(&self.state); let future = Box::pin(async move { let (mut parts, body) = req.into_parts(); $( - let $ty = match $ty::from_request_parts(&mut parts, &()).await { + let $ty = match $ty::from_request_parts(&mut parts, &state).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; @@ -294,7 +272,7 @@ macro_rules! impl_service { let req = Request::from_parts(parts, body); - let $last = match $last::from_request(req, &()).await { + let $last = match $last::from_request(req, &state).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; @@ -342,14 +320,16 @@ impl_service!( T16 ); -impl fmt::Debug for FromFn +impl fmt::Debug for FromFn where S: fmt::Debug, + I: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromFnLayer") .field("f", &format_args!("{}", type_name::())) .field("inner", &self.inner) + .field("state", &self.state) .finish() } } diff --git a/axum/src/middleware/mod.rs b/axum/src/middleware/mod.rs index f8be812b..15132da4 100644 --- a/axum/src/middleware/mod.rs +++ b/axum/src/middleware/mod.rs @@ -6,7 +6,9 @@ mod from_extractor; mod from_fn; pub use self::from_extractor::{from_extractor, FromExtractor, FromExtractorLayer}; -pub use self::from_fn::{from_fn, FromFn, FromFnLayer, Next}; +pub use self::from_fn::{ + from_fn, from_fn_with_state, from_fn_with_state_arc, FromFn, FromFnLayer, Next, +}; pub use crate::extension::AddExtension; pub mod future {