diff --git a/CHANGELOG.md b/CHANGELOG.md index 43804059..e97e5d5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None +- Improve performance of `BoxRoute` ([#339]) +- **breaking:** `Router::boxed` now the inner service to implement `Clone` and + `Sync` in addition to the previous trait bounds ([#339]) + +[#339]: https://github.com/tokio-rs/axum/pull/339 # 0.2.6 (02. October, 2021) diff --git a/src/buffer.rs b/src/buffer.rs deleted file mode 100644 index f76b8fd4..00000000 --- a/src/buffer.rs +++ /dev/null @@ -1,194 +0,0 @@ -use futures_util::ready; -use pin_project_lite::pin_project; -use std::{ - future::Future, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; -use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore}; -use tokio_util::sync::PollSemaphore; -use tower::ServiceExt; -use tower_service::Service; - -/// A version of [`tower::buffer::Buffer`] which panicks on channel related errors, thus keeping -/// the error type of the service. -pub(crate) struct MpscBuffer<S, R> -where - S: Service<R>, -{ - tx: mpsc::UnboundedSender<Msg<S, R>>, - semaphore: PollSemaphore, - permit: Option<OwnedSemaphorePermit>, -} - -impl<S, R> Clone for MpscBuffer<S, R> -where - S: Service<R>, -{ - fn clone(&self) -> Self { - Self { - tx: self.tx.clone(), - semaphore: self.semaphore.clone(), - permit: None, - } - } -} - -impl<S, R> MpscBuffer<S, R> -where - S: Service<R>, -{ - pub(crate) fn new(svc: S) -> Self - where - S: Send + 'static, - R: Send + 'static, - S::Error: Send + 'static, - S::Future: Send + 'static, - { - let (tx, rx) = mpsc::unbounded_channel::<Msg<S, R>>(); - let semaphore = PollSemaphore::new(Arc::new(Semaphore::new(1024))); - - tokio::spawn(run_worker(svc, rx)); - - Self { - tx, - semaphore, - permit: None, - } - } -} - -async fn run_worker<S, R>(mut svc: S, mut rx: mpsc::UnboundedReceiver<Msg<S, R>>) -where - S: Service<R>, -{ - while let Some((req, reply_tx)) = rx.recv().await { - match svc.ready().await { - Ok(svc) => { - let future = svc.call(req); - let _ = reply_tx.send(WorkerReply::Future(future)); - } - Err(err) => { - let _ = reply_tx.send(WorkerReply::Error(err)); - } - } - } -} - -type Msg<S, R> = ( - R, - oneshot::Sender<WorkerReply<<S as Service<R>>::Future, <S as Service<R>>::Error>>, -); - -enum WorkerReply<F, E> { - Future(F), - Error(E), -} - -impl<S, R> Service<R> for MpscBuffer<S, R> -where - S: Service<R>, -{ - type Response = S::Response; - type Error = S::Error; - type Future = ResponseFuture<S::Future, S::Error>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - if self.permit.is_some() { - return Poll::Ready(Ok(())); - } - - let permit = ready!(self.semaphore.poll_acquire(cx)) - .expect("buffer semaphore closed. This is a bug in axum and should never happen. Please file an issue"); - - self.permit = Some(permit); - - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: R) -> Self::Future { - let permit = self - .permit - .take() - .expect("semaphore permit missing. Did you forget to call `poll_ready`?"); - - let (reply_tx, reply_rx) = oneshot::channel::<WorkerReply<S::Future, S::Error>>(); - - self.tx.send((req, reply_tx)).unwrap_or_else(|_| { - panic!("buffer worker not running. This is a bug in axum and should never happen. Please file an issue") - }); - - ResponseFuture { - state: State::Channel { reply_rx }, - permit, - } - } -} - -pin_project! { - pub(crate) struct ResponseFuture<F, E> { - #[pin] - state: State<F, E>, - permit: OwnedSemaphorePermit, - } -} - -pin_project! { - #[project = StateProj] - enum State<F, E> { - Channel { reply_rx: oneshot::Receiver<WorkerReply<F, E>> }, - Future { #[pin] future: F }, - } -} - -impl<F, E, T> Future for ResponseFuture<F, E> -where - F: Future<Output = Result<T, E>>, -{ - type Output = Result<T, E>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - loop { - let mut this = self.as_mut().project(); - - let new_state = match this.state.as_mut().project() { - StateProj::Channel { reply_rx } => { - let msg = ready!(Pin::new(reply_rx).poll(cx)) - .expect("buffer worker not running. This is a bug in axum and should never happen. Please file an issue"); - - match msg { - WorkerReply::Future(future) => State::Future { future }, - WorkerReply::Error(err) => return Poll::Ready(Err(err)), - } - } - StateProj::Future { future } => { - return future.poll(cx); - } - }; - - this.state.set(new_state); - } - } -} - -#[cfg(test)] -mod tests { - #[allow(unused_imports)] - use super::*; - use tower::ServiceExt; - - #[tokio::test] - async fn test_buffer() { - let mut svc = MpscBuffer::new(tower::service_fn(handle)); - - let res = svc.ready().await.unwrap().call(42).await.unwrap(); - - assert_eq!(res, "foo"); - } - - async fn handle(req: i32) -> Result<&'static str, std::convert::Infallible> { - assert_eq!(req, 42); - Ok("foo") - } -} diff --git a/src/clone_box_service.rs b/src/clone_box_service.rs new file mode 100644 index 00000000..3642d4b1 --- /dev/null +++ b/src/clone_box_service.rs @@ -0,0 +1,69 @@ +use futures_util::future::BoxFuture; +use std::task::{Context, Poll}; +use tower::ServiceExt; +use tower_service::Service; + +/// A `Clone + Send + Sync` boxed `Service` +pub(crate) struct CloneBoxService<T, U, E>( + Box< + dyn CloneService<T, Response = U, Error = E, Future = BoxFuture<'static, Result<U, E>>> + + Send + + Sync, + >, +); + +impl<T, U, E> CloneBoxService<T, U, E> { + pub(crate) fn new<S>(inner: S) -> Self + where + S: Service<T, Response = U, Error = E> + Clone + Send + Sync + 'static, + S::Future: Send + 'static, + { + let inner = inner.map_future(|f| Box::pin(f) as _); + CloneBoxService(Box::new(inner)) + } +} + +impl<T, U, E> Service<T> for CloneBoxService<T, U, E> { + type Response = U; + type Error = E; + type Future = BoxFuture<'static, Result<U, E>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), E>> { + self.0.poll_ready(cx) + } + + fn call(&mut self, request: T) -> Self::Future { + self.0.call(request) + } +} + +impl<T, U, E> Clone for CloneBoxService<T, U, E> { + fn clone(&self) -> Self { + Self(self.0.clone_box()) + } +} + +trait CloneService<R>: Service<R> { + fn clone_box( + &self, + ) -> Box< + dyn CloneService<R, Response = Self::Response, Error = Self::Error, Future = Self::Future> + + Send + + Sync, + >; +} + +impl<R, T> CloneService<R> for T +where + T: Service<R> + Send + Sync + Clone + 'static, +{ + fn clone_box( + &self, + ) -> Box< + dyn CloneService<R, Response = T::Response, Error = T::Error, Future = T::Future> + + Send + + Sync, + > { + Box::new(self.clone()) + } +} diff --git a/src/lib.rs b/src/lib.rs index ba102a0a..6c212e7b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1191,7 +1191,7 @@ #[macro_use] pub(crate) mod macros; -mod buffer; +mod clone_box_service; mod error; mod json; mod util; diff --git a/src/routing/future.rs b/src/routing/future.rs index 6f6e9cfe..3296d344 100644 --- a/src/routing/future.rs +++ b/src/routing/future.rs @@ -1,6 +1,8 @@ //! Future types. -use crate::{body::BoxBody, buffer::MpscBuffer, routing::FromEmptyRouter, BoxError}; +use crate::{ + body::BoxBody, clone_box_service::CloneBoxService, routing::FromEmptyRouter, BoxError, +}; use futures_util::ready; use http::{Request, Response}; use pin_project_lite::pin_project; @@ -11,10 +13,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tower::{ - util::{BoxService, Oneshot}, - ServiceExt, -}; +use tower::{util::Oneshot, ServiceExt}; use tower_service::Service; pub use super::or::ResponseFuture as OrResponseFuture; @@ -33,10 +32,7 @@ pin_project! { { #[pin] pub(super) inner: Oneshot< - MpscBuffer< - BoxService<Request<B>, Response<BoxBody>, E >, - Request<B> - >, + CloneBoxService<Request<B>, Response<BoxBody>, E>, Request<B>, >, } diff --git a/src/routing/mod.rs b/src/routing/mod.rs index dfb9c523..8f2fab37 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -3,7 +3,7 @@ use self::future::{BoxRouteFuture, EmptyRouterFuture, NestedFuture, RouteFuture}; use crate::{ body::{box_body, BoxBody}, - buffer::MpscBuffer, + clone_box_service::CloneBoxService, extract::{ connect_info::{Connected, IntoMakeServiceWithConnectInfo}, OriginalUri, @@ -24,10 +24,7 @@ use std::{ sync::Arc, task::{Context, Poll}, }; -use tower::{ - util::{BoxService, ServiceExt}, - ServiceBuilder, -}; +use tower::{util::ServiceExt, ServiceBuilder}; use tower_http::map_response_body::MapResponseBodyLayer; use tower_layer::Layer; use tower_service::Service; @@ -256,7 +253,7 @@ impl<S> Router<S> { /// routes. pub fn boxed<ReqBody, ResBody>(self) -> Router<BoxRoute<ReqBody, S::Error>> where - S: Service<Request<ReqBody>, Response = Response<ResBody>> + Send + 'static, + S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + Sync + 'static, S::Error: Into<BoxError> + Send, S::Future: Send, ReqBody: Send + 'static, @@ -266,8 +263,7 @@ impl<S> Router<S> { self.map(|svc| { ServiceBuilder::new() .layer_fn(BoxRoute) - .layer_fn(MpscBuffer::new) - .layer(BoxService::layer()) + .layer_fn(CloneBoxService::new) .layer(MapResponseBodyLayer::new(box_body)) .service(svc) }) @@ -834,7 +830,7 @@ type Captures = Vec<(String, String)>; /// /// See [`Router::boxed`] for more details. pub struct BoxRoute<B = crate::body::Body, E = Infallible>( - MpscBuffer<BoxService<Request<B>, Response<BoxBody>, E>, Request<B>>, + CloneBoxService<Request<B>, Response<BoxBody>, E>, ); impl<B, E> Clone for BoxRoute<B, E> {