Remove mutexes around boxed services (#2947)

This commit is contained in:
Jonas Platte 2024-09-29 19:04:31 +00:00 committed by GitHub
parent 3eb8854839
commit fb4b1899eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 17 additions and 119 deletions

View file

@ -1,3 +0,0 @@
disallowed-types = [
{ path = "std::sync::Mutex", reason = "Use our internal AxumMutex instead" },
]

View file

@ -1,7 +1,6 @@
use std::{convert::Infallible, fmt};
use crate::extract::Request;
use crate::util::AxumMutex;
use tower::Service;
use crate::{
@ -10,7 +9,7 @@ use crate::{
Router,
};
pub(crate) struct BoxedIntoRoute<S, E>(AxumMutex<Box<dyn ErasedIntoRoute<S, E>>>);
pub(crate) struct BoxedIntoRoute<S, E>(Box<dyn ErasedIntoRoute<S, E>>);
impl<S> BoxedIntoRoute<S, Infallible>
where
@ -21,10 +20,10 @@ where
H: Handler<T, S>,
T: 'static,
{
Self(AxumMutex::new(Box::new(MakeErasedHandler {
Self(Box::new(MakeErasedHandler {
handler,
into_route: |handler, state| Route::new(Handler::with_state(handler, state)),
})))
}))
}
}
@ -36,20 +35,20 @@ impl<S, E> BoxedIntoRoute<S, E> {
F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + Sync + 'static,
E2: 'static,
{
BoxedIntoRoute(AxumMutex::new(Box::new(Map {
inner: self.0.into_inner().unwrap(),
BoxedIntoRoute(Box::new(Map {
inner: self.0,
layer: Box::new(f),
})))
}))
}
pub(crate) fn into_route(self, state: S) -> Route<E> {
self.0.into_inner().unwrap().into_route(state)
self.0.into_route(state)
}
}
impl<S, E> Clone for BoxedIntoRoute<S, E> {
fn clone(&self) -> Self {
Self(AxumMutex::new(self.0.lock().unwrap().clone_box()))
Self(self.0.clone_box())
}
}

View file

@ -2,7 +2,6 @@ use crate::{
body::{Body, HttpBody},
box_clone_service::BoxCloneService,
response::Response,
util::AxumMutex,
};
use axum_core::{extract::Request, response::IntoResponse};
use bytes::Bytes;
@ -29,7 +28,7 @@ use tower_service::Service;
///
/// You normally shouldn't need to care about this type. It's used in
/// [`Router::layer`](super::Router::layer).
pub struct Route<E = Infallible>(AxumMutex<BoxCloneService<Request, Response, E>>);
pub struct Route<E = Infallible>(BoxCloneService<Request, Response, E>);
impl<E> Route<E> {
pub(crate) fn new<T>(svc: T) -> Self
@ -38,16 +37,16 @@ impl<E> Route<E> {
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
{
Self(AxumMutex::new(BoxCloneService::new(
Self(BoxCloneService::new(
svc.map_response(IntoResponse::into_response),
)))
))
}
pub(crate) fn oneshot_inner(
&mut self,
req: Request,
) -> Oneshot<BoxCloneService<Request, Response, E>, Request> {
self.0.get_mut().unwrap().clone().oneshot(req)
self.0.clone().oneshot(req)
}
pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
@ -73,7 +72,7 @@ impl<E> Route<E> {
impl<E> Clone for Route<E> {
#[track_caller]
fn clone(&self) -> Self {
Self(AxumMutex::new(self.0.lock().unwrap().clone()))
Self(self.0.clone())
}
}

View file

@ -12,7 +12,6 @@ use crate::{
tracing_helpers::{capture_tracing, TracingEvent},
*,
},
util::mutex_num_locked,
BoxError, Extension, Json, Router, ServiceExt,
};
use axum_core::extract::Request;
@ -1068,35 +1067,3 @@ async fn impl_handler_for_into_response() {
assert_eq!(res.status(), StatusCode::CREATED);
assert_eq!(res.text().await, "thing created");
}
#[crate::test]
async fn locks_mutex_very_little() {
let (num, app) = mutex_num_locked(|| async {
Router::new()
.route("/a", get(|| async {}))
.route("/b", get(|| async {}))
.route("/c", get(|| async {}))
.with_state::<()>(())
.into_service::<Body>()
})
.await;
// once for `Router::new` for setting the default fallback and 3 times, once per route
assert_eq!(num, 4);
for path in ["/a", "/b", "/c"] {
// calling the router should only lock the mutex once
let (num, _res) = mutex_num_locked(|| async {
// We cannot use `TestClient` because it uses `serve` which spawns a new task per
// connection and `mutex_num_locked` uses a task local to keep track of the number of
// locks. So spawning a new task would unset the task local set by `mutex_num_locked`
//
// So instead `call` the service directly without spawning new tasks.
app.clone()
.oneshot(Request::builder().uri(path).body(Body::empty()).unwrap())
.await
.unwrap()
})
.await;
assert_eq!(num, 1);
}
}

View file

@ -1,10 +1,9 @@
use crate::util::AxumMutex;
use std::{
future::{Future, IntoFuture},
io,
marker::PhantomData,
pin::Pin,
sync::Arc,
sync::{Arc, Mutex},
};
use serde::{de::DeserializeOwned, Deserialize};
@ -87,12 +86,12 @@ where
}
struct TestMakeWriter {
write: Arc<AxumMutex<Option<Vec<u8>>>>,
write: Arc<Mutex<Option<Vec<u8>>>>,
}
impl TestMakeWriter {
fn new() -> (Self, Handle) {
let write = Arc::new(AxumMutex::new(Some(Vec::<u8>::new())));
let write = Arc::new(Mutex::new(Some(Vec::<u8>::new())));
(
Self {
@ -134,7 +133,7 @@ impl<'a> io::Write for Writer<'a> {
}
struct Handle {
write: Arc<AxumMutex<Option<Vec<u8>>>>,
write: Arc<Mutex<Option<Vec<u8>>>>,
}
impl Handle {

View file

@ -1,8 +1,6 @@
use pin_project_lite::pin_project;
use std::{ops::Deref, sync::Arc};
pub(crate) use self::mutex::*;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct PercentDecodedStr(Arc<str>);
@ -57,64 +55,3 @@ fn test_try_downcast() {
assert_eq!(try_downcast::<i32, _>(5_u32), Err(5_u32));
assert_eq!(try_downcast::<i32, _>(5_i32), Ok(5_i32));
}
// `AxumMutex` is a wrapper around `std::sync::Mutex` which, in test mode, tracks the number of
// times it's been locked on the current task. That way we can write a test to ensure we don't
// accidentally introduce more locking.
//
// When not in test mode, it is just a type alias for `std::sync::Mutex`.
#[cfg(not(test))]
mod mutex {
#[allow(clippy::disallowed_types)]
pub(crate) type AxumMutex<T> = std::sync::Mutex<T>;
}
#[cfg(test)]
#[allow(clippy::disallowed_types)]
mod mutex {
use std::sync::{
atomic::{AtomicUsize, Ordering},
LockResult, Mutex, MutexGuard,
};
tokio::task_local! {
pub(crate) static NUM_LOCKED: AtomicUsize;
}
pub(crate) async fn mutex_num_locked<F, Fut>(f: F) -> (usize, Fut::Output)
where
F: FnOnce() -> Fut,
Fut: std::future::IntoFuture,
{
NUM_LOCKED
.scope(AtomicUsize::new(0), async move {
let output = f().await;
let num = NUM_LOCKED.with(|num| num.load(Ordering::SeqCst));
(num, output)
})
.await
}
pub(crate) struct AxumMutex<T>(Mutex<T>);
impl<T> AxumMutex<T> {
pub(crate) fn new(value: T) -> Self {
Self(Mutex::new(value))
}
pub(crate) fn get_mut(&mut self) -> LockResult<&mut T> {
self.0.get_mut()
}
pub(crate) fn into_inner(self) -> LockResult<T> {
self.0.into_inner()
}
pub(crate) fn lock(&self) -> LockResult<MutexGuard<'_, T>> {
_ = NUM_LOCKED.try_with(|num| {
num.fetch_add(1, Ordering::SeqCst);
});
self.0.lock()
}
}
}