From 552d69e5d4af638c92a262d4d9e55334cb061226 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 26 Aug 2021 08:34:53 +0200 Subject: [PATCH] Remove buffer from `BoxRoute` (#270) Boxing a service normally means using `tower::util::BoxService`. That doesn't implement `Clone` however so normally I had been combining it with `Buffer` to get that. But recently I discovered https://github.com/dtolnay/dyn-clone which makes it possible to clone trait objects. So this adds a new internal utility called `CloneBoxService` which replaces the previous `BoxService` + `Buffer` combo in `BoxRoute`. I'll investigate upstreaming that to tower. I think it makes sense there since box + clone is quite a common need. --- Cargo.toml | 2 + src/buffer.rs | 194 ---------------------------------- src/lib.rs | 1 - src/routing/future.rs | 12 +-- src/routing/mod.rs | 27 +++-- src/util/clone_box_service.rs | 75 +++++++++++++ src/{util.rs => util/mod.rs} | 4 + 7 files changed, 96 insertions(+), 219 deletions(-) delete mode 100644 src/buffer.rs create mode 100644 src/util/clone_box_service.rs rename src/{util.rs => util/mod.rs} (91%) diff --git a/Cargo.toml b/Cargo.toml index 4147c244..46487291 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,8 @@ tower-layer = "0.3" tower-http = { version = "0.1", features = ["add-extension", "map-response-body"] } sync_wrapper = "0.1.1" +dyn-clone = "1.0" + # optional dependencies tokio-tungstenite = { optional = true, version = "0.15" } sha-1 = { optional = true, version = "0.9.6" } diff --git a/src/buffer.rs b/src/buffer.rs deleted file mode 100644 index f76b8fd4..00000000 --- a/src/buffer.rs +++ /dev/null @@ -1,194 +0,0 @@ -use futures_util::ready; -use pin_project_lite::pin_project; -use std::{ - future::Future, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; -use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore}; -use tokio_util::sync::PollSemaphore; -use tower::ServiceExt; -use tower_service::Service; - -/// A version of [`tower::buffer::Buffer`] which panicks on channel related errors, thus keeping -/// the error type of the service. -pub(crate) struct MpscBuffer -where - S: Service, -{ - tx: mpsc::UnboundedSender>, - semaphore: PollSemaphore, - permit: Option, -} - -impl Clone for MpscBuffer -where - S: Service, -{ - fn clone(&self) -> Self { - Self { - tx: self.tx.clone(), - semaphore: self.semaphore.clone(), - permit: None, - } - } -} - -impl MpscBuffer -where - S: Service, -{ - pub(crate) fn new(svc: S) -> Self - where - S: Send + 'static, - R: Send + 'static, - S::Error: Send + 'static, - S::Future: Send + 'static, - { - let (tx, rx) = mpsc::unbounded_channel::>(); - let semaphore = PollSemaphore::new(Arc::new(Semaphore::new(1024))); - - tokio::spawn(run_worker(svc, rx)); - - Self { - tx, - semaphore, - permit: None, - } - } -} - -async fn run_worker(mut svc: S, mut rx: mpsc::UnboundedReceiver>) -where - S: Service, -{ - while let Some((req, reply_tx)) = rx.recv().await { - match svc.ready().await { - Ok(svc) => { - let future = svc.call(req); - let _ = reply_tx.send(WorkerReply::Future(future)); - } - Err(err) => { - let _ = reply_tx.send(WorkerReply::Error(err)); - } - } - } -} - -type Msg = ( - R, - oneshot::Sender>::Future, >::Error>>, -); - -enum WorkerReply { - Future(F), - Error(E), -} - -impl Service for MpscBuffer -where - S: Service, -{ - type Response = S::Response; - type Error = S::Error; - type Future = ResponseFuture; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.permit.is_some() { - return Poll::Ready(Ok(())); - } - - let permit = ready!(self.semaphore.poll_acquire(cx)) - .expect("buffer semaphore closed. This is a bug in axum and should never happen. Please file an issue"); - - self.permit = Some(permit); - - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: R) -> Self::Future { - let permit = self - .permit - .take() - .expect("semaphore permit missing. Did you forget to call `poll_ready`?"); - - let (reply_tx, reply_rx) = oneshot::channel::>(); - - self.tx.send((req, reply_tx)).unwrap_or_else(|_| { - panic!("buffer worker not running. This is a bug in axum and should never happen. Please file an issue") - }); - - ResponseFuture { - state: State::Channel { reply_rx }, - permit, - } - } -} - -pin_project! { - pub(crate) struct ResponseFuture { - #[pin] - state: State, - permit: OwnedSemaphorePermit, - } -} - -pin_project! { - #[project = StateProj] - enum State { - Channel { reply_rx: oneshot::Receiver> }, - Future { #[pin] future: F }, - } -} - -impl Future for ResponseFuture -where - F: Future>, -{ - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - let mut this = self.as_mut().project(); - - let new_state = match this.state.as_mut().project() { - StateProj::Channel { reply_rx } => { - let msg = ready!(Pin::new(reply_rx).poll(cx)) - .expect("buffer worker not running. This is a bug in axum and should never happen. Please file an issue"); - - match msg { - WorkerReply::Future(future) => State::Future { future }, - WorkerReply::Error(err) => return Poll::Ready(Err(err)), - } - } - StateProj::Future { future } => { - return future.poll(cx); - } - }; - - this.state.set(new_state); - } - } -} - -#[cfg(test)] -mod tests { - #[allow(unused_imports)] - use super::*; - use tower::ServiceExt; - - #[tokio::test] - async fn test_buffer() { - let mut svc = MpscBuffer::new(tower::service_fn(handle)); - - let res = svc.ready().await.unwrap().call(42).await.unwrap(); - - assert_eq!(res, "foo"); - } - - async fn handle(req: i32) -> Result<&'static str, std::convert::Infallible> { - assert_eq!(req, 42); - Ok("foo") - } -} diff --git a/src/lib.rs b/src/lib.rs index 082796f5..637e1100 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1004,7 +1004,6 @@ #[macro_use] pub(crate) mod macros; -mod buffer; mod error; mod json; mod util; diff --git a/src/routing/future.rs b/src/routing/future.rs index 6f6e9cfe..bd08b8ea 100644 --- a/src/routing/future.rs +++ b/src/routing/future.rs @@ -1,6 +1,6 @@ //! Future types. -use crate::{body::BoxBody, buffer::MpscBuffer, routing::FromEmptyRouter, BoxError}; +use crate::{body::BoxBody, routing::FromEmptyRouter, util::CloneBoxService, BoxError}; use futures_util::ready; use http::{Request, Response}; use pin_project_lite::pin_project; @@ -11,10 +11,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tower::{ - util::{BoxService, Oneshot}, - ServiceExt, -}; +use tower::{util::Oneshot, ServiceExt}; use tower_service::Service; pub use super::or::ResponseFuture as OrResponseFuture; @@ -33,10 +30,7 @@ pin_project! { { #[pin] pub(super) inner: Oneshot< - MpscBuffer< - BoxService, Response, E >, - Request - >, + CloneBoxService, Response, E>, Request, >, } diff --git a/src/routing/mod.rs b/src/routing/mod.rs index c8456618..ccf7597a 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -3,13 +3,12 @@ use self::future::{BoxRouteFuture, EmptyRouterFuture, NestedFuture, RouteFuture}; use crate::{ body::{box_body, BoxBody}, - buffer::MpscBuffer, extract::{ connect_info::{Connected, IntoMakeServiceWithConnectInfo}, OriginalUri, }, service::HandleError, - util::ByteStr, + util::{ByteStr, CloneBoxService}, BoxError, }; use bytes::Bytes; @@ -24,10 +23,7 @@ use std::{ sync::Arc, task::{Context, Poll}, }; -use tower::{ - util::{BoxService, ServiceExt}, - ServiceBuilder, -}; +use tower::{util::ServiceExt, ServiceBuilder}; use tower_http::map_response_body::MapResponseBodyLayer; use tower_layer::Layer; use tower_service::Service; @@ -256,7 +252,7 @@ impl Router { /// routes. pub fn boxed(self) -> Router> where - S: Service, Response = Response> + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Error: Into + Send, S::Future: Send, ReqBody: Send + 'static, @@ -265,9 +261,8 @@ impl Router { { self.map(|svc| { ServiceBuilder::new() - .layer_fn(BoxRoute) - .layer_fn(MpscBuffer::new) - .layer(BoxService::layer()) + .layer_fn(|inner| BoxRoute { inner }) + .layer_fn(CloneBoxService::new) .layer(MapResponseBodyLayer::new(box_body)) .service(svc) }) @@ -833,13 +828,15 @@ type Captures = Vec<(String, String)>; /// A boxed route trait object. /// /// See [`Router::boxed`] for more details. -pub struct BoxRoute( - MpscBuffer, Response, E>, Request>, -); +pub struct BoxRoute { + inner: CloneBoxService, Response, E>, +} impl Clone for BoxRoute { fn clone(&self) -> Self { - Self(self.0.clone()) + BoxRoute { + inner: self.inner.clone(), + } } } @@ -865,7 +862,7 @@ where #[inline] fn call(&mut self, req: Request) -> Self::Future { BoxRouteFuture { - inner: self.0.clone().oneshot(req), + inner: self.inner.clone().oneshot(req), } } } diff --git a/src/util/clone_box_service.rs b/src/util/clone_box_service.rs new file mode 100644 index 00000000..1e70e3a5 --- /dev/null +++ b/src/util/clone_box_service.rs @@ -0,0 +1,75 @@ +use futures_util::future::BoxFuture; +use std::future::Future; +use std::task::{Context, Poll}; +use tower::ServiceExt; +use tower_service::Service; + +/// A boxed Service that implements Clone +/// +/// Could probably upstream this to tower +pub(crate) struct CloneBoxService { + inner: Box< + dyn CloneService>> + + Send, + >, +} + +impl CloneBoxService { + pub(crate) fn new(inner: S) -> Self + where + S: Service + Clone + Send + 'static, + S::Future: Send + 'static, + { + let inner = Box::new(inner.map_future(|f| Box::pin(f) as _)); + Self { inner } + } +} + +impl Clone for CloneBoxService { + fn clone(&self) -> Self { + Self { + inner: dyn_clone::clone_box(&*self.inner), + } + } +} + +impl Service for CloneBoxService { + type Response = U; + type Error = E; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + CloneService::poll_ready(&mut *self.inner, cx) + } + + fn call(&mut self, req: T) -> Self::Future { + CloneService::call(&mut *self.inner, req) + } +} + +trait CloneService: dyn_clone::DynClone { + type Response; + type Error; + type Future: Future>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll>; + + fn call(&mut self, req: R) -> Self::Future; +} + +impl CloneService for T +where + T: Service + Clone, +{ + type Response = T::Response; + type Error = T::Error; + type Future = T::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + Service::poll_ready(self, cx) + } + + fn call(&mut self, req: R) -> Self::Future { + Service::call(self, req) + } +} diff --git a/src/util.rs b/src/util/mod.rs similarity index 91% rename from src/util.rs rename to src/util/mod.rs index a731f741..2bce6833 100644 --- a/src/util.rs +++ b/src/util/mod.rs @@ -2,6 +2,10 @@ use bytes::Bytes; use pin_project_lite::pin_project; use std::ops::Deref; +mod clone_box_service; + +pub(crate) use self::clone_box_service::CloneBoxService; + /// A string like type backed by `Bytes` making it cheap to clone. #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub(crate) struct ByteStr(Bytes);