From 04d62798b6db7454815076c8f98782bb606ff5d7 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sat, 12 Jun 2021 23:59:18 +0200 Subject: [PATCH] Reduce body boxing (#9) Previously, when routing between one or two requests the two body types would be merged by boxing them. This isn't ideal since it introduces a layer indirection for each route. We can't require the services to be routed between as not all services use the same body type. This changes that so it instead uses an `Either` enum that implements `http_body::Body` if each variant does. Will reduce the overall allocations and hopefully the compiler can optimize things if both variants are the same. --- Cargo.toml | 2 +- examples/key_value_store.rs | 3 +- src/body.rs | 109 ++++++++++++++++++++-- src/handler/mod.rs | 28 +++--- src/lib.rs | 15 +--- src/routing.rs | 174 ++++++++++++++++++------------------ src/service/mod.rs | 28 +++--- 7 files changed, 216 insertions(+), 143 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bab9ec77..497056e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ serde_json = "1.0" serde_urlencoded = "0.7" tokio = { version = "1", features = ["time"] } tower = { version = "0.4", features = ["util", "buffer", "make"] } -tower-http = { version = "0.1", features = ["add-extension"] } +tower-http = { version = "0.1", features = ["add-extension", "map-response-body"] } # optional dependencies tokio-tungstenite = { optional = true, version = "0.14" } diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs index 06d61f17..f74423a9 100644 --- a/examples/key_value_store.rs +++ b/examples/key_value_store.rs @@ -21,7 +21,6 @@ use tower_http::{ compression::CompressionLayer, trace::TraceLayer, }; use tower_web::{ - body::BoxBody, extract::{ContentLengthLimit, Extension, UrlParams}, prelude::*, response::IntoResponse, @@ -99,7 +98,7 @@ async fn list_keys(Extension(state): Extension) -> String { .join("\n") } -fn admin_routes() -> BoxRoute { +fn admin_routes() -> BoxRoute { async fn delete_all_keys(Extension(state): Extension) { state.write().unwrap().db.clear(); } diff --git a/src/body.rs b/src/body.rs index cb6f1f38..ca8acb69 100644 --- a/src/body.rs +++ b/src/body.rs @@ -1,7 +1,8 @@ //! HTTP body utilities. use bytes::Bytes; -use http_body::{Empty, Full}; +use http_body::{Empty, Full, SizeHint}; +use pin_project::pin_project; use std::{ error::Error as StdError, fmt, @@ -35,12 +36,6 @@ impl BoxBody { } } -impl Default for BoxBody { - fn default() -> Self { - BoxBody::new(Empty::::new()) - } -} - impl fmt::Debug for BoxBody { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BoxBody").finish() @@ -101,3 +96,103 @@ impl fmt::Display for BoxStdError { self.0.fmt(f) } } + +/// Type that combines two body types into one. +#[pin_project] +#[derive(Debug)] +pub struct Or(#[pin] Either); + +impl Or { + #[inline] + pub(crate) fn a(a: A) -> Self { + Or(Either::A(a)) + } + + #[inline] + pub(crate) fn b(b: B) -> Self { + Or(Either::B(b)) + } +} + +impl Default for Or { + fn default() -> Self { + Self(Either::Empty(Empty::new())) + } +} + +#[pin_project(project = EitherProj)] +#[derive(Debug)] +enum Either { + Empty(Empty), // required for `Default` + A(#[pin] A), + B(#[pin] B), +} + +impl http_body::Body for Or +where + A: http_body::Body, + A::Error: Into, + B: http_body::Body, + B::Error: Into, +{ + type Data = Bytes; + type Error = BoxStdError; + + #[inline] + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.project().0.project() { + EitherProj::Empty(inner) => Pin::new(inner).poll_data(cx).map(map_option_error), + EitherProj::A(inner) => inner.poll_data(cx).map(map_option_error), + EitherProj::B(inner) => inner.poll_data(cx).map(map_option_error), + } + } + + #[inline] + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + match self.project().0.project() { + EitherProj::Empty(inner) => Pin::new(inner) + .poll_trailers(cx) + .map_err(Into::into) + .map_err(BoxStdError), + EitherProj::A(inner) => inner + .poll_trailers(cx) + .map_err(Into::into) + .map_err(BoxStdError), + EitherProj::B(inner) => inner + .poll_trailers(cx) + .map_err(Into::into) + .map_err(BoxStdError), + } + } + + #[inline] + fn size_hint(&self) -> SizeHint { + match &self.0 { + Either::Empty(inner) => inner.size_hint(), + Either::A(inner) => inner.size_hint(), + Either::B(inner) => inner.size_hint(), + } + } + + #[inline] + fn is_end_stream(&self) -> bool { + match &self.0 { + Either::Empty(inner) => inner.is_end_stream(), + Either::A(inner) => inner.is_end_stream(), + Either::B(inner) => inner.is_end_stream(), + } + } +} + +fn map_option_error(opt: Option>) -> Option> +where + E: Into, +{ + opt.map(|result| result.map_err(Into::::into).map_err(BoxStdError)) +} diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 48c8d9a1..c48a2180 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -39,15 +39,14 @@ //! the [`extract`](crate::extract) module. use crate::{ - body::{Body, BoxBody}, + body::{self, Body, BoxBody}, extract::FromRequest, response::IntoResponse, - routing::{BoxResponseBody, EmptyRouter, MethodFilter, RouteFuture}, + routing::{EmptyRouter, MethodFilter, RouteFuture}, service::HandleError, }; use async_trait::async_trait; use bytes::Bytes; -use futures_util::future::Either; use http::{Request, Response}; use std::{ convert::Infallible, @@ -647,14 +646,14 @@ impl OnMethod { impl Service> for OnMethod where S: Service, Response = Response, Error = Infallible> + Clone, - SB: http_body::Body + Send + Sync + 'static, - SB::Error: Into, - F: Service, Response = Response, Error = Infallible> + Clone, - FB: http_body::Body + Send + Sync + 'static, + + SB: http_body::Body, + SB::Error: Into, + FB: http_body::Body, FB::Error: Into, { - type Response = Response; + type Response = Response>; type Error = Infallible; type Future = RouteFuture; @@ -663,13 +662,12 @@ where } fn call(&mut self, req: Request) -> Self::Future { - let f = if self.method.matches(req.method()) { - let response_future = self.svc.clone().oneshot(req); - Either::Left(BoxResponseBody(response_future)) + if self.method.matches(req.method()) { + let fut = self.svc.clone().oneshot(req); + RouteFuture::a(fut) } else { - let response_future = self.fallback.clone().oneshot(req); - Either::Right(BoxResponseBody(response_future)) - }; - RouteFuture(f) + let fut = self.fallback.clone().oneshot(req); + RouteFuture::b(fut) + } } } diff --git a/src/lib.rs b/src/lib.rs index 0fe2313a..8eb8e471 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -510,7 +510,7 @@ //! use http::Response; //! use std::convert::Infallible; //! -//! fn api_routes() -> BoxRoute { +//! fn api_routes() -> BoxRoute { //! route("/users", get(|_: Request| async { /* ... */ })).boxed() //! } //! @@ -692,19 +692,6 @@ where routing::EmptyRouter.route(description, service) } -pub(crate) trait ResultExt { - fn unwrap_infallible(self) -> T; -} - -impl ResultExt for Result { - fn unwrap_infallible(self) -> T { - match self { - Ok(value) => value, - Err(err) => match err {}, - } - } -} - mod sealed { #![allow(unreachable_pub, missing_docs)] diff --git a/src/routing.rs b/src/routing.rs index 82bb9fe7..e189f839 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -1,6 +1,9 @@ //! Routing between [`Service`]s. -use crate::{body::BoxBody, response::IntoResponse, ResultExt}; +use crate::{ + body::{self, BoxBody}, + response::IntoResponse, +}; use async_trait::async_trait; use bytes::Bytes; use futures_util::{future, ready}; @@ -23,6 +26,7 @@ use tower::{ util::{BoxService, Oneshot, ServiceExt}, BoxError, Layer, Service, ServiceBuilder, }; +use tower_http::map_response_body::MapResponseBodyLayer; /// A filter that matches one or more HTTP method. #[derive(Debug, Copy, Clone)] @@ -132,7 +136,7 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// return them from functions: /// /// ```rust - /// use tower_web::{body::BoxBody, routing::BoxRoute, prelude::*}; + /// use tower_web::{routing::BoxRoute, prelude::*}; /// /// async fn first_handler() { /* ... */ } /// @@ -140,7 +144,7 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// /// async fn third_handler() { /* ... */ } /// - /// fn app() -> BoxRoute { + /// fn app() -> BoxRoute { /// route("/", get(first_handler).post(second_handler)) /// .route("/foo", get(third_handler)) /// .boxed() @@ -149,7 +153,7 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// /// It also helps with compile times when you have a very large number of /// routes. - fn boxed(self) -> BoxRoute + fn boxed(self) -> BoxRoute where Self: Service, Response = Response, Error = Infallible> + Send + 'static, >>::Future: Send, @@ -160,6 +164,7 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { .layer_fn(BoxRoute) .buffer(1024) .layer(BoxService::layer()) + .layer(MapResponseBodyLayer::new(BoxBody::new)) .service(self) } @@ -292,14 +297,14 @@ impl crate::sealed::Sealed for Route {} impl Service> for Route where S: Service, Response = Response, Error = Infallible> + Clone, - SB: http_body::Body + Send + Sync + 'static, - SB::Error: Into, - F: Service, Response = Response, Error = Infallible> + Clone, - FB: http_body::Body + Send + Sync + 'static, + + SB: http_body::Body, + SB::Error: Into, + FB: http_body::Body, FB::Error: Into, { - type Response = Response; + type Response = Response>; type Error = Infallible; type Future = RouteFuture; @@ -308,45 +313,71 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - let f = if let Some(captures) = self.pattern.full_match(req.uri().path()) { + if let Some(captures) = self.pattern.full_match(req.uri().path()) { insert_url_params(&mut req, captures); - let response_future = self.svc.clone().oneshot(req); - future::Either::Left(BoxResponseBody(response_future)) + let fut = self.svc.clone().oneshot(req); + RouteFuture::a(fut) } else { - let response_future = self.fallback.clone().oneshot(req); - future::Either::Right(BoxResponseBody(response_future)) - }; - RouteFuture(f) + let fut = self.fallback.clone().oneshot(req); + RouteFuture::b(fut) + } } } /// The response future for [`Route`]. #[pin_project] #[derive(Debug)] -pub struct RouteFuture( - #[pin] - pub(crate) future::Either< - BoxResponseBody>>, - BoxResponseBody>>, - >, -) +pub struct RouteFuture(#[pin] RouteFutureInner) where S: Service>, F: Service>; +impl RouteFuture +where + S: Service>, + F: Service>, +{ + pub(crate) fn a(a: Oneshot>) -> Self { + RouteFuture(RouteFutureInner::A(a)) + } + + pub(crate) fn b(b: Oneshot>) -> Self { + RouteFuture(RouteFutureInner::B(b)) + } +} + +#[pin_project(project = RouteFutureInnerProj)] +#[derive(Debug)] +enum RouteFutureInner +where + S: Service>, + F: Service>, +{ + A(#[pin] Oneshot>), + B(#[pin] Oneshot>), +} + impl Future for RouteFuture where S: Service, Response = Response, Error = Infallible>, - SB: http_body::Body + Send + Sync + 'static, - SB::Error: Into, F: Service, Response = Response, Error = Infallible>, - FB: http_body::Body + Send + Sync + 'static, + + SB: http_body::Body, + SB::Error: Into, + FB: http_body::Body, FB::Error: Into, { - type Output = Result, Infallible>; + type Output = Result>, Infallible>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project().0.poll(cx) + match self.project().0.project() { + RouteFutureInnerProj::A(inner) => inner + .poll(cx) + .map(|result| result.map(|res| res.map(body::Or::a))), + RouteFutureInnerProj::B(inner) => inner + .poll(cx) + .map(|result| result.map(|res| res.map(body::Or::b))), + } } } @@ -363,29 +394,6 @@ fn insert_url_params(req: &mut Request, params: Vec<(String, String)>) { } } -/// A response future that boxes the response body with [`BoxBody`]. -#[pin_project] -#[derive(Debug)] -pub struct BoxResponseBody(#[pin] pub(crate) F); - -impl Future for BoxResponseBody -where - F: Future, Infallible>>, - B: http_body::Body + Send + Sync + 'static, - B::Error: Into, -{ - type Output = Result, Infallible>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let response: Response = ready!(self.project().0.poll(cx)).unwrap_infallible(); - let response = response.map(|body| { - let body = body.map_err(Into::into); - BoxBody::new(body) - }); - Poll::Ready(Ok(response)) - } -} - /// A [`Service`] that responds with `404 Not Found` to all requests. /// /// This is used as the bottom service in a router stack. You shouldn't have to @@ -513,32 +521,25 @@ type Captures = Vec<(String, String)>; /// A boxed route trait object. /// /// See [`RoutingDsl::boxed`] for more details. -pub struct BoxRoute(Buffer, Response, Infallible>, Request>); +#[derive(Clone)] +pub struct BoxRoute( + Buffer, Response, Infallible>, Request>, +); -impl fmt::Debug for BoxRoute { +impl fmt::Debug for BoxRoute { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BoxRoute").finish() } } -impl Clone for BoxRoute { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} +impl RoutingDsl for BoxRoute {} -impl RoutingDsl for BoxRoute {} +impl crate::sealed::Sealed for BoxRoute {} -impl crate::sealed::Sealed for BoxRoute {} - -impl Service> for BoxRoute -where - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, -{ +impl Service> for BoxRoute { type Response = Response; type Error = Infallible; - type Future = BoxRouteFuture; + type Future = BoxRouteFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { @@ -553,29 +554,25 @@ where /// The response future for [`BoxRoute`]. #[pin_project] -pub struct BoxRouteFuture(#[pin] InnerFuture); +pub struct BoxRouteFuture(#[pin] InnerFuture); -type InnerFuture = Oneshot< - Buffer, Response, Infallible>, Request>, +type InnerFuture = Oneshot< + Buffer, Response, Infallible>, Request>, Request, >; -impl fmt::Debug for BoxRouteFuture { +impl fmt::Debug for BoxRouteFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BoxRouteFuture").finish() } } -impl Future for BoxRouteFuture -where - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, -{ +impl Future for BoxRouteFuture { type Output = Result, Infallible>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match ready!(self.project().0.poll(cx)) { - Ok(res) => Poll::Ready(Ok(res.map(BoxBody::new))), + Ok(res) => Poll::Ready(Ok(res)), Err(err) => Poll::Ready(Ok(handle_buffer_error(err))), } } @@ -792,14 +789,14 @@ impl crate::sealed::Sealed for Nested {} impl Service> for Nested where S: Service, Response = Response, Error = Infallible> + Clone, - SB: http_body::Body + Send + Sync + 'static, - SB::Error: Into, - F: Service, Response = Response, Error = Infallible> + Clone, - FB: http_body::Body + Send + Sync + 'static, + + SB: http_body::Body, + SB::Error: Into, + FB: http_body::Body, FB::Error: Into, { - type Response = Response; + type Response = Response>; type Error = Infallible; type Future = RouteFuture; @@ -808,18 +805,17 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - let f = if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) { + if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) { let without_prefix = strip_prefix(req.uri(), prefix); *req.uri_mut() = without_prefix; insert_url_params(&mut req, captures); - let response_future = self.svc.clone().oneshot(req); - future::Either::Left(BoxResponseBody(response_future)) + let fut = self.svc.clone().oneshot(req); + RouteFuture::a(fut) } else { - let response_future = self.fallback.clone().oneshot(req); - future::Either::Right(BoxResponseBody(response_future)) - }; - RouteFuture(f) + let fut = self.fallback.clone().oneshot(req); + RouteFuture::b(fut) + } } } diff --git a/src/service/mod.rs b/src/service/mod.rs index 9054a22e..33f3a3bf 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -84,12 +84,11 @@ //! [load shed]: tower::load_shed use crate::{ - body::{Body, BoxBody}, + body::{self, Body, BoxBody}, response::IntoResponse, - routing::{BoxResponseBody, EmptyRouter, MethodFilter, RouteFuture}, + routing::{EmptyRouter, MethodFilter, RouteFuture}, }; use bytes::Bytes; -use futures_util::future::Either; use http::{Request, Response}; use std::{ convert::Infallible, @@ -408,14 +407,14 @@ impl OnMethod { impl Service> for OnMethod where S: Service, Response = Response, Error = Infallible> + Clone, - SB: http_body::Body + Send + Sync + 'static, - SB::Error: Into, - F: Service, Response = Response, Error = Infallible> + Clone, - FB: http_body::Body + Send + Sync + 'static, + + SB: http_body::Body, + SB::Error: Into, + FB: http_body::Body, FB::Error: Into, { - type Response = Response; + type Response = Response>; type Error = Infallible; type Future = RouteFuture; @@ -424,14 +423,13 @@ where } fn call(&mut self, req: Request) -> Self::Future { - let f = if self.method.matches(req.method()) { - let response_future = self.svc.clone().oneshot(req); - Either::Left(BoxResponseBody(response_future)) + if self.method.matches(req.method()) { + let fut = self.svc.clone().oneshot(req); + RouteFuture::a(fut) } else { - let response_future = self.fallback.clone().oneshot(req); - Either::Right(BoxResponseBody(response_future)) - }; - RouteFuture(f) + let fut = self.fallback.clone().oneshot(req); + RouteFuture::b(fut) + } } }