diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index be5f4d01..ecaa01fe 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -8,9 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased - **fixed:** Improve `debug_handler` on tuple response types ([#2201]) +- **fixed:** Fix performance regression present since axum 0.7.0 ([#2483]) - **added:** Add `must_use` attribute to `Serve` and `WithGracefulShutdown` ([#2484]) - **added:** Re-export `axum_core::body::BodyDataStream` from axum +[#2201]: https://github.com/tokio-rs/axum/pull/2201 +[#2483]: https://github.com/tokio-rs/axum/pull/2483 [#2201]: https://github.com/tokio-rs/axum/pull/2201 [#2484]: https://github.com/tokio-rs/axum/pull/2484 diff --git a/axum/clippy.toml b/axum/clippy.toml new file mode 100644 index 00000000..291e8cd5 --- /dev/null +++ b/axum/clippy.toml @@ -0,0 +1,3 @@ +disallowed-types = [ + { path = "std::sync::Mutex", reason = "Use our internal AxumMutex instead" }, +] diff --git a/axum/src/boxed.rs b/axum/src/boxed.rs index 69fda3fd..c734c2fb 100644 --- a/axum/src/boxed.rs +++ b/axum/src/boxed.rs @@ -1,6 +1,7 @@ -use std::{convert::Infallible, fmt, sync::Mutex}; +use std::{convert::Infallible, fmt}; use crate::extract::Request; +use crate::util::AxumMutex; use tower::Service; use crate::{ @@ -9,7 +10,7 @@ use crate::{ Router, }; -pub(crate) struct BoxedIntoRoute(Mutex>>); +pub(crate) struct BoxedIntoRoute(AxumMutex>>); impl BoxedIntoRoute where @@ -20,7 +21,7 @@ where H: Handler, T: 'static, { - Self(Mutex::new(Box::new(MakeErasedHandler { + Self(AxumMutex::new(Box::new(MakeErasedHandler { handler, into_route: |handler, state| Route::new(Handler::with_state(handler, state)), }))) @@ -35,7 +36,7 @@ impl BoxedIntoRoute { F: FnOnce(Route) -> Route + Clone + Send + 'static, E2: 'static, { - BoxedIntoRoute(Mutex::new(Box::new(Map { + BoxedIntoRoute(AxumMutex::new(Box::new(Map { inner: self.0.into_inner().unwrap(), layer: Box::new(f), }))) @@ -48,7 +49,7 @@ impl BoxedIntoRoute { impl Clone for BoxedIntoRoute { fn clone(&self) -> Self { - Self(Mutex::new(self.0.lock().unwrap().clone_box())) + Self(AxumMutex::new(self.0.lock().unwrap().clone_box())) } } @@ -118,7 +119,7 @@ where (self.into_route)(self.router, state) } - fn call_with_state(mut self: Box, request: Request, state: S) -> RouteFuture { + fn call_with_state(self: Box, request: Request, state: S) -> RouteFuture { self.router.call_with_state(request, state) } } diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 962a4401..bb6a1bc9 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1022,7 +1022,7 @@ where self } - pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { + pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture { macro_rules! call { ( $req:expr, @@ -1034,12 +1034,12 @@ where match $svc { MethodEndpoint::None => {} MethodEndpoint::Route(route) => { - return RouteFuture::from_future(route.oneshot_inner($req)) + return RouteFuture::from_future(route.clone().oneshot_inner($req)) .strip_body($method == Method::HEAD); } MethodEndpoint::BoxedHandler(handler) => { - let mut route = handler.clone().into_route(state); - return RouteFuture::from_future(route.oneshot_inner($req)) + let route = handler.clone().into_route(state); + return RouteFuture::from_future(route.clone().oneshot_inner($req)) .strip_body($method == Method::HEAD); } } @@ -1073,7 +1073,7 @@ where call!(req, method, DELETE, delete); call!(req, method, TRACE, trace); - let future = fallback.call_with_state(req, state); + let future = fallback.clone().call_with_state(req, state); match allow_header { AllowHeader::None => future.allow_header(Bytes::new()), @@ -1219,7 +1219,7 @@ where { type Future = InfallibleRouteFuture; - fn call(mut self, req: Request, state: S) -> Self::Future { + fn call(self, req: Request, state: S) -> Self::Future { InfallibleRouteFuture::new(self.call_with_state(req, state)) } } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 49237e4e..fad920a4 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -17,6 +17,7 @@ use std::{ convert::Infallible, fmt, marker::PhantomData, + sync::Arc, task::{Context, Poll}, }; use tower_layer::Layer; @@ -59,23 +60,24 @@ pub(crate) struct RouteId(u32); /// The router type for composing handlers and services. #[must_use] pub struct Router { - path_router: PathRouter, - fallback_router: PathRouter, - default_fallback: bool, - catch_all_fallback: Fallback, + inner: Arc>, } impl Clone for Router { fn clone(&self) -> Self { Self { - path_router: self.path_router.clone(), - fallback_router: self.fallback_router.clone(), - default_fallback: self.default_fallback, - catch_all_fallback: self.catch_all_fallback.clone(), + inner: Arc::clone(&self.inner), } } } +struct RouterInner { + path_router: PathRouter, + fallback_router: PathRouter, + default_fallback: bool, + catch_all_fallback: Fallback, +} + impl Default for Router where S: Clone + Send + Sync + 'static, @@ -88,10 +90,10 @@ where impl fmt::Debug for Router { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Router") - .field("path_router", &self.path_router) - .field("fallback_router", &self.fallback_router) - .field("default_fallback", &self.default_fallback) - .field("catch_all_fallback", &self.catch_all_fallback) + .field("path_router", &self.inner.path_router) + .field("fallback_router", &self.inner.fallback_router) + .field("default_fallback", &self.inner.default_fallback) + .field("catch_all_fallback", &self.inner.catch_all_fallback) .finish() } } @@ -111,22 +113,57 @@ where /// all requests. pub fn new() -> Self { Self { - path_router: Default::default(), - fallback_router: PathRouter::new_fallback(), - default_fallback: true, - catch_all_fallback: Fallback::Default(Route::new(NotFound)), + inner: Arc::new(RouterInner { + path_router: Default::default(), + fallback_router: PathRouter::new_fallback(), + default_fallback: true, + catch_all_fallback: Fallback::Default(Route::new(NotFound)), + }), + } + } + + fn map_inner(self, f: F) -> Router + where + F: FnOnce(RouterInner) -> RouterInner, + { + Router { + inner: Arc::new(f(self.into_inner())), + } + } + + fn tap_inner_mut(self, f: F) -> Self + where + F: FnOnce(&mut RouterInner), + { + let mut inner = self.into_inner(); + f(&mut inner); + Router { + inner: Arc::new(inner), + } + } + + fn into_inner(self) -> RouterInner { + match Arc::try_unwrap(self.inner) { + Ok(inner) => inner, + Err(arc) => RouterInner { + path_router: arc.path_router.clone(), + fallback_router: arc.fallback_router.clone(), + default_fallback: arc.default_fallback, + catch_all_fallback: arc.catch_all_fallback.clone(), + }, } } #[doc = include_str!("../docs/routing/route.md")] #[track_caller] - pub fn route(mut self, path: &str, method_router: MethodRouter) -> Self { - panic_on_err!(self.path_router.route(path, method_router)); - self + pub fn route(self, path: &str, method_router: MethodRouter) -> Self { + self.tap_inner_mut(|this| { + panic_on_err!(this.path_router.route(path, method_router)); + }) } #[doc = include_str!("../docs/routing/route_service.md")] - pub fn route_service(mut self, path: &str, service: T) -> Self + pub fn route_service(self, path: &str, service: T) -> Self where T: Service + Clone + Send + 'static, T::Response: IntoResponse, @@ -142,14 +179,15 @@ where Err(service) => service, }; - panic_on_err!(self.path_router.route_service(path, service)); - self + self.tap_inner_mut(|this| { + panic_on_err!(this.path_router.route_service(path, service)); + }) } #[doc = include_str!("../docs/routing/nest.md")] #[track_caller] - pub fn nest(mut self, path: &str, router: Router) -> Self { - let Router { + pub fn nest(self, path: &str, router: Router) -> Self { + let RouterInner { path_router, fallback_router, default_fallback, @@ -157,76 +195,80 @@ where // requests with an empty path. If we were to inherit the catch-all fallback // it would end up matching `/{path}/*` which doesn't match empty paths. catch_all_fallback: _, - } = router; + } = router.into_inner(); - panic_on_err!(self.path_router.nest(path, path_router)); + self.tap_inner_mut(|this| { + panic_on_err!(this.path_router.nest(path, path_router)); - if !default_fallback { - panic_on_err!(self.fallback_router.nest(path, fallback_router)); - } - - self + if !default_fallback { + panic_on_err!(this.fallback_router.nest(path, fallback_router)); + } + }) } /// Like [`nest`](Self::nest), but accepts an arbitrary `Service`. #[track_caller] - pub fn nest_service(mut self, path: &str, service: T) -> Self + pub fn nest_service(self, path: &str, service: T) -> Self where T: Service + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { - panic_on_err!(self.path_router.nest_service(path, service)); - self + self.tap_inner_mut(|this| { + panic_on_err!(this.path_router.nest_service(path, service)); + }) } #[doc = include_str!("../docs/routing/merge.md")] #[track_caller] - pub fn merge(mut self, other: R) -> Self + pub fn merge(self, other: R) -> Self where R: Into>, { const PANIC_MSG: &str = "Failed to merge fallbacks. This is a bug in axum. Please file an issue"; - let Router { + let other: Router = other.into(); + let RouterInner { path_router, fallback_router: mut other_fallback, default_fallback, catch_all_fallback, - } = other.into(); + } = other.into_inner(); - panic_on_err!(self.path_router.merge(path_router)); + self.map_inner(|mut this| { + panic_on_err!(this.path_router.merge(path_router)); - match (self.default_fallback, default_fallback) { - // both have the default fallback - // use the one from other - (true, true) => { - self.fallback_router.merge(other_fallback).expect(PANIC_MSG); - } - // self has default fallback, other has a custom fallback - (true, false) => { - self.fallback_router.merge(other_fallback).expect(PANIC_MSG); - self.default_fallback = false; - } - // self has a custom fallback, other has a default - (false, true) => { - let fallback_router = std::mem::take(&mut self.fallback_router); - other_fallback.merge(fallback_router).expect(PANIC_MSG); - self.fallback_router = other_fallback; - } - // both have a custom fallback, not allowed - (false, false) => { - panic!("Cannot merge two `Router`s that both have a fallback") - } - }; + match (this.default_fallback, default_fallback) { + // both have the default fallback + // use the one from other + (true, true) => { + this.fallback_router.merge(other_fallback).expect(PANIC_MSG); + } + // this has default fallback, other has a custom fallback + (true, false) => { + this.fallback_router.merge(other_fallback).expect(PANIC_MSG); + this.default_fallback = false; + } + // this has a custom fallback, other has a default + (false, true) => { + let fallback_router = std::mem::take(&mut this.fallback_router); + other_fallback.merge(fallback_router).expect(PANIC_MSG); + this.fallback_router = other_fallback; + } + // both have a custom fallback, not allowed + (false, false) => { + panic!("Cannot merge two `Router`s that both have a fallback") + } + }; - self.catch_all_fallback = self - .catch_all_fallback - .merge(catch_all_fallback) - .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback")); + this.catch_all_fallback = this + .catch_all_fallback + .merge(catch_all_fallback) + .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback")); - self + this + }) } #[doc = include_str!("../docs/routing/layer.md")] @@ -238,12 +280,12 @@ where >::Error: Into + 'static, >::Future: Send + 'static, { - Router { - path_router: self.path_router.layer(layer.clone()), - fallback_router: self.fallback_router.layer(layer.clone()), - default_fallback: self.default_fallback, - catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)), - } + self.map_inner(|this| RouterInner { + path_router: this.path_router.layer(layer.clone()), + fallback_router: this.fallback_router.layer(layer.clone()), + default_fallback: this.default_fallback, + catch_all_fallback: this.catch_all_fallback.map(|route| route.layer(layer)), + }) } #[doc = include_str!("../docs/routing/route_layer.md")] @@ -256,68 +298,76 @@ where >::Error: Into + 'static, >::Future: Send + 'static, { - Router { - path_router: self.path_router.route_layer(layer), - fallback_router: self.fallback_router, - default_fallback: self.default_fallback, - catch_all_fallback: self.catch_all_fallback, - } + self.map_inner(|this| RouterInner { + path_router: this.path_router.route_layer(layer), + fallback_router: this.fallback_router, + default_fallback: this.default_fallback, + catch_all_fallback: this.catch_all_fallback, + }) } #[track_caller] #[doc = include_str!("../docs/routing/fallback.md")] - pub fn fallback(mut self, handler: H) -> Self + pub fn fallback(self, handler: H) -> Self where H: Handler, T: 'static, { - self.catch_all_fallback = - Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone())); - self.fallback_endpoint(Endpoint::MethodRouter(any(handler))) + self.tap_inner_mut(|this| { + this.catch_all_fallback = + Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone())); + }) + .fallback_endpoint(Endpoint::MethodRouter(any(handler))) } /// Add a fallback [`Service`] to the router. /// /// See [`Router::fallback`] for more details. - pub fn fallback_service(mut self, service: T) -> Self + pub fn fallback_service(self, service: T) -> Self where T: Service + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { let route = Route::new(service); - self.catch_all_fallback = Fallback::Service(route.clone()); - self.fallback_endpoint(Endpoint::Route(route)) + self.tap_inner_mut(|this| { + this.catch_all_fallback = Fallback::Service(route.clone()); + }) + .fallback_endpoint(Endpoint::Route(route)) } - fn fallback_endpoint(mut self, endpoint: Endpoint) -> Self { - self.fallback_router.set_fallback(endpoint); - self.default_fallback = false; - self + fn fallback_endpoint(self, endpoint: Endpoint) -> Self { + self.tap_inner_mut(|this| { + this.fallback_router.set_fallback(endpoint); + this.default_fallback = false; + }) } #[doc = include_str!("../docs/routing/with_state.md")] pub fn with_state(self, state: S) -> Router { - Router { - path_router: self.path_router.with_state(state.clone()), - fallback_router: self.fallback_router.with_state(state.clone()), - default_fallback: self.default_fallback, - catch_all_fallback: self.catch_all_fallback.with_state(state), - } + self.map_inner(|this| RouterInner { + path_router: this.path_router.with_state(state.clone()), + fallback_router: this.fallback_router.with_state(state.clone()), + default_fallback: this.default_fallback, + catch_all_fallback: this.catch_all_fallback.with_state(state), + }) } - pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { - let (req, state) = match self.path_router.call_with_state(req, state) { + pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture { + let (req, state) = match self.inner.path_router.call_with_state(req, state) { Ok(future) => return future, Err((req, state)) => (req, state), }; - let (req, state) = match self.fallback_router.call_with_state(req, state) { + let (req, state) = match self.inner.fallback_router.call_with_state(req, state) { Ok(future) => return future, Err((req, state)) => (req, state), }; - self.catch_all_fallback.call_with_state(req, state) + self.inner + .catch_all_fallback + .clone() + .call_with_state(req, state) } /// Convert the router into a borrowed [`Service`] with a fixed request body type, to aid type diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index b4ef4cb4..e9353dc7 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -316,7 +316,7 @@ where } pub(super) fn call_with_state( - &mut self, + &self, mut req: Request, state: S, ) -> Result, (Request, S)> { @@ -349,7 +349,7 @@ where let endpoint = self .routes - .get_mut(&id) + .get(&id) .expect("no route for id. This is a bug in axum. Please file an issue"); match endpoint { diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 4f3cffd7..2bde8c8c 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -1,6 +1,7 @@ use crate::{ body::{Body, HttpBody}, response::Response, + util::AxumMutex, }; use axum_core::{extract::Request, response::IntoResponse}; use bytes::Bytes; @@ -14,7 +15,6 @@ use std::{ fmt, future::Future, pin::Pin, - sync::Mutex, task::{Context, Poll}, }; use tower::{ @@ -28,7 +28,7 @@ use tower_service::Service; /// /// You normally shouldn't need to care about this type. It's used in /// [`Router::layer`](super::Router::layer). -pub struct Route(Mutex>); +pub struct Route(AxumMutex>); impl Route { pub(crate) fn new(svc: T) -> Self @@ -37,7 +37,7 @@ impl Route { T::Response: IntoResponse + 'static, T::Future: Send + 'static, { - Self(Mutex::new(BoxCloneService::new( + Self(AxumMutex::new(BoxCloneService::new( svc.map_response(IntoResponse::into_response), ))) } @@ -70,8 +70,9 @@ impl Route { } impl Clone for Route { + #[track_caller] fn clone(&self) -> Self { - Self(Mutex::new(self.0.lock().unwrap().clone())) + Self(AxumMutex::new(self.0.lock().unwrap().clone())) } } diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 0621156b..3e66eb7a 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -12,6 +12,7 @@ use crate::{ tracing_helpers::{capture_tracing, TracingEvent}, *, }, + util::mutex_num_locked, BoxError, Extension, Json, Router, ServiceExt, }; use axum_core::extract::Request; @@ -951,7 +952,7 @@ async fn state_isnt_cloned_too_much() { client.get("/").await; - assert_eq!(COUNT.load(Ordering::SeqCst), 5); + assert_eq!(COUNT.load(Ordering::SeqCst), 4); } #[crate::test] @@ -1066,3 +1067,35 @@ async fn impl_handler_for_into_response() { assert_eq!(res.status(), StatusCode::CREATED); assert_eq!(res.text().await, "thing created"); } + +#[crate::test] +async fn locks_mutex_very_little() { + let (num, app) = mutex_num_locked(|| async { + Router::new() + .route("/a", get(|| async {})) + .route("/b", get(|| async {})) + .route("/c", get(|| async {})) + .with_state::<()>(()) + .into_service::() + }) + .await; + // once for `Router::new` for setting the default fallback and 3 times, once per route + assert_eq!(num, 4); + + for path in ["/a", "/b", "/c"] { + // calling the router should only lock the mutex once + let (num, _res) = mutex_num_locked(|| async { + // We cannot use `TestClient` because it uses `serve` which spawns a new task per + // connection and `mutex_num_locked` uses a task local to keep track of the number of + // locks. So spawning a new task would unset the task local set by `mutex_num_locked` + // + // So instead `call` the service directly without spawning new tasks. + app.clone() + .oneshot(Request::builder().uri(path).body(Body::empty()).unwrap()) + .await + .unwrap() + }) + .await; + assert_eq!(num, 1); + } +} diff --git a/axum/src/test_helpers/tracing_helpers.rs b/axum/src/test_helpers/tracing_helpers.rs index 3d5cf181..2240717e 100644 --- a/axum/src/test_helpers/tracing_helpers.rs +++ b/axum/src/test_helpers/tracing_helpers.rs @@ -1,8 +1,5 @@ -use std::{ - future::Future, - io, - sync::{Arc, Mutex}, -}; +use crate::util::AxumMutex; +use std::{future::Future, io, sync::Arc}; use serde::{de::DeserializeOwned, Deserialize}; use tracing_subscriber::prelude::*; @@ -50,12 +47,12 @@ where } struct TestMakeWriter { - write: Arc>>>, + write: Arc>>>, } impl TestMakeWriter { fn new() -> (Self, Handle) { - let write = Arc::new(Mutex::new(Some(Vec::::new()))); + let write = Arc::new(AxumMutex::new(Some(Vec::::new()))); ( Self { @@ -97,7 +94,7 @@ impl<'a> io::Write for Writer<'a> { } struct Handle { - write: Arc>>>, + write: Arc>>>, } impl Handle { diff --git a/axum/src/util.rs b/axum/src/util.rs index f7fc6ae1..e3195f3e 100644 --- a/axum/src/util.rs +++ b/axum/src/util.rs @@ -1,6 +1,8 @@ use pin_project_lite::pin_project; use std::{ops::Deref, sync::Arc}; +pub(crate) use self::mutex::*; + #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub(crate) struct PercentDecodedStr(Arc); @@ -59,3 +61,65 @@ fn test_try_downcast() { assert_eq!(try_downcast::(5_u32), Err(5_u32)); assert_eq!(try_downcast::(5_i32), Ok(5_i32)); } + +// `AxumMutex` is a wrapper around `std::sync::Mutex` which, in test mode, tracks the number of +// times it's been locked on the current task. That way we can write a test to ensure we don't +// accidentally introduce more locking. +// +// When not in test mode, it is just a type alias for `std::sync::Mutex`. +#[cfg(not(test))] +mod mutex { + #[allow(clippy::disallowed_types)] + pub(crate) type AxumMutex = std::sync::Mutex; +} + +#[cfg(test)] +mod mutex { + #![allow(clippy::disallowed_types)] + + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + LockResult, Mutex, MutexGuard, + }; + + tokio::task_local! { + pub(crate) static NUM_LOCKED: AtomicUsize; + } + + pub(crate) async fn mutex_num_locked(f: F) -> (usize, Fut::Output) + where + F: FnOnce() -> Fut, + Fut: std::future::IntoFuture, + { + NUM_LOCKED + .scope(AtomicUsize::new(0), async move { + let output = f().await; + let num = NUM_LOCKED.with(|num| num.load(Ordering::SeqCst)); + (num, output) + }) + .await + } + + pub(crate) struct AxumMutex(Mutex); + + impl AxumMutex { + pub(crate) fn new(value: T) -> Self { + Self(Mutex::new(value)) + } + + pub(crate) fn get_mut(&mut self) -> LockResult<&mut T> { + self.0.get_mut() + } + + pub(crate) fn into_inner(self) -> LockResult { + self.0.into_inner() + } + + pub(crate) fn lock(&self) -> LockResult> { + _ = NUM_LOCKED.try_with(|num| { + num.fetch_add(1, Ordering::SeqCst); + }); + self.0.lock() + } + } +}