diff --git a/Cargo.toml b/Cargo.toml index bab9ec77..497056e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ serde_json = "1.0" serde_urlencoded = "0.7" tokio = { version = "1", features = ["time"] } tower = { version = "0.4", features = ["util", "buffer", "make"] } -tower-http = { version = "0.1", features = ["add-extension"] } +tower-http = { version = "0.1", features = ["add-extension", "map-response-body"] } # optional dependencies tokio-tungstenite = { optional = true, version = "0.14" } diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs index 06d61f17..f74423a9 100644 --- a/examples/key_value_store.rs +++ b/examples/key_value_store.rs @@ -21,7 +21,6 @@ use tower_http::{ compression::CompressionLayer, trace::TraceLayer, }; use tower_web::{ - body::BoxBody, extract::{ContentLengthLimit, Extension, UrlParams}, prelude::*, response::IntoResponse, @@ -99,7 +98,7 @@ async fn list_keys(Extension(state): Extension) -> String { .join("\n") } -fn admin_routes() -> BoxRoute { +fn admin_routes() -> BoxRoute { async fn delete_all_keys(Extension(state): Extension) { state.write().unwrap().db.clear(); } diff --git a/src/body.rs b/src/body.rs index cb6f1f38..ca8acb69 100644 --- a/src/body.rs +++ b/src/body.rs @@ -1,7 +1,8 @@ //! HTTP body utilities. use bytes::Bytes; -use http_body::{Empty, Full}; +use http_body::{Empty, Full, SizeHint}; +use pin_project::pin_project; use std::{ error::Error as StdError, fmt, @@ -35,12 +36,6 @@ impl BoxBody { } } -impl Default for BoxBody { - fn default() -> Self { - BoxBody::new(Empty::::new()) - } -} - impl fmt::Debug for BoxBody { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BoxBody").finish() @@ -101,3 +96,103 @@ impl fmt::Display for BoxStdError { self.0.fmt(f) } } + +/// Type that combines two body types into one. +#[pin_project] +#[derive(Debug)] +pub struct Or(#[pin] Either); + +impl Or { + #[inline] + pub(crate) fn a(a: A) -> Self { + Or(Either::A(a)) + } + + #[inline] + pub(crate) fn b(b: B) -> Self { + Or(Either::B(b)) + } +} + +impl Default for Or { + fn default() -> Self { + Self(Either::Empty(Empty::new())) + } +} + +#[pin_project(project = EitherProj)] +#[derive(Debug)] +enum Either { + Empty(Empty), // required for `Default` + A(#[pin] A), + B(#[pin] B), +} + +impl http_body::Body for Or +where + A: http_body::Body, + A::Error: Into, + B: http_body::Body, + B::Error: Into, +{ + type Data = Bytes; + type Error = BoxStdError; + + #[inline] + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.project().0.project() { + EitherProj::Empty(inner) => Pin::new(inner).poll_data(cx).map(map_option_error), + EitherProj::A(inner) => inner.poll_data(cx).map(map_option_error), + EitherProj::B(inner) => inner.poll_data(cx).map(map_option_error), + } + } + + #[inline] + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + match self.project().0.project() { + EitherProj::Empty(inner) => Pin::new(inner) + .poll_trailers(cx) + .map_err(Into::into) + .map_err(BoxStdError), + EitherProj::A(inner) => inner + .poll_trailers(cx) + .map_err(Into::into) + .map_err(BoxStdError), + EitherProj::B(inner) => inner + .poll_trailers(cx) + .map_err(Into::into) + .map_err(BoxStdError), + } + } + + #[inline] + fn size_hint(&self) -> SizeHint { + match &self.0 { + Either::Empty(inner) => inner.size_hint(), + Either::A(inner) => inner.size_hint(), + Either::B(inner) => inner.size_hint(), + } + } + + #[inline] + fn is_end_stream(&self) -> bool { + match &self.0 { + Either::Empty(inner) => inner.is_end_stream(), + Either::A(inner) => inner.is_end_stream(), + Either::B(inner) => inner.is_end_stream(), + } + } +} + +fn map_option_error(opt: Option>) -> Option> +where + E: Into, +{ + opt.map(|result| result.map_err(Into::::into).map_err(BoxStdError)) +} diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 48c8d9a1..c48a2180 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -39,15 +39,14 @@ //! the [`extract`](crate::extract) module. use crate::{ - body::{Body, BoxBody}, + body::{self, Body, BoxBody}, extract::FromRequest, response::IntoResponse, - routing::{BoxResponseBody, EmptyRouter, MethodFilter, RouteFuture}, + routing::{EmptyRouter, MethodFilter, RouteFuture}, service::HandleError, }; use async_trait::async_trait; use bytes::Bytes; -use futures_util::future::Either; use http::{Request, Response}; use std::{ convert::Infallible, @@ -647,14 +646,14 @@ impl OnMethod { impl Service> for OnMethod where S: Service, Response = Response, Error = Infallible> + Clone, - SB: http_body::Body + Send + Sync + 'static, - SB::Error: Into, - F: Service, Response = Response, Error = Infallible> + Clone, - FB: http_body::Body + Send + Sync + 'static, + + SB: http_body::Body, + SB::Error: Into, + FB: http_body::Body, FB::Error: Into, { - type Response = Response; + type Response = Response>; type Error = Infallible; type Future = RouteFuture; @@ -663,13 +662,12 @@ where } fn call(&mut self, req: Request) -> Self::Future { - let f = if self.method.matches(req.method()) { - let response_future = self.svc.clone().oneshot(req); - Either::Left(BoxResponseBody(response_future)) + if self.method.matches(req.method()) { + let fut = self.svc.clone().oneshot(req); + RouteFuture::a(fut) } else { - let response_future = self.fallback.clone().oneshot(req); - Either::Right(BoxResponseBody(response_future)) - }; - RouteFuture(f) + let fut = self.fallback.clone().oneshot(req); + RouteFuture::b(fut) + } } } diff --git a/src/lib.rs b/src/lib.rs index 0fe2313a..8eb8e471 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -510,7 +510,7 @@ //! use http::Response; //! use std::convert::Infallible; //! -//! fn api_routes() -> BoxRoute { +//! fn api_routes() -> BoxRoute { //! route("/users", get(|_: Request| async { /* ... */ })).boxed() //! } //! @@ -692,19 +692,6 @@ where routing::EmptyRouter.route(description, service) } -pub(crate) trait ResultExt { - fn unwrap_infallible(self) -> T; -} - -impl ResultExt for Result { - fn unwrap_infallible(self) -> T { - match self { - Ok(value) => value, - Err(err) => match err {}, - } - } -} - mod sealed { #![allow(unreachable_pub, missing_docs)] diff --git a/src/routing.rs b/src/routing.rs index 82bb9fe7..e189f839 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -1,6 +1,9 @@ //! Routing between [`Service`]s. -use crate::{body::BoxBody, response::IntoResponse, ResultExt}; +use crate::{ + body::{self, BoxBody}, + response::IntoResponse, +}; use async_trait::async_trait; use bytes::Bytes; use futures_util::{future, ready}; @@ -23,6 +26,7 @@ use tower::{ util::{BoxService, Oneshot, ServiceExt}, BoxError, Layer, Service, ServiceBuilder, }; +use tower_http::map_response_body::MapResponseBodyLayer; /// A filter that matches one or more HTTP method. #[derive(Debug, Copy, Clone)] @@ -132,7 +136,7 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// return them from functions: /// /// ```rust - /// use tower_web::{body::BoxBody, routing::BoxRoute, prelude::*}; + /// use tower_web::{routing::BoxRoute, prelude::*}; /// /// async fn first_handler() { /* ... */ } /// @@ -140,7 +144,7 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// /// async fn third_handler() { /* ... */ } /// - /// fn app() -> BoxRoute { + /// fn app() -> BoxRoute { /// route("/", get(first_handler).post(second_handler)) /// .route("/foo", get(third_handler)) /// .boxed() @@ -149,7 +153,7 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// /// It also helps with compile times when you have a very large number of /// routes. - fn boxed(self) -> BoxRoute + fn boxed(self) -> BoxRoute where Self: Service, Response = Response, Error = Infallible> + Send + 'static, >>::Future: Send, @@ -160,6 +164,7 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { .layer_fn(BoxRoute) .buffer(1024) .layer(BoxService::layer()) + .layer(MapResponseBodyLayer::new(BoxBody::new)) .service(self) } @@ -292,14 +297,14 @@ impl crate::sealed::Sealed for Route {} impl Service> for Route where S: Service, Response = Response, Error = Infallible> + Clone, - SB: http_body::Body + Send + Sync + 'static, - SB::Error: Into, - F: Service, Response = Response, Error = Infallible> + Clone, - FB: http_body::Body + Send + Sync + 'static, + + SB: http_body::Body, + SB::Error: Into, + FB: http_body::Body, FB::Error: Into, { - type Response = Response; + type Response = Response>; type Error = Infallible; type Future = RouteFuture; @@ -308,45 +313,71 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - let f = if let Some(captures) = self.pattern.full_match(req.uri().path()) { + if let Some(captures) = self.pattern.full_match(req.uri().path()) { insert_url_params(&mut req, captures); - let response_future = self.svc.clone().oneshot(req); - future::Either::Left(BoxResponseBody(response_future)) + let fut = self.svc.clone().oneshot(req); + RouteFuture::a(fut) } else { - let response_future = self.fallback.clone().oneshot(req); - future::Either::Right(BoxResponseBody(response_future)) - }; - RouteFuture(f) + let fut = self.fallback.clone().oneshot(req); + RouteFuture::b(fut) + } } } /// The response future for [`Route`]. #[pin_project] #[derive(Debug)] -pub struct RouteFuture( - #[pin] - pub(crate) future::Either< - BoxResponseBody>>, - BoxResponseBody>>, - >, -) +pub struct RouteFuture(#[pin] RouteFutureInner) where S: Service>, F: Service>; +impl RouteFuture +where + S: Service>, + F: Service>, +{ + pub(crate) fn a(a: Oneshot>) -> Self { + RouteFuture(RouteFutureInner::A(a)) + } + + pub(crate) fn b(b: Oneshot>) -> Self { + RouteFuture(RouteFutureInner::B(b)) + } +} + +#[pin_project(project = RouteFutureInnerProj)] +#[derive(Debug)] +enum RouteFutureInner +where + S: Service>, + F: Service>, +{ + A(#[pin] Oneshot>), + B(#[pin] Oneshot>), +} + impl Future for RouteFuture where S: Service, Response = Response, Error = Infallible>, - SB: http_body::Body + Send + Sync + 'static, - SB::Error: Into, F: Service, Response = Response, Error = Infallible>, - FB: http_body::Body + Send + Sync + 'static, + + SB: http_body::Body, + SB::Error: Into, + FB: http_body::Body, FB::Error: Into, { - type Output = Result, Infallible>; + type Output = Result>, Infallible>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project().0.poll(cx) + match self.project().0.project() { + RouteFutureInnerProj::A(inner) => inner + .poll(cx) + .map(|result| result.map(|res| res.map(body::Or::a))), + RouteFutureInnerProj::B(inner) => inner + .poll(cx) + .map(|result| result.map(|res| res.map(body::Or::b))), + } } } @@ -363,29 +394,6 @@ fn insert_url_params(req: &mut Request, params: Vec<(String, String)>) { } } -/// A response future that boxes the response body with [`BoxBody`]. -#[pin_project] -#[derive(Debug)] -pub struct BoxResponseBody(#[pin] pub(crate) F); - -impl Future for BoxResponseBody -where - F: Future, Infallible>>, - B: http_body::Body + Send + Sync + 'static, - B::Error: Into, -{ - type Output = Result, Infallible>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let response: Response = ready!(self.project().0.poll(cx)).unwrap_infallible(); - let response = response.map(|body| { - let body = body.map_err(Into::into); - BoxBody::new(body) - }); - Poll::Ready(Ok(response)) - } -} - /// A [`Service`] that responds with `404 Not Found` to all requests. /// /// This is used as the bottom service in a router stack. You shouldn't have to @@ -513,32 +521,25 @@ type Captures = Vec<(String, String)>; /// A boxed route trait object. /// /// See [`RoutingDsl::boxed`] for more details. -pub struct BoxRoute(Buffer, Response, Infallible>, Request>); +#[derive(Clone)] +pub struct BoxRoute( + Buffer, Response, Infallible>, Request>, +); -impl fmt::Debug for BoxRoute { +impl fmt::Debug for BoxRoute { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BoxRoute").finish() } } -impl Clone for BoxRoute { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} +impl RoutingDsl for BoxRoute {} -impl RoutingDsl for BoxRoute {} +impl crate::sealed::Sealed for BoxRoute {} -impl crate::sealed::Sealed for BoxRoute {} - -impl Service> for BoxRoute -where - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, -{ +impl Service> for BoxRoute { type Response = Response; type Error = Infallible; - type Future = BoxRouteFuture; + type Future = BoxRouteFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { @@ -553,29 +554,25 @@ where /// The response future for [`BoxRoute`]. #[pin_project] -pub struct BoxRouteFuture(#[pin] InnerFuture); +pub struct BoxRouteFuture(#[pin] InnerFuture); -type InnerFuture = Oneshot< - Buffer, Response, Infallible>, Request>, +type InnerFuture = Oneshot< + Buffer, Response, Infallible>, Request>, Request, >; -impl fmt::Debug for BoxRouteFuture { +impl fmt::Debug for BoxRouteFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BoxRouteFuture").finish() } } -impl Future for BoxRouteFuture -where - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, -{ +impl Future for BoxRouteFuture { type Output = Result, Infallible>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match ready!(self.project().0.poll(cx)) { - Ok(res) => Poll::Ready(Ok(res.map(BoxBody::new))), + Ok(res) => Poll::Ready(Ok(res)), Err(err) => Poll::Ready(Ok(handle_buffer_error(err))), } } @@ -792,14 +789,14 @@ impl crate::sealed::Sealed for Nested {} impl Service> for Nested where S: Service, Response = Response, Error = Infallible> + Clone, - SB: http_body::Body + Send + Sync + 'static, - SB::Error: Into, - F: Service, Response = Response, Error = Infallible> + Clone, - FB: http_body::Body + Send + Sync + 'static, + + SB: http_body::Body, + SB::Error: Into, + FB: http_body::Body, FB::Error: Into, { - type Response = Response; + type Response = Response>; type Error = Infallible; type Future = RouteFuture; @@ -808,18 +805,17 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - let f = if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) { + if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) { let without_prefix = strip_prefix(req.uri(), prefix); *req.uri_mut() = without_prefix; insert_url_params(&mut req, captures); - let response_future = self.svc.clone().oneshot(req); - future::Either::Left(BoxResponseBody(response_future)) + let fut = self.svc.clone().oneshot(req); + RouteFuture::a(fut) } else { - let response_future = self.fallback.clone().oneshot(req); - future::Either::Right(BoxResponseBody(response_future)) - }; - RouteFuture(f) + let fut = self.fallback.clone().oneshot(req); + RouteFuture::b(fut) + } } } diff --git a/src/service/mod.rs b/src/service/mod.rs index 9054a22e..33f3a3bf 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -84,12 +84,11 @@ //! [load shed]: tower::load_shed use crate::{ - body::{Body, BoxBody}, + body::{self, Body, BoxBody}, response::IntoResponse, - routing::{BoxResponseBody, EmptyRouter, MethodFilter, RouteFuture}, + routing::{EmptyRouter, MethodFilter, RouteFuture}, }; use bytes::Bytes; -use futures_util::future::Either; use http::{Request, Response}; use std::{ convert::Infallible, @@ -408,14 +407,14 @@ impl OnMethod { impl Service> for OnMethod where S: Service, Response = Response, Error = Infallible> + Clone, - SB: http_body::Body + Send + Sync + 'static, - SB::Error: Into, - F: Service, Response = Response, Error = Infallible> + Clone, - FB: http_body::Body + Send + Sync + 'static, + + SB: http_body::Body, + SB::Error: Into, + FB: http_body::Body, FB::Error: Into, { - type Response = Response; + type Response = Response>; type Error = Infallible; type Future = RouteFuture; @@ -424,14 +423,13 @@ where } fn call(&mut self, req: Request) -> Self::Future { - let f = if self.method.matches(req.method()) { - let response_future = self.svc.clone().oneshot(req); - Either::Left(BoxResponseBody(response_future)) + if self.method.matches(req.method()) { + let fut = self.svc.clone().oneshot(req); + RouteFuture::a(fut) } else { - let response_future = self.fallback.clone().oneshot(req); - Either::Right(BoxResponseBody(response_future)) - }; - RouteFuture(f) + let fut = self.fallback.clone().oneshot(req); + RouteFuture::b(fut) + } } }