Simplify handler trait (#221)

Rely on the `impl FromRequest for (T, ...)` rather than extracting things directly inside the macro.
This commit is contained in:
David Pedersen 2021-08-20 20:36:34 +02:00 committed by GitHub
parent e8bc3f5082
commit 44c58bdf5f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 108 additions and 110 deletions

View file

@ -2,10 +2,7 @@
use crate::body::{box_body, BoxBody};
use crate::util::{Either, EitherProj};
use futures_util::{
future::{BoxFuture, Map},
ready,
};
use futures_util::{future::BoxFuture, ready};
use http::{Method, Request, Response};
use http_body::Empty;
use pin_project_lite::pin_project;
@ -67,8 +64,5 @@ where
opaque_future! {
/// The response future for [`IntoService`](super::IntoService).
pub type IntoServiceFuture =
Map<
BoxFuture<'static, Response<BoxBody>>,
fn(Response<BoxBody>) -> Result<Response<BoxBody>, Infallible>,
>;
BoxFuture<'static, Result<Response<BoxBody>, Infallible>>;
}

View file

@ -1,5 +1,9 @@
use super::Handler;
use crate::body::BoxBody;
use crate::{
body::{box_body, BoxBody},
extract::{FromRequest, RequestParts},
response::IntoResponse,
};
use http::{Request, Response};
use std::{
convert::Infallible,
@ -12,12 +16,12 @@ use tower::Service;
/// An adapter that makes a [`Handler`] into a [`Service`].
///
/// Created with [`Handler::into_service`].
pub struct IntoService<H, B, T> {
pub struct IntoService<H, T> {
handler: H,
_marker: PhantomData<fn() -> (B, T)>,
_marker: PhantomData<fn() -> T>,
}
impl<H, B, T> IntoService<H, B, T> {
impl<H, T> IntoService<H, T> {
pub(super) fn new(handler: H) -> Self {
Self {
handler,
@ -26,7 +30,7 @@ impl<H, B, T> IntoService<H, B, T> {
}
}
impl<H, B, T> fmt::Debug for IntoService<H, B, T> {
impl<H, T> fmt::Debug for IntoService<H, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("IntoService")
.field(&format_args!("..."))
@ -34,7 +38,7 @@ impl<H, B, T> fmt::Debug for IntoService<H, B, T> {
}
}
impl<H, B, T> Clone for IntoService<H, B, T>
impl<H, T> Clone for IntoService<H, T>
where
H: Clone,
{
@ -46,9 +50,11 @@ where
}
}
impl<H, T, B> Service<Request<B>> for IntoService<H, B, T>
impl<H, T, B> Service<Request<B>> for IntoService<H, T>
where
H: Handler<B, T> + Clone + Send + 'static,
H: Handler<T>,
T: FromRequest<B> + Send,
T::Rejection: Send,
B: Send + 'static,
{
type Response = Response<BoxBody>;
@ -63,10 +69,16 @@ where
}
fn call(&mut self, req: Request<B>) -> Self::Future {
use futures_util::future::FutureExt;
let handler = self.handler.clone();
let future = Handler::call(handler, req).map(Ok::<_, Infallible> as _);
let future = Box::pin(async move {
let mut req = RequestParts::new(req);
let input = T::from_request(&mut req).await;
let res = match input {
Ok(input) => Handler::call(handler, input).await,
Err(rejection) => rejection.into_response().map(box_body),
};
Ok::<_, Infallible>(res)
});
super::future::IntoServiceFuture { future }
}

View file

@ -2,7 +2,7 @@
use crate::{
body::{box_body, BoxBody},
extract::FromRequest,
extract::{FromRequest, RequestParts},
response::IntoResponse,
routing::{EmptyRouter, MethodFilter},
service::HandleError,
@ -44,9 +44,9 @@ pub use self::into_service::IntoService;
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn any<H, B, T>(handler: H) -> OnMethod<H, B, T, EmptyRouter>
pub fn any<H, T>(handler: H) -> OnMethod<H, T, EmptyRouter>
where
H: Handler<B, T>,
H: Handler<T>,
{
on(MethodFilter::all(), handler)
}
@ -54,9 +54,9 @@ where
/// Route `CONNECT` requests to the given handler.
///
/// See [`get`] for an example.
pub fn connect<H, B, T>(handler: H) -> OnMethod<H, B, T, EmptyRouter>
pub fn connect<H, T>(handler: H) -> OnMethod<H, T, EmptyRouter>
where
H: Handler<B, T>,
H: Handler<T>,
{
on(MethodFilter::CONNECT, handler)
}
@ -64,9 +64,9 @@ where
/// Route `DELETE` requests to the given handler.
///
/// See [`get`] for an example.
pub fn delete<H, B, T>(handler: H) -> OnMethod<H, B, T, EmptyRouter>
pub fn delete<H, T>(handler: H) -> OnMethod<H, T, EmptyRouter>
where
H: Handler<B, T>,
H: Handler<T>,
{
on(MethodFilter::DELETE, handler)
}
@ -93,9 +93,9 @@ where
/// Note that `get` routes will also be called for `HEAD` requests but will have
/// the response body removed. Make sure to add explicit `HEAD` routes
/// afterwards.
pub fn get<H, B, T>(handler: H) -> OnMethod<H, B, T, EmptyRouter>
pub fn get<H, T>(handler: H) -> OnMethod<H, T, EmptyRouter>
where
H: Handler<B, T>,
H: Handler<T>,
{
on(MethodFilter::GET | MethodFilter::HEAD, handler)
}
@ -103,9 +103,9 @@ where
/// Route `HEAD` requests to the given handler.
///
/// See [`get`] for an example.
pub fn head<H, B, T>(handler: H) -> OnMethod<H, B, T, EmptyRouter>
pub fn head<H, T>(handler: H) -> OnMethod<H, T, EmptyRouter>
where
H: Handler<B, T>,
H: Handler<T>,
{
on(MethodFilter::HEAD, handler)
}
@ -113,9 +113,9 @@ where
/// Route `OPTIONS` requests to the given handler.
///
/// See [`get`] for an example.
pub fn options<H, B, T>(handler: H) -> OnMethod<H, B, T, EmptyRouter>
pub fn options<H, T>(handler: H) -> OnMethod<H, T, EmptyRouter>
where
H: Handler<B, T>,
H: Handler<T>,
{
on(MethodFilter::OPTIONS, handler)
}
@ -123,9 +123,9 @@ where
/// Route `PATCH` requests to the given handler.
///
/// See [`get`] for an example.
pub fn patch<H, B, T>(handler: H) -> OnMethod<H, B, T, EmptyRouter>
pub fn patch<H, T>(handler: H) -> OnMethod<H, T, EmptyRouter>
where
H: Handler<B, T>,
H: Handler<T>,
{
on(MethodFilter::PATCH, handler)
}
@ -133,9 +133,9 @@ where
/// Route `POST` requests to the given handler.
///
/// See [`get`] for an example.
pub fn post<H, B, T>(handler: H) -> OnMethod<H, B, T, EmptyRouter>
pub fn post<H, T>(handler: H) -> OnMethod<H, T, EmptyRouter>
where
H: Handler<B, T>,
H: Handler<T>,
{
on(MethodFilter::POST, handler)
}
@ -143,9 +143,9 @@ where
/// Route `PUT` requests to the given handler.
///
/// See [`get`] for an example.
pub fn put<H, B, T>(handler: H) -> OnMethod<H, B, T, EmptyRouter>
pub fn put<H, T>(handler: H) -> OnMethod<H, T, EmptyRouter>
where
H: Handler<B, T>,
H: Handler<T>,
{
on(MethodFilter::PUT, handler)
}
@ -153,9 +153,9 @@ where
/// Route `TRACE` requests to the given handler.
///
/// See [`get`] for an example.
pub fn trace<H, B, T>(handler: H) -> OnMethod<H, B, T, EmptyRouter>
pub fn trace<H, T>(handler: H) -> OnMethod<H, T, EmptyRouter>
where
H: Handler<B, T>,
H: Handler<T>,
{
on(MethodFilter::TRACE, handler)
}
@ -179,9 +179,9 @@ where
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn on<H, B, T>(method: MethodFilter, handler: H) -> OnMethod<H, B, T, EmptyRouter>
pub fn on<H, T>(method: MethodFilter, handler: H) -> OnMethod<H, T, EmptyRouter>
where
H: Handler<B, T>,
H: Handler<T>,
{
OnMethod {
method,
@ -206,14 +206,14 @@ pub(crate) mod sealed {
///
/// See the [module docs](crate::handler) for more details.
#[async_trait]
pub trait Handler<B, T>: Clone + Send + Sized + 'static {
pub trait Handler<T>: Clone + Send + Sync + Sized + 'static {
// This seals the trait. We cannot use the regular "sealed super trait"
// approach due to coherence.
#[doc(hidden)]
type Sealed: sealed::HiddentTrait;
/// Call the handler with the given request.
async fn call(self, req: Request<B>) -> Response<BoxBody>;
/// Call the handler.
async fn call(self, input: T) -> Response<BoxBody>;
/// Apply a [`tower::Layer`] to the handler.
///
@ -251,7 +251,7 @@ pub trait Handler<B, T>: Clone + Send + Sized + 'static {
/// errors. See [`Layered::handle_error`] for more details.
fn layer<L>(self, layer: L) -> Layered<L::Service, T>
where
L: Layer<OnMethod<Self, B, T, EmptyRouter>>,
L: Layer<OnMethod<Self, T, EmptyRouter>>,
{
Layered::new(layer.layer(any(self)))
}
@ -280,22 +280,21 @@ pub trait Handler<B, T>: Clone + Send + Sized + 'static {
/// # Ok::<_, hyper::Error>(())
/// # };
/// ```
fn into_service(self) -> IntoService<Self, B, T> {
fn into_service(self) -> IntoService<Self, T> {
IntoService::new(self)
}
}
#[async_trait]
impl<F, Fut, Res, B> Handler<B, ()> for F
impl<F, Fut, Res> Handler<()> for F
where
F: FnOnce() -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Res> + Send,
Res: IntoResponse,
B: Send + 'static,
{
type Sealed = sealed::Hidden;
async fn call(self, _req: Request<B>) -> Response<BoxBody> {
async fn call(self, _: ()) -> Response<BoxBody> {
self().await.into_response().map(box_body)
}
}
@ -307,35 +306,18 @@ macro_rules! impl_handler {
( $head:ident, $($tail:ident),* $(,)? ) => {
#[async_trait]
#[allow(non_snake_case)]
impl<F, Fut, B, Res, $head, $($tail,)*> Handler<B, ($head, $($tail,)*)> for F
impl<F, Fut, Res, $head, $($tail,)*> Handler<($head, $($tail,)*)> for F
where
F: FnOnce($head, $($tail,)*) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Res> + Send,
B: Send + 'static,
Res: IntoResponse,
B: Send + 'static,
$head: FromRequest<B> + Send,
$( $tail: FromRequest<B> + Send,)*
$head: Send + 'static,
$( $tail: Send + 'static ),*
{
type Sealed = sealed::Hidden;
async fn call(self, req: Request<B>) -> Response<BoxBody> {
let mut req = crate::extract::RequestParts::new(req);
let $head = match $head::from_request(&mut req).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response().map(box_body),
};
$(
let $tail = match $tail::from_request(&mut req).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response().map(box_body),
};
)*
async fn call(self, ($head, $($tail,)*): ($head, $($tail,)*)) -> Response<BoxBody> {
let res = self($head, $($tail,)*).await;
res.into_response().map(crate::body::box_body)
}
}
@ -373,9 +355,9 @@ where
}
#[async_trait]
impl<S, T, ReqBody, ResBody> Handler<ReqBody, T> for Layered<S, T>
impl<S, T, ReqBody, ResBody> Handler<(Request<ReqBody>,)> for Layered<S, T>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
S::Error: IntoResponse,
S::Future: Send,
T: 'static,
@ -385,7 +367,7 @@ where
{
type Sealed = sealed::Hidden;
async fn call(self, req: Request<ReqBody>) -> Response<BoxBody> {
async fn call(self, (req,): (Request<ReqBody>,)) -> Response<BoxBody> {
match self
.svc
.oneshot(req)
@ -431,14 +413,14 @@ impl<S, T> Layered<S, T> {
/// A handler [`Service`] that accepts requests based on a [`MethodFilter`] and
/// allows chaining additional handlers.
pub struct OnMethod<H, B, T, F> {
pub struct OnMethod<H, T, F> {
pub(crate) method: MethodFilter,
pub(crate) handler: H,
pub(crate) fallback: F,
pub(crate) _marker: PhantomData<fn() -> (B, T)>,
pub(crate) _marker: PhantomData<fn() -> T>,
}
impl<H, B, T, F> fmt::Debug for OnMethod<H, B, T, F>
impl<H, T, F> fmt::Debug for OnMethod<H, T, F>
where
T: fmt::Debug,
F: fmt::Debug,
@ -452,7 +434,7 @@ where
}
}
impl<H, B, T, F> Clone for OnMethod<H, B, T, F>
impl<H, T, F> Clone for OnMethod<H, T, F>
where
H: Clone,
F: Clone,
@ -467,21 +449,21 @@ where
}
}
impl<H, B, T, F> Copy for OnMethod<H, B, T, F>
impl<H, T, F> Copy for OnMethod<H, T, F>
where
H: Copy,
F: Copy,
{
}
impl<H, B, T, F> OnMethod<H, B, T, F> {
impl<H, T, F> OnMethod<H, T, F> {
/// Chain an additional handler that will accept all requests regardless of
/// its HTTP method.
///
/// See [`OnMethod::get`] for an example.
pub fn any<H2, T2>(self, handler: H2) -> OnMethod<H2, B, T2, Self>
pub fn any<H2, T2>(self, handler: H2) -> OnMethod<H2, T2, Self>
where
H2: Handler<B, T2>,
H2: Handler<T2>,
{
self.on(MethodFilter::all(), handler)
}
@ -489,9 +471,9 @@ impl<H, B, T, F> OnMethod<H, B, T, F> {
/// Chain an additional handler that will only accept `CONNECT` requests.
///
/// See [`OnMethod::get`] for an example.
pub fn connect<H2, T2>(self, handler: H2) -> OnMethod<H2, B, T2, Self>
pub fn connect<H2, T2>(self, handler: H2) -> OnMethod<H2, T2, Self>
where
H2: Handler<B, T2>,
H2: Handler<T2>,
{
self.on(MethodFilter::CONNECT, handler)
}
@ -499,9 +481,9 @@ impl<H, B, T, F> OnMethod<H, B, T, F> {
/// Chain an additional handler that will only accept `DELETE` requests.
///
/// See [`OnMethod::get`] for an example.
pub fn delete<H2, T2>(self, handler: H2) -> OnMethod<H2, B, T2, Self>
pub fn delete<H2, T2>(self, handler: H2) -> OnMethod<H2, T2, Self>
where
H2: Handler<B, T2>,
H2: Handler<T2>,
{
self.on(MethodFilter::DELETE, handler)
}
@ -528,9 +510,9 @@ impl<H, B, T, F> OnMethod<H, B, T, F> {
/// Note that `get` routes will also be called for `HEAD` requests but will have
/// the response body removed. Make sure to add explicit `HEAD` routes
/// afterwards.
pub fn get<H2, T2>(self, handler: H2) -> OnMethod<H2, B, T2, Self>
pub fn get<H2, T2>(self, handler: H2) -> OnMethod<H2, T2, Self>
where
H2: Handler<B, T2>,
H2: Handler<T2>,
{
self.on(MethodFilter::GET | MethodFilter::HEAD, handler)
}
@ -538,9 +520,9 @@ impl<H, B, T, F> OnMethod<H, B, T, F> {
/// Chain an additional handler that will only accept `HEAD` requests.
///
/// See [`OnMethod::get`] for an example.
pub fn head<H2, T2>(self, handler: H2) -> OnMethod<H2, B, T2, Self>
pub fn head<H2, T2>(self, handler: H2) -> OnMethod<H2, T2, Self>
where
H2: Handler<B, T2>,
H2: Handler<T2>,
{
self.on(MethodFilter::HEAD, handler)
}
@ -548,9 +530,9 @@ impl<H, B, T, F> OnMethod<H, B, T, F> {
/// Chain an additional handler that will only accept `OPTIONS` requests.
///
/// See [`OnMethod::get`] for an example.
pub fn options<H2, T2>(self, handler: H2) -> OnMethod<H2, B, T2, Self>
pub fn options<H2, T2>(self, handler: H2) -> OnMethod<H2, T2, Self>
where
H2: Handler<B, T2>,
H2: Handler<T2>,
{
self.on(MethodFilter::OPTIONS, handler)
}
@ -558,9 +540,9 @@ impl<H, B, T, F> OnMethod<H, B, T, F> {
/// Chain an additional handler that will only accept `PATCH` requests.
///
/// See [`OnMethod::get`] for an example.
pub fn patch<H2, T2>(self, handler: H2) -> OnMethod<H2, B, T2, Self>
pub fn patch<H2, T2>(self, handler: H2) -> OnMethod<H2, T2, Self>
where
H2: Handler<B, T2>,
H2: Handler<T2>,
{
self.on(MethodFilter::PATCH, handler)
}
@ -568,9 +550,9 @@ impl<H, B, T, F> OnMethod<H, B, T, F> {
/// Chain an additional handler that will only accept `POST` requests.
///
/// See [`OnMethod::get`] for an example.
pub fn post<H2, T2>(self, handler: H2) -> OnMethod<H2, B, T2, Self>
pub fn post<H2, T2>(self, handler: H2) -> OnMethod<H2, T2, Self>
where
H2: Handler<B, T2>,
H2: Handler<T2>,
{
self.on(MethodFilter::POST, handler)
}
@ -578,9 +560,9 @@ impl<H, B, T, F> OnMethod<H, B, T, F> {
/// Chain an additional handler that will only accept `PUT` requests.
///
/// See [`OnMethod::get`] for an example.
pub fn put<H2, T2>(self, handler: H2) -> OnMethod<H2, B, T2, Self>
pub fn put<H2, T2>(self, handler: H2) -> OnMethod<H2, T2, Self>
where
H2: Handler<B, T2>,
H2: Handler<T2>,
{
self.on(MethodFilter::PUT, handler)
}
@ -588,9 +570,9 @@ impl<H, B, T, F> OnMethod<H, B, T, F> {
/// Chain an additional handler that will only accept `TRACE` requests.
///
/// See [`OnMethod::get`] for an example.
pub fn trace<H2, T2>(self, handler: H2) -> OnMethod<H2, B, T2, Self>
pub fn trace<H2, T2>(self, handler: H2) -> OnMethod<H2, T2, Self>
where
H2: Handler<B, T2>,
H2: Handler<T2>,
{
self.on(MethodFilter::TRACE, handler)
}
@ -618,9 +600,9 @@ impl<H, B, T, F> OnMethod<H, B, T, F> {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn on<H2, T2>(self, method: MethodFilter, handler: H2) -> OnMethod<H2, B, T2, Self>
pub fn on<H2, T2>(self, method: MethodFilter, handler: H2) -> OnMethod<H2, T2, Self>
where
H2: Handler<B, T2>,
H2: Handler<T2>,
{
OnMethod {
method,
@ -631,10 +613,12 @@ impl<H, B, T, F> OnMethod<H, B, T, F> {
}
}
impl<H, B, T, F> Service<Request<B>> for OnMethod<H, B, T, F>
impl<H, T, B, F> Service<Request<B>> for OnMethod<H, T, F>
where
H: Handler<B, T>,
F: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible> + Clone,
H: Handler<T>,
T: FromRequest<B> + Send,
T::Rejection: Send,
F: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible> + Clone + Send,
B: Send + 'static,
{
type Response = Response<BoxBody>;
@ -649,7 +633,15 @@ where
let req_method = req.method().clone();
let fut = if self.method.matches(req.method()) {
let fut = Handler::call(self.handler.clone(), req);
let handler = self.handler.clone();
let fut = Box::pin(async move {
let mut req = RequestParts::new(req);
let input = T::from_request(&mut req).await;
match input {
Ok(input) => Handler::call(handler, input).await,
Err(rejection) => rejection.into_response().map(box_body),
}
}) as _;
Either::A { inner: fut }
} else {
let fut = self.fallback.clone().oneshot(req);

View file

@ -191,9 +191,9 @@ fn service_handle_on_router_still_impls_routing_dsl() {
#[test]
fn layered() {
let app = Router::new()
.route("/echo", get::<_, Body, _>(unit))
.route("/echo", get(unit))
.layer(timeout())
.handle_error(handle_error::<BoxError>);
.handle_error::<Body, _>(handle_error::<BoxError>);
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
}
@ -201,9 +201,9 @@ fn layered() {
#[tokio::test] // async because of `.boxed()`
async fn layered_boxed() {
let app = Router::new()
.route("/echo", get::<_, Body, _>(unit))
.route("/echo", get(unit))
.layer(timeout())
.boxed()
.boxed::<Body, _>()
.handle_error(handle_error::<BoxError>);
check_make_svc::<_, _, _, Infallible>(app.into_make_service());