diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 9a3292fc..a07b6715 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Add `map_request`, `map_request_with_state`, and `map_request_with_state_arc` for transforming the request with an async function ([#1408]) +- **added:** Add `map_response`, `map_response_with_state`, and + `map_response_with_state_arc` for transforming the response with an async + function ([#1414]) - **breaking:** `ContentLengthLimit` has been removed. `Use DefaultBodyLimit` instead ([#1400]) - **changed:** `Router` no longer implements `Service`, call `.into_service()` on it to obtain a `RouterService` that does ([#1368]) @@ -39,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1397]: https://github.com/tokio-rs/axum/pull/1397 [#1400]: https://github.com/tokio-rs/axum/pull/1400 [#1408]: https://github.com/tokio-rs/axum/pull/1408 +[#1414]: https://github.com/tokio-rs/axum/pull/1414 # 0.6.0-rc.2 (10. September, 2022) diff --git a/axum/src/middleware/map_response.rs b/axum/src/middleware/map_response.rs new file mode 100644 index 00000000..a8b332d1 --- /dev/null +++ b/axum/src/middleware/map_response.rs @@ -0,0 +1,341 @@ +use crate::response::{IntoResponse, Response}; +use axum_core::extract::FromRequestParts; +use futures_util::future::BoxFuture; +use http::Request; +use std::{ + any::type_name, + convert::Infallible, + fmt, + future::Future, + marker::PhantomData, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Create a middleware from an async function that transforms a response. +/// +/// This differs from [`tower::util::MapResponse`] in that it allows you to easily run axum-specific +/// extractors. +/// +/// # Example +/// +/// ``` +/// use axum::{ +/// Router, +/// routing::get, +/// middleware::map_response, +/// response::Response, +/// }; +/// +/// async fn set_header(mut response: Response) -> Response { +/// response.headers_mut().insert("x-foo", "foo".parse().unwrap()); +/// response +/// } +/// +/// let app = Router::new() +/// .route("/", get(|| async { /* ... */ })) +/// .layer(map_response(set_header)); +/// # let _: Router = app; +/// ``` +/// +/// # Running extractors +/// +/// ``` +/// use axum::{ +/// Router, +/// routing::get, +/// middleware::map_response, +/// extract::Path, +/// response::Response, +/// }; +/// use std::collections::HashMap; +/// +/// async fn log_path_params( +/// Path(path_params): Path>, +/// response: Response, +/// ) -> Response { +/// tracing::debug!(?path_params); +/// response +/// } +/// +/// let app = Router::new() +/// .route("/", get(|| async { /* ... */ })) +/// .layer(map_response(log_path_params)); +/// # let _: Router = app; +/// ``` +/// +/// Note that to access state you must use either [`map_response_with_state`] or [`map_response_with_state_arc`]. +pub fn map_response(f: F) -> MapResponseLayer { + map_response_with_state((), f) +} + +/// Create a middleware from an async function that transforms a response, with the given state. +/// +/// See [`State`](crate::extract::State) for more details about accessing state. +/// +/// # Example +/// +/// ```rust +/// use axum::{ +/// Router, +/// http::StatusCode, +/// routing::get, +/// response::Response, +/// middleware::map_response_with_state, +/// extract::State, +/// }; +/// +/// #[derive(Clone)] +/// struct AppState { /* ... */ } +/// +/// async fn my_middleware( +/// State(state): State, +/// // you can add more extractors here but they must +/// // all implement `FromRequestParts` +/// // `FromRequest` is not allowed +/// response: Response, +/// ) -> Response { +/// // do something with `state` and `response`... +/// response +/// } +/// +/// let state = AppState { /* ... */ }; +/// +/// let app = Router::with_state(state.clone()) +/// .route("/", get(|| async { /* ... */ })) +/// .route_layer(map_response_with_state(state, my_middleware)); +/// # let app: Router<_> = app; +/// ``` +pub fn map_response_with_state(state: S, f: F) -> MapResponseLayer { + map_response_with_state_arc(Arc::new(state), f) +} + +/// Create a middleware from an async function that transforms a response, with the given [`Arc`]'ed +/// state. +/// +/// See [`map_response_with_state`] for an example. +/// +/// See [`State`](crate::extract::State) for more details about accessing state. +pub fn map_response_with_state_arc(state: Arc, f: F) -> MapResponseLayer { + MapResponseLayer { + f, + state, + _extractor: PhantomData, + } +} + +/// A [`tower::Layer`] from an async function that transforms a response. +/// +/// Created with [`map_response`]. See that function for more details. +pub struct MapResponseLayer { + f: F, + state: Arc, + _extractor: PhantomData T>, +} + +impl Clone for MapResponseLayer +where + F: Clone, +{ + fn clone(&self) -> Self { + Self { + f: self.f.clone(), + state: Arc::clone(&self.state), + _extractor: self._extractor, + } + } +} + +impl Layer for MapResponseLayer +where + F: Clone, +{ + type Service = MapResponse; + + fn layer(&self, inner: I) -> Self::Service { + MapResponse { + f: self.f.clone(), + state: Arc::clone(&self.state), + inner, + _extractor: PhantomData, + } + } +} + +impl fmt::Debug for MapResponseLayer +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MapResponseLayer") + // Write out the type name, without quoting it as `&type_name::()` would + .field("f", &format_args!("{}", type_name::())) + .field("state", &self.state) + .finish() + } +} + +/// A middleware created from an async function that transforms a response. +/// +/// Created with [`map_response`]. See that function for more details. +pub struct MapResponse { + f: F, + inner: I, + state: Arc, + _extractor: PhantomData T>, +} + +impl Clone for MapResponse +where + F: Clone, + I: Clone, +{ + fn clone(&self) -> Self { + Self { + f: self.f.clone(), + inner: self.inner.clone(), + state: Arc::clone(&self.state), + _extractor: self._extractor, + } + } +} + +macro_rules! impl_service { + ( + $($ty:ident),* + ) => { + #[allow(non_snake_case, unused_mut)] + impl Service> for MapResponse + where + F: FnMut($($ty,)* Response) -> Fut + Clone + Send + 'static, + $( $ty: FromRequestParts + Send, )* + Fut: Future + Send + 'static, + Fut::Output: IntoResponse + Send + 'static, + I: Service, Response = Response, Error = Infallible> + + Clone + + Send + + 'static, + I::Future: Send + 'static, + B: Send + 'static, + ResBody: Send + 'static, + S: Send + Sync + 'static, + { + type Response = Response; + type Error = Infallible; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + + fn call(&mut self, req: Request) -> Self::Future { + let not_ready_inner = self.inner.clone(); + let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner); + + let mut f = self.f.clone(); + let _state = Arc::clone(&self.state); + + let future = Box::pin(async move { + let (mut parts, body) = req.into_parts(); + + $( + let $ty = match $ty::from_request_parts(&mut parts, &_state).await { + Ok(value) => value, + Err(rejection) => return rejection.into_response(), + }; + )* + + let req = Request::from_parts(parts, body); + + match ready_inner.call(req).await { + Ok(res) => { + f($($ty,)* res).await.into_response() + } + Err(err) => match err {} + } + }); + + ResponseFuture { + inner: future + } + } + } + }; +} + +impl_service!(); +impl_service!(T1); +impl_service!(T1, T2); +impl_service!(T1, T2, T3); +impl_service!(T1, T2, T3, T4); +impl_service!(T1, T2, T3, T4, T5); +impl_service!(T1, T2, T3, T4, T5, T6); +impl_service!(T1, T2, T3, T4, T5, T6, T7); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); + +impl fmt::Debug for MapResponse +where + S: fmt::Debug, + I: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MapResponse") + .field("f", &format_args!("{}", type_name::())) + .field("inner", &self.inner) + .field("state", &self.state) + .finish() + } +} + +/// Response future for [`MapResponse`]. +pub struct ResponseFuture { + inner: BoxFuture<'static, Response>, +} + +impl Future for ResponseFuture { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.inner.as_mut().poll(cx).map(Ok) + } +} + +impl fmt::Debug for ResponseFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ResponseFuture").finish() + } +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use crate::{test_helpers::TestClient, Router}; + + #[tokio::test] + async fn works() { + async fn add_header(mut res: Response) -> Response { + res.headers_mut().insert("x-foo", "foo".parse().unwrap()); + res + } + + let app = Router::new().layer(map_response(add_header)); + let client = TestClient::new(app); + + let res = client.get("/").send().await; + + assert_eq!(res.headers()["x-foo"], "foo"); + } +} diff --git a/axum/src/middleware/mod.rs b/axum/src/middleware/mod.rs index 9e14825c..5b138828 100644 --- a/axum/src/middleware/mod.rs +++ b/axum/src/middleware/mod.rs @@ -5,6 +5,7 @@ mod from_extractor; mod from_fn; mod map_request; +mod map_response; pub use self::from_extractor::{ from_extractor, from_extractor_with_state, from_extractor_with_state_arc, FromExtractor, @@ -17,6 +18,10 @@ pub use self::map_request::{ map_request, map_request_with_state, map_request_with_state_arc, IntoMapRequestResult, MapRequest, MapRequestLayer, }; +pub use self::map_response::{ + map_response, map_response_with_state, map_response_with_state_arc, MapResponse, + MapResponseLayer, +}; pub use crate::extension::AddExtension; pub mod future { @@ -25,4 +30,5 @@ pub mod future { pub use super::from_extractor::ResponseFuture as FromExtractorResponseFuture; pub use super::from_fn::ResponseFuture as FromFnResponseFuture; pub use super::map_request::ResponseFuture as MapRequestResponseFuture; + pub use super::map_response::ResponseFuture as MapResponseResponseFuture; }