diff --git a/Cargo.toml b/Cargo.toml index 3e4f5b5f..8056e6fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,8 +41,6 @@ tower-layer = "0.3" tower-http = { version = "0.1", features = ["add-extension", "map-response-body"] } sync_wrapper = "0.1.1" -dyn-clone = "1.0" - # optional dependencies tokio-tungstenite = { optional = true, version = "0.15" } sha-1 = { optional = true, version = "0.9.6" } diff --git a/src/buffer.rs b/src/buffer.rs new file mode 100644 index 00000000..f76b8fd4 --- /dev/null +++ b/src/buffer.rs @@ -0,0 +1,194 @@ +use futures_util::ready; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore}; +use tokio_util::sync::PollSemaphore; +use tower::ServiceExt; +use tower_service::Service; + +/// A version of [`tower::buffer::Buffer`] which panicks on channel related errors, thus keeping +/// the error type of the service. +pub(crate) struct MpscBuffer +where + S: Service, +{ + tx: mpsc::UnboundedSender>, + semaphore: PollSemaphore, + permit: Option, +} + +impl Clone for MpscBuffer +where + S: Service, +{ + fn clone(&self) -> Self { + Self { + tx: self.tx.clone(), + semaphore: self.semaphore.clone(), + permit: None, + } + } +} + +impl MpscBuffer +where + S: Service, +{ + pub(crate) fn new(svc: S) -> Self + where + S: Send + 'static, + R: Send + 'static, + S::Error: Send + 'static, + S::Future: Send + 'static, + { + let (tx, rx) = mpsc::unbounded_channel::>(); + let semaphore = PollSemaphore::new(Arc::new(Semaphore::new(1024))); + + tokio::spawn(run_worker(svc, rx)); + + Self { + tx, + semaphore, + permit: None, + } + } +} + +async fn run_worker(mut svc: S, mut rx: mpsc::UnboundedReceiver>) +where + S: Service, +{ + while let Some((req, reply_tx)) = rx.recv().await { + match svc.ready().await { + Ok(svc) => { + let future = svc.call(req); + let _ = reply_tx.send(WorkerReply::Future(future)); + } + Err(err) => { + let _ = reply_tx.send(WorkerReply::Error(err)); + } + } + } +} + +type Msg = ( + R, + oneshot::Sender>::Future, >::Error>>, +); + +enum WorkerReply { + Future(F), + Error(E), +} + +impl Service for MpscBuffer +where + S: Service, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.permit.is_some() { + return Poll::Ready(Ok(())); + } + + let permit = ready!(self.semaphore.poll_acquire(cx)) + .expect("buffer semaphore closed. This is a bug in axum and should never happen. Please file an issue"); + + self.permit = Some(permit); + + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: R) -> Self::Future { + let permit = self + .permit + .take() + .expect("semaphore permit missing. Did you forget to call `poll_ready`?"); + + let (reply_tx, reply_rx) = oneshot::channel::>(); + + self.tx.send((req, reply_tx)).unwrap_or_else(|_| { + panic!("buffer worker not running. This is a bug in axum and should never happen. Please file an issue") + }); + + ResponseFuture { + state: State::Channel { reply_rx }, + permit, + } + } +} + +pin_project! { + pub(crate) struct ResponseFuture { + #[pin] + state: State, + permit: OwnedSemaphorePermit, + } +} + +pin_project! { + #[project = StateProj] + enum State { + Channel { reply_rx: oneshot::Receiver> }, + Future { #[pin] future: F }, + } +} + +impl Future for ResponseFuture +where + F: Future>, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + let mut this = self.as_mut().project(); + + let new_state = match this.state.as_mut().project() { + StateProj::Channel { reply_rx } => { + let msg = ready!(Pin::new(reply_rx).poll(cx)) + .expect("buffer worker not running. This is a bug in axum and should never happen. Please file an issue"); + + match msg { + WorkerReply::Future(future) => State::Future { future }, + WorkerReply::Error(err) => return Poll::Ready(Err(err)), + } + } + StateProj::Future { future } => { + return future.poll(cx); + } + }; + + this.state.set(new_state); + } + } +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use tower::ServiceExt; + + #[tokio::test] + async fn test_buffer() { + let mut svc = MpscBuffer::new(tower::service_fn(handle)); + + let res = svc.ready().await.unwrap().call(42).await.unwrap(); + + assert_eq!(res, "foo"); + } + + async fn handle(req: i32) -> Result<&'static str, std::convert::Infallible> { + assert_eq!(req, 42); + Ok("foo") + } +} diff --git a/src/lib.rs b/src/lib.rs index 637e1100..082796f5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1004,6 +1004,7 @@ #[macro_use] pub(crate) mod macros; +mod buffer; mod error; mod json; mod util; diff --git a/src/routing/future.rs b/src/routing/future.rs index bd08b8ea..6f6e9cfe 100644 --- a/src/routing/future.rs +++ b/src/routing/future.rs @@ -1,6 +1,6 @@ //! Future types. -use crate::{body::BoxBody, routing::FromEmptyRouter, util::CloneBoxService, BoxError}; +use crate::{body::BoxBody, buffer::MpscBuffer, routing::FromEmptyRouter, BoxError}; use futures_util::ready; use http::{Request, Response}; use pin_project_lite::pin_project; @@ -11,7 +11,10 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tower::{util::Oneshot, ServiceExt}; +use tower::{ + util::{BoxService, Oneshot}, + ServiceExt, +}; use tower_service::Service; pub use super::or::ResponseFuture as OrResponseFuture; @@ -30,7 +33,10 @@ pin_project! { { #[pin] pub(super) inner: Oneshot< - CloneBoxService, Response, E>, + MpscBuffer< + BoxService, Response, E >, + Request + >, Request, >, } diff --git a/src/routing/mod.rs b/src/routing/mod.rs index ccf7597a..c8456618 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -3,12 +3,13 @@ use self::future::{BoxRouteFuture, EmptyRouterFuture, NestedFuture, RouteFuture}; use crate::{ body::{box_body, BoxBody}, + buffer::MpscBuffer, extract::{ connect_info::{Connected, IntoMakeServiceWithConnectInfo}, OriginalUri, }, service::HandleError, - util::{ByteStr, CloneBoxService}, + util::ByteStr, BoxError, }; use bytes::Bytes; @@ -23,7 +24,10 @@ use std::{ sync::Arc, task::{Context, Poll}, }; -use tower::{util::ServiceExt, ServiceBuilder}; +use tower::{ + util::{BoxService, ServiceExt}, + ServiceBuilder, +}; use tower_http::map_response_body::MapResponseBodyLayer; use tower_layer::Layer; use tower_service::Service; @@ -252,7 +256,7 @@ impl Router { /// routes. pub fn boxed(self) -> Router> where - S: Service, Response = Response> + Clone + Send + 'static, + S: Service, Response = Response> + Send + 'static, S::Error: Into + Send, S::Future: Send, ReqBody: Send + 'static, @@ -261,8 +265,9 @@ impl Router { { self.map(|svc| { ServiceBuilder::new() - .layer_fn(|inner| BoxRoute { inner }) - .layer_fn(CloneBoxService::new) + .layer_fn(BoxRoute) + .layer_fn(MpscBuffer::new) + .layer(BoxService::layer()) .layer(MapResponseBodyLayer::new(box_body)) .service(svc) }) @@ -828,15 +833,13 @@ type Captures = Vec<(String, String)>; /// A boxed route trait object. /// /// See [`Router::boxed`] for more details. -pub struct BoxRoute { - inner: CloneBoxService, Response, E>, -} +pub struct BoxRoute( + MpscBuffer, Response, E>, Request>, +); impl Clone for BoxRoute { fn clone(&self) -> Self { - BoxRoute { - inner: self.inner.clone(), - } + Self(self.0.clone()) } } @@ -862,7 +865,7 @@ where #[inline] fn call(&mut self, req: Request) -> Self::Future { BoxRouteFuture { - inner: self.inner.clone().oneshot(req), + inner: self.0.clone().oneshot(req), } } } diff --git a/src/util/mod.rs b/src/util.rs similarity index 91% rename from src/util/mod.rs rename to src/util.rs index 2bce6833..a731f741 100644 --- a/src/util/mod.rs +++ b/src/util.rs @@ -2,10 +2,6 @@ use bytes::Bytes; use pin_project_lite::pin_project; use std::ops::Deref; -mod clone_box_service; - -pub(crate) use self::clone_box_service::CloneBoxService; - /// A string like type backed by `Bytes` making it cheap to clone. #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub(crate) struct ByteStr(Bytes); diff --git a/src/util/clone_box_service.rs b/src/util/clone_box_service.rs deleted file mode 100644 index 1e70e3a5..00000000 --- a/src/util/clone_box_service.rs +++ /dev/null @@ -1,75 +0,0 @@ -use futures_util::future::BoxFuture; -use std::future::Future; -use std::task::{Context, Poll}; -use tower::ServiceExt; -use tower_service::Service; - -/// A boxed Service that implements Clone -/// -/// Could probably upstream this to tower -pub(crate) struct CloneBoxService { - inner: Box< - dyn CloneService>> - + Send, - >, -} - -impl CloneBoxService { - pub(crate) fn new(inner: S) -> Self - where - S: Service + Clone + Send + 'static, - S::Future: Send + 'static, - { - let inner = Box::new(inner.map_future(|f| Box::pin(f) as _)); - Self { inner } - } -} - -impl Clone for CloneBoxService { - fn clone(&self) -> Self { - Self { - inner: dyn_clone::clone_box(&*self.inner), - } - } -} - -impl Service for CloneBoxService { - type Response = U; - type Error = E; - type Future = BoxFuture<'static, Result>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - CloneService::poll_ready(&mut *self.inner, cx) - } - - fn call(&mut self, req: T) -> Self::Future { - CloneService::call(&mut *self.inner, req) - } -} - -trait CloneService: dyn_clone::DynClone { - type Response; - type Error; - type Future: Future>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll>; - - fn call(&mut self, req: R) -> Self::Future; -} - -impl CloneService for T -where - T: Service + Clone, -{ - type Response = T::Response; - type Error = T::Error; - type Future = T::Future; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - Service::poll_ready(self, cx) - } - - fn call(&mut self, req: R) -> Self::Future { - Service::call(self, req) - } -}