Support any middleware response that implements IntoResponse (#1152)

* Support any middleware response that implements `IntoResponse`

* Require `Infallible` for middleware added with `Handler::layer`
This commit is contained in:
David Pedersen 2022-07-13 10:38:19 +02:00 committed by GitHub
parent 0a7fdd0b05
commit 329bd5f9b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 54 additions and 47 deletions

View file

@ -25,6 +25,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Added `debug_handler` which is an attribute macro that improves
type errors when applied to handler function. It is re-exported from
`axum-macros`
- **added:** Support any middleware response that implements `IntoResponse` ([#1152])
- **breaking:** Require middleware added with `Handler::layer` to have
`Infallible` as the error type ([#1152])
[#1077]: https://github.com/tokio-rs/axum/pull/1077
[#1088]: https://github.com/tokio-rs/axum/pull/1088
@ -32,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1119]: https://github.com/tokio-rs/axum/pull/1119
[#1130]: https://github.com/tokio-rs/axum/pull/1130
[#1135]: https://github.com/tokio-rs/axum/pull/1135
[#1152]: https://github.com/tokio-rs/axum/pull/1152
[#924]: https://github.com/tokio-rs/axum/pull/924
# 0.5.10 (28. June, 2022)

View file

@ -36,14 +36,13 @@
#![doc = include_str!("../docs/debugging_handler_type_errors.md")]
use crate::{
body::{boxed, Body, Bytes, HttpBody},
body::Body,
extract::{connect_info::IntoMakeServiceWithConnectInfo, FromRequest, RequestParts},
response::{IntoResponse, Response},
routing::IntoMakeService,
BoxError,
};
use http::Request;
use std::{fmt, future::Future, marker::PhantomData, pin::Pin};
use std::{convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin};
use tower::ServiceExt;
use tower_layer::Layer;
use tower_service::Service;
@ -284,15 +283,13 @@ where
}
}
impl<S, T, ReqBody, ResBody> Handler<T, ReqBody> for Layered<S, T>
impl<S, T, ReqBody> Handler<T, ReqBody> for Layered<S, T>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Error: IntoResponse,
S: Service<Request<ReqBody>, Error = Infallible> + Clone + Send + 'static,
S::Response: IntoResponse,
S::Future: Send,
T: 'static,
ReqBody: Send + 'static,
ResBody: HttpBody<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
{
type Future = future::LayeredFuture<S, ReqBody>;
@ -301,8 +298,8 @@ where
let future: Map<_, fn(Result<S::Response, S::Error>) -> _> =
self.svc.oneshot(req).map(|result| match result {
Ok(res) => res.map(boxed),
Err(res) => res.into_response(),
Ok(res) => res.into_response(),
Err(err) => match err {},
});
future::LayeredFuture::new(future)

View file

@ -9,6 +9,7 @@ use crate::{
routing::{future::RouteFuture, Fallback, MethodFilter, Route},
BoxError,
};
use axum_core::response::IntoResponse;
use bytes::BytesMut;
use std::{
convert::Infallible,
@ -16,8 +17,7 @@ use std::{
marker::PhantomData,
task::{Context, Poll},
};
use tower::{service_fn, ServiceBuilder, ServiceExt};
use tower_http::map_response_body::MapResponseBodyLayer;
use tower::{service_fn, util::MapResponseLayer, ServiceBuilder, ServiceExt};
use tower_layer::Layer;
use tower_service::Service;
@ -731,23 +731,16 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
}
#[doc = include_str!("../docs/method_routing/layer.md")]
pub fn layer<L, NewReqBody, NewResBody, NewError>(
self,
layer: L,
) -> MethodRouter<NewReqBody, NewError>
pub fn layer<L, NewReqBody, NewError>(self, layer: L) -> MethodRouter<NewReqBody, NewError>
where
L: Layer<Route<ReqBody, E>>,
L::Service: Service<Request<NewReqBody>, Response = Response<NewResBody>, Error = NewError>
+ Clone
+ Send
+ 'static,
L::Service: Service<Request<NewReqBody>, Error = NewError> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewResBody: HttpBody<Data = Bytes> + Send + 'static,
NewResBody::Error: Into<BoxError>,
{
let layer = ServiceBuilder::new()
.layer_fn(Route::new)
.layer(MapResponseBodyLayer::new(boxed))
.layer(MapResponseLayer::new(IntoResponse::into_response))
.layer(layer)
.into_inner();
let layer_fn = |s| layer.layer(s);
@ -768,20 +761,16 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
}
#[doc = include_str!("../docs/method_routing/route_layer.md")]
pub fn route_layer<L, NewResBody>(self, layer: L) -> MethodRouter<ReqBody, E>
pub fn route_layer<L>(self, layer: L) -> MethodRouter<ReqBody, E>
where
L: Layer<Route<ReqBody, E>>,
L::Service: Service<Request<ReqBody>, Response = Response<NewResBody>, Error = E>
+ Clone
+ Send
+ 'static,
L::Service: Service<Request<ReqBody>, Error = E> + Clone + Send + 'static,
<L::Service as Service<Request<ReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<ReqBody>>>::Future: Send + 'static,
NewResBody: HttpBody<Data = Bytes> + Send + 'static,
NewResBody::Error: Into<BoxError>,
{
let layer = ServiceBuilder::new()
.layer_fn(Route::new)
.layer(MapResponseBodyLayer::new(boxed))
.layer(MapResponseLayer::new(IntoResponse::into_response))
.layer(layer)
.into_inner();
let layer_fn = |s| layer.layer(s);

View file

@ -2,13 +2,13 @@
use self::{future::RouteFuture, not_found::NotFound};
use crate::{
body::{boxed, Body, Bytes, HttpBody},
body::{Body, HttpBody},
extract::connect_info::IntoMakeServiceWithConnectInfo,
response::Response,
routing::strip_prefix::StripPrefix,
util::try_downcast,
BoxError,
};
use axum_core::response::IntoResponse;
use http::Request;
use matchit::MatchError;
use std::{
@ -19,8 +19,7 @@ use std::{
sync::Arc,
task::{Context, Poll},
};
use tower::{layer::layer_fn, ServiceBuilder};
use tower_http::map_response_body::MapResponseBodyLayer;
use tower::{layer::layer_fn, util::MapResponseLayer, ServiceBuilder};
use tower_layer::Layer;
use tower_service::Service;
@ -291,19 +290,17 @@ where
}
#[doc = include_str!("../docs/routing/layer.md")]
pub fn layer<L, NewReqBody, NewResBody>(self, layer: L) -> Router<NewReqBody>
pub fn layer<L, NewReqBody>(self, layer: L) -> Router<NewReqBody>
where
L: Layer<Route<B>>,
L::Service:
Service<Request<NewReqBody>, Response = Response<NewResBody>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewResBody: HttpBody<Data = Bytes> + Send + 'static,
NewResBody::Error: Into<BoxError>,
{
let layer = ServiceBuilder::new()
.map_err(Into::into)
.layer(MapResponseBodyLayer::new(boxed))
.layer(MapResponseLayer::new(IntoResponse::into_response))
.layer(layer)
.into_inner();
@ -332,18 +329,17 @@ where
}
#[doc = include_str!("../docs/routing/route_layer.md")]
pub fn route_layer<L, NewResBody>(self, layer: L) -> Self
pub fn route_layer<L>(self, layer: L) -> Self
where
L: Layer<Route<B>>,
L::Service: Service<Request<B>, Response = Response<NewResBody>> + Clone + Send + 'static,
L::Service: Service<Request<B>> + Clone + Send + 'static,
<L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<B>>>::Future: Send + 'static,
NewResBody: HttpBody<Data = Bytes> + Send + 'static,
NewResBody::Error: Into<BoxError>,
{
let layer = ServiceBuilder::new()
.map_err(Into::into)
.layer(MapResponseBodyLayer::new(boxed))
.layer(MapResponseLayer::new(IntoResponse::into_response))
.layer(layer)
.into_inner();

View file

@ -19,7 +19,9 @@ use std::{
task::{Context, Poll},
time::Duration,
};
use tower::{service_fn, timeout::TimeoutLayer, ServiceBuilder, ServiceExt};
use tower::{
service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder, ServiceExt,
};
use tower_http::{auth::RequireAuthorizationLayer, limit::RequestBodyLimitLayer};
use tower_service::Service;
@ -720,3 +722,22 @@ async fn limited_body_with_streaming_body() {
.await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
#[tokio::test]
async fn layer_response_into_response() {
fn map_response<B>(_res: Response<B>) -> Result<Response<B>, impl IntoResponse> {
let headers = [("x-foo", "bar")];
let status = StatusCode::IM_A_TEAPOT;
Err((headers, status))
}
let app = Router::new()
.route("/", get(|| async {}))
.layer(MapResponseLayer::new(map_response));
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.headers()["x-foo"], "bar");
assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
}