diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index b2e9b416..a583d2c7 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning]. - **added:** Add `RouterExt::route_with_tsr` for adding routes with an additional "trailing slash redirect" route ([#1119]) +- **changed:** For methods that accept some `S: Service`, the bounds have been + relaxed so the response type must implement `IntoResponse` rather than being a + literal `Response` [#1119]: https://github.com/tokio-rs/axum/pull/1119 diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 6298a36e..35993b44 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -3,7 +3,7 @@ use axum::{ handler::Handler, http::Request, - response::{Redirect, Response}, + response::{IntoResponse, Redirect}, Router, }; use std::{convert::Infallible, future::ready}; @@ -161,7 +161,8 @@ pub trait RouterExt<B>: sealed::Sealed { /// ``` fn route_with_tsr<T>(self, path: &str, service: T) -> Self where - T: Service<Request<B>, Response = Response, Error = Infallible> + Clone + Send + 'static, + T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, + T::Response: IntoResponse, T::Future: Send + 'static, Self: Sized; } @@ -252,7 +253,8 @@ where fn route_with_tsr<T>(mut self, path: &str, service: T) -> Self where - T: Service<Request<B>, Response = Response, Error = Infallible> + Clone + Send + 'static, + T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, + T::Response: IntoResponse, T::Future: Send + 'static, Self: Sized, { diff --git a/axum-extra/src/routing/resource.rs b/axum-extra/src/routing/resource.rs index 48fcba01..18e1bde8 100644 --- a/axum-extra/src/routing/resource.rs +++ b/axum-extra/src/routing/resource.rs @@ -2,7 +2,7 @@ use axum::{ body::Body, handler::Handler, http::Request, - response::Response, + response::IntoResponse, routing::{delete, get, on, post, MethodFilter}, Router, }; @@ -141,7 +141,8 @@ where /// The routes will be nested at `/{resource_name}/:{resource_name}_id`. pub fn nest<T>(mut self, svc: T) -> Self where - T: Service<Request<B>, Response = Response, Error = Infallible> + Clone + Send + 'static, + T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, + T::Response: IntoResponse, T::Future: Send + 'static, { let path = self.show_update_destroy_path(); @@ -154,7 +155,8 @@ where /// The routes will be nested at `/{resource_name}`. pub fn nest_collection<T>(mut self, svc: T) -> Self where - T: Service<Request<B>, Response = Response, Error = Infallible> + Clone + Send + 'static, + T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, + T::Response: IntoResponse, T::Future: Send + 'static, { let path = self.index_create_path(); @@ -172,7 +174,8 @@ where fn route<T>(mut self, path: &str, svc: T) -> Self where - T: Service<Request<B>, Response = Response, Error = Infallible> + Clone + Send + 'static, + T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, + T::Response: IntoResponse, T::Future: Send + 'static, { self.router = self.router.route(path, svc); diff --git a/axum-extra/src/routing/spa.rs b/axum-extra/src/routing/spa.rs index 594b15c2..7c67d882 100644 --- a/axum-extra/src/routing/spa.rs +++ b/axum-extra/src/routing/spa.rs @@ -1,7 +1,7 @@ use axum::{ body::{Body, HttpBody}, error_handling::HandleError, - response::Response, + response::IntoResponse, routing::{get_service, Route}, Router, }; @@ -150,8 +150,8 @@ impl<B, T, F> SpaRouter<B, T, F> { impl<B, F, T> From<SpaRouter<B, T, F>> for Router<B> where F: Clone + Send + 'static, - HandleError<Route<B, io::Error>, F, T>: - Service<Request<B>, Response = Response, Error = Infallible>, + HandleError<Route<B, io::Error>, F, T>: Service<Request<B>, Error = Infallible>, + <HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send, <HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Future: Send, B: HttpBody + Send + 'static, T: 'static, diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 3f4b04e3..fa1169fc 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -29,6 +29,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Support any middleware response that implements `IntoResponse` ([#1152]) - **breaking:** Require middleware added with `Handler::layer` to have `Infallible` as the error type ([#1152]) +- **changed:** For methods that accept some `S: Service`, the bounds have been + relaxed so the response type must implement `IntoResponse` rather than being a + literal `Response` [#1171]: https://github.com/tokio-rs/axum/pull/1171 [#1077]: https://github.com/tokio-rs/axum/pull/1077 diff --git a/axum/src/error_handling/mod.rs b/axum/src/error_handling/mod.rs index e0027eed..76c3d025 100644 --- a/axum/src/error_handling/mod.rs +++ b/axum/src/error_handling/mod.rs @@ -1,11 +1,10 @@ #![doc = include_str!("../docs/error_handling.md")] use crate::{ - body::{boxed, Bytes, HttpBody}, + body::boxed, extract::{FromRequest, RequestParts}, http::{Request, StatusCode}, response::{IntoResponse, Response}, - BoxError, }; use std::{ convert::Infallible, @@ -114,17 +113,16 @@ where } } -impl<S, F, ReqBody, ResBody, Fut, Res> Service<Request<ReqBody>> for HandleError<S, F, ()> +impl<S, F, ReqBody, Fut, Res> Service<Request<ReqBody>> for HandleError<S, F, ()> where - S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static, + S: Service<Request<ReqBody>> + Clone + Send + 'static, + S::Response: IntoResponse + Send, S::Error: Send, S::Future: Send, F: FnOnce(S::Error) -> Fut + Clone + Send + 'static, Fut: Future<Output = Res> + Send, Res: IntoResponse, ReqBody: Send + 'static, - ResBody: HttpBody<Data = Bytes> + Send + 'static, - ResBody::Error: Into<BoxError>, { type Response = Response; type Error = Infallible; @@ -142,7 +140,7 @@ where let future = Box::pin(async move { match inner.oneshot(req).await { - Ok(res) => Ok(res.map(boxed)), + Ok(res) => Ok(res.into_response()), Err(err) => Ok(f(err).await.into_response()), } }); @@ -154,10 +152,11 @@ where #[allow(unused_macros)] macro_rules! impl_service { ( $($ty:ident),* $(,)? ) => { - impl<S, F, ReqBody, ResBody, Res, Fut, $($ty,)*> Service<Request<ReqBody>> + impl<S, F, ReqBody, Res, Fut, $($ty,)*> Service<Request<ReqBody>> for HandleError<S, F, ($($ty,)*)> where - S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static, + S: Service<Request<ReqBody>> + Clone + Send + 'static, + S::Response: IntoResponse + Send, S::Error: Send, S::Future: Send, F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static, @@ -165,8 +164,6 @@ macro_rules! impl_service { Res: IntoResponse, $( $ty: FromRequest<ReqBody> + Send,)* ReqBody: Send + 'static, - ResBody: HttpBody<Data = Bytes> + Send + 'static, - ResBody::Error: Into<BoxError>, { type Response = Response; type Error = Infallible; @@ -202,7 +199,7 @@ macro_rules! impl_service { }; match inner.oneshot(req).await { - Ok(res) => Ok(res.map(boxed)), + Ok(res) => Ok(res.into_response()), Err(err) => Ok(f($($ty),*, err).await.into_response().map(boxed)), } }); diff --git a/axum/src/middleware/from_extractor.rs b/axum/src/middleware/from_extractor.rs index 45bb9513..8399d8c5 100644 --- a/axum/src/middleware/from_extractor.rs +++ b/axum/src/middleware/from_extractor.rs @@ -1,8 +1,6 @@ use crate::{ - body::{Bytes, HttpBody}, extract::{FromRequest, RequestParts}, response::{IntoResponse, Response}, - BoxError, }; use futures_util::{future::BoxFuture, ready}; use http::Request; @@ -90,6 +88,8 @@ use tower_service::Service; /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` +/// +/// [`Bytes`]: bytes::Bytes pub fn from_extractor<E>() -> FromExtractorLayer<E> { FromExtractorLayer(PhantomData) } @@ -166,13 +166,12 @@ where } } -impl<S, E, ReqBody, ResBody> Service<Request<ReqBody>> for FromExtractor<S, E> +impl<S, E, ReqBody> Service<Request<ReqBody>> for FromExtractor<S, E> where E: FromRequest<ReqBody> + 'static, ReqBody: Default + Send + 'static, - S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone, - ResBody: HttpBody<Data = Bytes> + Send + 'static, - ResBody::Error: Into<BoxError>, + S: Service<Request<ReqBody>> + Clone, + S::Response: IntoResponse, { type Response = Response; type Error = S::Error; @@ -225,13 +224,12 @@ pin_project! { } } -impl<ReqBody, S, E, ResBody> Future for ResponseFuture<ReqBody, S, E> +impl<ReqBody, S, E> Future for ResponseFuture<ReqBody, S, E> where E: FromRequest<ReqBody>, - S: Service<Request<ReqBody>, Response = Response<ResBody>>, + S: Service<Request<ReqBody>>, + S::Response: IntoResponse, ReqBody: Default, - ResBody: HttpBody<Data = Bytes> + Send + 'static, - ResBody::Error: Into<BoxError>, { type Output = Result<Response, S::Error>; @@ -259,7 +257,7 @@ where StateProj::Call { future } => { return future .poll(cx) - .map(|result| result.map(|response| response.map(crate::body::boxed))); + .map(|result| result.map(IntoResponse::into_response)); } }; diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index 8e46fa69..5322adfb 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -1,8 +1,4 @@ -use crate::{ - body::{self, Bytes, HttpBody}, - response::{IntoResponse, Response}, - BoxError, -}; +use crate::response::{IntoResponse, Response}; use axum_core::extract::{FromRequest, RequestParts}; use futures_util::future::BoxFuture; use http::Request; @@ -16,7 +12,6 @@ use std::{ task::{Context, Poll}, }; use tower::{util::BoxCloneService, ServiceBuilder}; -use tower_http::ServiceBuilderExt; use tower_layer::Layer; use tower_service::Service; @@ -256,20 +251,19 @@ where macro_rules! impl_service { ( $($ty:ident),* $(,)? ) => { #[allow(non_snake_case)] - impl<F, Fut, Out, S, ReqBody, ResBody, $($ty,)*> Service<Request<ReqBody>> for FromFn<F, S, ($($ty,)*)> + impl<F, Fut, Out, S, ReqBody, $($ty,)*> Service<Request<ReqBody>> for FromFn<F, S, ($($ty,)*)> where F: FnMut($($ty),*, Next<ReqBody>) -> Fut + Clone + Send + 'static, $( $ty: FromRequest<ReqBody> + Send, )* Fut: Future<Output = Out> + Send + 'static, Out: IntoResponse + 'static, - S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible> + S: Service<Request<ReqBody>, Error = Infallible> + Clone + Send + 'static, + S::Response: IntoResponse, S::Future: Send + 'static, ReqBody: Send + 'static, - ResBody: HttpBody<Data = Bytes> + Send + 'static, - ResBody::Error: Into<BoxError>, { type Response = Response; type Error = Infallible; @@ -296,7 +290,7 @@ macro_rules! impl_service { let inner = ServiceBuilder::new() .boxed_clone() - .map_response_body(body::boxed) + .map_response(IntoResponse::into_response) .service(ready_inner); let next = Next { inner }; @@ -370,7 +364,7 @@ impl fmt::Debug for ResponseFuture { #[cfg(test)] mod tests { use super::*; - use crate::{body::Empty, routing::get, Router}; + use crate::{body::Body, routing::get, Router}; use http::{HeaderMap, StatusCode}; use tower::ServiceExt; @@ -392,12 +386,7 @@ mod tests { .layer(from_fn(insert_header)); let res = app - .oneshot( - Request::builder() - .uri("/") - .body(body::boxed(Empty::new())) - .unwrap(), - ) + .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(res.status(), StatusCode::OK); diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index c1f86626..fc86c9ab 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -7,7 +7,6 @@ use crate::{ http::{Method, Request, StatusCode}, response::Response, routing::{future::RouteFuture, Fallback, MethodFilter, Route}, - BoxError, }; use axum_core::response::IntoResponse; use bytes::BytesMut; @@ -17,7 +16,7 @@ use std::{ marker::PhantomData, task::{Context, Poll}, }; -use tower::{service_fn, util::MapResponseLayer, ServiceBuilder, ServiceExt}; +use tower::{service_fn, util::MapResponseLayer, ServiceBuilder}; use tower_layer::Layer; use tower_service::Service; @@ -76,12 +75,11 @@ macro_rules! top_level_service_fn { $name:ident, $method:ident ) => { $(#[$m])+ - pub fn $name<S, ReqBody, ResBody>(svc: S) -> MethodRouter<ReqBody, S::Error> + pub fn $name<S, ReqBody>(svc: S) -> MethodRouter<ReqBody, S::Error> where - S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static, + S: Service<Request<ReqBody>> + Clone + Send + 'static, + S::Response: IntoResponse + 'static, S::Future: Send + 'static, - ResBody: HttpBody<Data = Bytes> + Send + 'static, - ResBody::Error: Into<BoxError>, { on_service(MethodFilter::$method, svc) } @@ -208,15 +206,14 @@ macro_rules! chained_service_fn { $name:ident, $method:ident ) => { $(#[$m])+ - pub fn $name<S, ResBody>(self, svc: S) -> Self + pub fn $name<S>(self, svc: S) -> Self where - S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = E> + S: Service<Request<ReqBody>, Error = E> + Clone + Send + 'static, + S::Response: IntoResponse + 'static, S::Future: Send + 'static, - ResBody: HttpBody<Data = Bytes> + Send + 'static, - ResBody::Error: Into<BoxError>, { self.on_service(MethodFilter::$method, svc) } @@ -316,15 +313,11 @@ top_level_service_fn!(trace_service, TRACE); /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn on_service<S, ReqBody, ResBody>( - filter: MethodFilter, - svc: S, -) -> MethodRouter<ReqBody, S::Error> +pub fn on_service<S, ReqBody>(filter: MethodFilter, svc: S) -> MethodRouter<ReqBody, S::Error> where - S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static, + S: Service<Request<ReqBody>> + Clone + Send + 'static, + S::Response: IntoResponse + 'static, S::Future: Send + 'static, - ResBody: HttpBody<Data = Bytes> + Send + 'static, - ResBody::Error: Into<BoxError>, { MethodRouter::new().on_service(filter, svc) } @@ -382,12 +375,11 @@ where /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn any_service<S, ReqBody, ResBody>(svc: S) -> MethodRouter<ReqBody, S::Error> +pub fn any_service<S, ReqBody>(svc: S) -> MethodRouter<ReqBody, S::Error> where - S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static, + S: Service<Request<ReqBody>> + Clone + Send + 'static, + S::Response: IntoResponse + 'static, S::Future: Send + 'static, - ResBody: HttpBody<Data = Bytes> + Send + 'static, - ResBody::Error: Into<BoxError>, { MethodRouter::new().fallback(svc).skip_allow_header() } @@ -684,17 +676,13 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - pub fn on_service<S, ResBody>(self, filter: MethodFilter, svc: S) -> Self + pub fn on_service<S>(self, filter: MethodFilter, svc: S) -> Self where - S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = E> - + Clone - + Send - + 'static, + S: Service<Request<ReqBody>, Error = E> + Clone + Send + 'static, + S::Response: IntoResponse + 'static, S::Future: Send + 'static, - ResBody: HttpBody<Data = Bytes> + Send + 'static, - ResBody::Error: Into<BoxError>, { - self.on_service_boxed_response_body(filter, svc.map_response(|res| res.map(boxed))) + self.on_service_boxed_response_body(filter, svc) } chained_service_fn!(delete_service, DELETE); @@ -707,23 +695,20 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> { chained_service_fn!(trace_service, TRACE); #[doc = include_str!("../docs/method_routing/fallback.md")] - pub fn fallback<S, ResBody>(mut self, svc: S) -> Self + pub fn fallback<S>(mut self, svc: S) -> Self where - S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = E> - + Clone - + Send - + 'static, + S: Service<Request<ReqBody>, Error = E> + Clone + Send + 'static, + S::Response: IntoResponse + 'static, S::Future: Send + 'static, - ResBody: HttpBody<Data = Bytes> + Send + 'static, - ResBody::Error: Into<BoxError>, { - self.fallback = Fallback::Custom(Route::new(svc.map_response(|res| res.map(boxed)))); + self.fallback = Fallback::Custom(Route::new(svc)); self } fn fallback_boxed_response_body<S>(mut self, svc: S) -> Self where - S: Service<Request<ReqBody>, Response = Response, Error = E> + Clone + Send + 'static, + S: Service<Request<ReqBody>, Error = E> + Clone + Send + 'static, + S::Response: IntoResponse + 'static, S::Future: Send + 'static, { self.fallback = Fallback::Custom(Route::new(svc)); @@ -886,9 +871,10 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> { pub fn handle_error<F, T>(self, f: F) -> MethodRouter<ReqBody, Infallible> where F: Clone + Send + 'static, - HandleError<Route<ReqBody, E>, F, T>: - Service<Request<ReqBody>, Response = Response, Error = Infallible>, + HandleError<Route<ReqBody, E>, F, T>: Service<Request<ReqBody>, Error = Infallible>, <HandleError<Route<ReqBody, E>, F, T> as Service<Request<ReqBody>>>::Future: Send, + <HandleError<Route<ReqBody, E>, F, T> as Service<Request<ReqBody>>>::Response: + IntoResponse + Send, T: 'static, E: 'static, ReqBody: 'static, @@ -898,7 +884,8 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> { fn on_service_boxed_response_body<S>(self, filter: MethodFilter, svc: S) -> Self where - S: Service<Request<ReqBody>, Response = Response, Error = E> + Clone + Send + 'static, + S: Service<Request<ReqBody>, Error = E> + Clone + Send + 'static, + S::Response: IntoResponse + 'static, S::Future: Send + 'static, { macro_rules! set_service { @@ -1319,14 +1306,22 @@ mod tests { async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String) where - S: Service<Request<Body>, Response = Response, Error = Infallible>, + S: Service<Request<Body>, Error = Infallible>, + S::Response: IntoResponse, { let request = Request::builder() .uri("/") .method(method) .body(Body::empty()) .unwrap(); - let response = svc.ready().await.unwrap().call(request).await.unwrap(); + let response = svc + .ready() + .await + .unwrap() + .call(request) + .await + .unwrap() + .into_response(); let (parts, body) = response.into_parts(); let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap(); (parts.status, parts.headers, body) diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 5b6ee682..13536397 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -122,7 +122,8 @@ where #[doc = include_str!("../docs/routing/route.md")] pub fn route<T>(mut self, path: &str, service: T) -> Self where - T: Service<Request<B>, Response = Response, Error = Infallible> + Clone + Send + 'static, + T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, + T::Response: IntoResponse, T::Future: Send + 'static, { if path.is_empty() { @@ -176,7 +177,8 @@ where #[doc = include_str!("../docs/routing/nest.md")] pub fn nest<T>(mut self, mut path: &str, svc: T) -> Self where - T: Service<Request<B>, Response = Response, Error = Infallible> + Clone + Send + 'static, + T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, + T::Response: IntoResponse, T::Future: Send + 'static, { if path.is_empty() { @@ -368,7 +370,8 @@ where #[doc = include_str!("../docs/routing/fallback.md")] pub fn fallback<T>(mut self, svc: T) -> Self where - T: Service<Request<B>, Response = Response, Error = Infallible> + Clone + Send + 'static, + T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, + T::Response: IntoResponse, T::Future: Send + 'static, { self.fallback = Fallback::Custom(Route::new(svc)); diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 6e56d77a..71ff3e05 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -2,6 +2,7 @@ use crate::{ body::{boxed, Body, Empty, HttpBody}, response::Response, }; +use axum_core::response::IntoResponse; use bytes::Bytes; use http::{ header::{self, CONTENT_LENGTH}, @@ -30,10 +31,13 @@ pub struct Route<B = Body, E = Infallible>(BoxCloneService<Request<B>, Response, impl<B, E> Route<B, E> { pub(super) fn new<T>(svc: T) -> Self where - T: Service<Request<B>, Response = Response, Error = E> + Clone + Send + 'static, + T: Service<Request<B>, Error = E> + Clone + Send + 'static, + T::Response: IntoResponse + 'static, T::Future: Send + 'static, { - Self(BoxCloneService::new(svc)) + Self(BoxCloneService::new( + svc.map_response(IntoResponse::into_response), + )) } pub(crate) fn oneshot_inner(