Internally Arc Router, without breaking changes (#2483)

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
David Pedersen 2024-01-13 13:44:32 +01:00 committed by GitHub
parent d3112a40d5
commit 45116730c6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 277 additions and 125 deletions

View file

@ -8,9 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- **fixed:** Improve `debug_handler` on tuple response types ([#2201])
- **fixed:** Fix performance regression present since axum 0.7.0 ([#2483])
- **added:** Add `must_use` attribute to `Serve` and `WithGracefulShutdown` ([#2484])
- **added:** Re-export `axum_core::body::BodyDataStream` from axum
[#2201]: https://github.com/tokio-rs/axum/pull/2201
[#2483]: https://github.com/tokio-rs/axum/pull/2483
[#2201]: https://github.com/tokio-rs/axum/pull/2201
[#2484]: https://github.com/tokio-rs/axum/pull/2484

3
axum/clippy.toml Normal file
View file

@ -0,0 +1,3 @@
disallowed-types = [
{ path = "std::sync::Mutex", reason = "Use our internal AxumMutex instead" },
]

View file

@ -1,6 +1,7 @@
use std::{convert::Infallible, fmt, sync::Mutex};
use std::{convert::Infallible, fmt};
use crate::extract::Request;
use crate::util::AxumMutex;
use tower::Service;
use crate::{
@ -9,7 +10,7 @@ use crate::{
Router,
};
pub(crate) struct BoxedIntoRoute<S, E>(Mutex<Box<dyn ErasedIntoRoute<S, E>>>);
pub(crate) struct BoxedIntoRoute<S, E>(AxumMutex<Box<dyn ErasedIntoRoute<S, E>>>);
impl<S> BoxedIntoRoute<S, Infallible>
where
@ -20,7 +21,7 @@ where
H: Handler<T, S>,
T: 'static,
{
Self(Mutex::new(Box::new(MakeErasedHandler {
Self(AxumMutex::new(Box::new(MakeErasedHandler {
handler,
into_route: |handler, state| Route::new(Handler::with_state(handler, state)),
})))
@ -35,7 +36,7 @@ impl<S, E> BoxedIntoRoute<S, E> {
F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + 'static,
E2: 'static,
{
BoxedIntoRoute(Mutex::new(Box::new(Map {
BoxedIntoRoute(AxumMutex::new(Box::new(Map {
inner: self.0.into_inner().unwrap(),
layer: Box::new(f),
})))
@ -48,7 +49,7 @@ impl<S, E> BoxedIntoRoute<S, E> {
impl<S, E> Clone for BoxedIntoRoute<S, E> {
fn clone(&self) -> Self {
Self(Mutex::new(self.0.lock().unwrap().clone_box()))
Self(AxumMutex::new(self.0.lock().unwrap().clone_box()))
}
}
@ -118,7 +119,7 @@ where
(self.into_route)(self.router, state)
}
fn call_with_state(mut self: Box<Self>, request: Request, state: S) -> RouteFuture<Infallible> {
fn call_with_state(self: Box<Self>, request: Request, state: S) -> RouteFuture<Infallible> {
self.router.call_with_state(request, state)
}
}

View file

@ -1022,7 +1022,7 @@ where
self
}
pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<E> {
pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<E> {
macro_rules! call {
(
$req:expr,
@ -1034,12 +1034,12 @@ where
match $svc {
MethodEndpoint::None => {}
MethodEndpoint::Route(route) => {
return RouteFuture::from_future(route.oneshot_inner($req))
return RouteFuture::from_future(route.clone().oneshot_inner($req))
.strip_body($method == Method::HEAD);
}
MethodEndpoint::BoxedHandler(handler) => {
let mut route = handler.clone().into_route(state);
return RouteFuture::from_future(route.oneshot_inner($req))
let route = handler.clone().into_route(state);
return RouteFuture::from_future(route.clone().oneshot_inner($req))
.strip_body($method == Method::HEAD);
}
}
@ -1073,7 +1073,7 @@ where
call!(req, method, DELETE, delete);
call!(req, method, TRACE, trace);
let future = fallback.call_with_state(req, state);
let future = fallback.clone().call_with_state(req, state);
match allow_header {
AllowHeader::None => future.allow_header(Bytes::new()),
@ -1219,7 +1219,7 @@ where
{
type Future = InfallibleRouteFuture;
fn call(mut self, req: Request, state: S) -> Self::Future {
fn call(self, req: Request, state: S) -> Self::Future {
InfallibleRouteFuture::new(self.call_with_state(req, state))
}
}

View file

@ -17,6 +17,7 @@ use std::{
convert::Infallible,
fmt,
marker::PhantomData,
sync::Arc,
task::{Context, Poll},
};
use tower_layer::Layer;
@ -59,23 +60,24 @@ pub(crate) struct RouteId(u32);
/// The router type for composing handlers and services.
#[must_use]
pub struct Router<S = ()> {
path_router: PathRouter<S, false>,
fallback_router: PathRouter<S, true>,
default_fallback: bool,
catch_all_fallback: Fallback<S>,
inner: Arc<RouterInner<S>>,
}
impl<S> Clone for Router<S> {
fn clone(&self) -> Self {
Self {
path_router: self.path_router.clone(),
fallback_router: self.fallback_router.clone(),
default_fallback: self.default_fallback,
catch_all_fallback: self.catch_all_fallback.clone(),
inner: Arc::clone(&self.inner),
}
}
}
struct RouterInner<S> {
path_router: PathRouter<S, false>,
fallback_router: PathRouter<S, true>,
default_fallback: bool,
catch_all_fallback: Fallback<S>,
}
impl<S> Default for Router<S>
where
S: Clone + Send + Sync + 'static,
@ -88,10 +90,10 @@ where
impl<S> fmt::Debug for Router<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Router")
.field("path_router", &self.path_router)
.field("fallback_router", &self.fallback_router)
.field("default_fallback", &self.default_fallback)
.field("catch_all_fallback", &self.catch_all_fallback)
.field("path_router", &self.inner.path_router)
.field("fallback_router", &self.inner.fallback_router)
.field("default_fallback", &self.inner.default_fallback)
.field("catch_all_fallback", &self.inner.catch_all_fallback)
.finish()
}
}
@ -111,22 +113,57 @@ where
/// all requests.
pub fn new() -> Self {
Self {
inner: Arc::new(RouterInner {
path_router: Default::default(),
fallback_router: PathRouter::new_fallback(),
default_fallback: true,
catch_all_fallback: Fallback::Default(Route::new(NotFound)),
}),
}
}
fn map_inner<F, S2>(self, f: F) -> Router<S2>
where
F: FnOnce(RouterInner<S>) -> RouterInner<S2>,
{
Router {
inner: Arc::new(f(self.into_inner())),
}
}
fn tap_inner_mut<F>(self, f: F) -> Self
where
F: FnOnce(&mut RouterInner<S>),
{
let mut inner = self.into_inner();
f(&mut inner);
Router {
inner: Arc::new(inner),
}
}
fn into_inner(self) -> RouterInner<S> {
match Arc::try_unwrap(self.inner) {
Ok(inner) => inner,
Err(arc) => RouterInner {
path_router: arc.path_router.clone(),
fallback_router: arc.fallback_router.clone(),
default_fallback: arc.default_fallback,
catch_all_fallback: arc.catch_all_fallback.clone(),
},
}
}
#[doc = include_str!("../docs/routing/route.md")]
#[track_caller]
pub fn route(mut self, path: &str, method_router: MethodRouter<S>) -> Self {
panic_on_err!(self.path_router.route(path, method_router));
self
pub fn route(self, path: &str, method_router: MethodRouter<S>) -> Self {
self.tap_inner_mut(|this| {
panic_on_err!(this.path_router.route(path, method_router));
})
}
#[doc = include_str!("../docs/routing/route_service.md")]
pub fn route_service<T>(mut self, path: &str, service: T) -> Self
pub fn route_service<T>(self, path: &str, service: T) -> Self
where
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
@ -142,14 +179,15 @@ where
Err(service) => service,
};
panic_on_err!(self.path_router.route_service(path, service));
self
self.tap_inner_mut(|this| {
panic_on_err!(this.path_router.route_service(path, service));
})
}
#[doc = include_str!("../docs/routing/nest.md")]
#[track_caller]
pub fn nest(mut self, path: &str, router: Router<S>) -> Self {
let Router {
pub fn nest(self, path: &str, router: Router<S>) -> Self {
let RouterInner {
path_router,
fallback_router,
default_fallback,
@ -157,63 +195,66 @@ where
// requests with an empty path. If we were to inherit the catch-all fallback
// it would end up matching `/{path}/*` which doesn't match empty paths.
catch_all_fallback: _,
} = router;
} = router.into_inner();
panic_on_err!(self.path_router.nest(path, path_router));
self.tap_inner_mut(|this| {
panic_on_err!(this.path_router.nest(path, path_router));
if !default_fallback {
panic_on_err!(self.fallback_router.nest(path, fallback_router));
panic_on_err!(this.fallback_router.nest(path, fallback_router));
}
self
})
}
/// Like [`nest`](Self::nest), but accepts an arbitrary `Service`.
#[track_caller]
pub fn nest_service<T>(mut self, path: &str, service: T) -> Self
pub fn nest_service<T>(self, path: &str, service: T) -> Self
where
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
panic_on_err!(self.path_router.nest_service(path, service));
self
self.tap_inner_mut(|this| {
panic_on_err!(this.path_router.nest_service(path, service));
})
}
#[doc = include_str!("../docs/routing/merge.md")]
#[track_caller]
pub fn merge<R>(mut self, other: R) -> Self
pub fn merge<R>(self, other: R) -> Self
where
R: Into<Router<S>>,
{
const PANIC_MSG: &str =
"Failed to merge fallbacks. This is a bug in axum. Please file an issue";
let Router {
let other: Router<S> = other.into();
let RouterInner {
path_router,
fallback_router: mut other_fallback,
default_fallback,
catch_all_fallback,
} = other.into();
} = other.into_inner();
panic_on_err!(self.path_router.merge(path_router));
self.map_inner(|mut this| {
panic_on_err!(this.path_router.merge(path_router));
match (self.default_fallback, default_fallback) {
match (this.default_fallback, default_fallback) {
// both have the default fallback
// use the one from other
(true, true) => {
self.fallback_router.merge(other_fallback).expect(PANIC_MSG);
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
}
// self has default fallback, other has a custom fallback
// this has default fallback, other has a custom fallback
(true, false) => {
self.fallback_router.merge(other_fallback).expect(PANIC_MSG);
self.default_fallback = false;
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
this.default_fallback = false;
}
// self has a custom fallback, other has a default
// this has a custom fallback, other has a default
(false, true) => {
let fallback_router = std::mem::take(&mut self.fallback_router);
let fallback_router = std::mem::take(&mut this.fallback_router);
other_fallback.merge(fallback_router).expect(PANIC_MSG);
self.fallback_router = other_fallback;
this.fallback_router = other_fallback;
}
// both have a custom fallback, not allowed
(false, false) => {
@ -221,12 +262,13 @@ where
}
};
self.catch_all_fallback = self
this.catch_all_fallback = this
.catch_all_fallback
.merge(catch_all_fallback)
.unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback"));
self
this
})
}
#[doc = include_str!("../docs/routing/layer.md")]
@ -238,12 +280,12 @@ where
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request>>::Future: Send + 'static,
{
Router {
path_router: self.path_router.layer(layer.clone()),
fallback_router: self.fallback_router.layer(layer.clone()),
default_fallback: self.default_fallback,
catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)),
}
self.map_inner(|this| RouterInner {
path_router: this.path_router.layer(layer.clone()),
fallback_router: this.fallback_router.layer(layer.clone()),
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback.map(|route| route.layer(layer)),
})
}
#[doc = include_str!("../docs/routing/route_layer.md")]
@ -256,68 +298,76 @@ where
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request>>::Future: Send + 'static,
{
Router {
path_router: self.path_router.route_layer(layer),
fallback_router: self.fallback_router,
default_fallback: self.default_fallback,
catch_all_fallback: self.catch_all_fallback,
}
self.map_inner(|this| RouterInner {
path_router: this.path_router.route_layer(layer),
fallback_router: this.fallback_router,
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback,
})
}
#[track_caller]
#[doc = include_str!("../docs/routing/fallback.md")]
pub fn fallback<H, T>(mut self, handler: H) -> Self
pub fn fallback<H, T>(self, handler: H) -> Self
where
H: Handler<T, S>,
T: 'static,
{
self.catch_all_fallback =
self.tap_inner_mut(|this| {
this.catch_all_fallback =
Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone()));
self.fallback_endpoint(Endpoint::MethodRouter(any(handler)))
})
.fallback_endpoint(Endpoint::MethodRouter(any(handler)))
}
/// Add a fallback [`Service`] to the router.
///
/// See [`Router::fallback`] for more details.
pub fn fallback_service<T>(mut self, service: T) -> Self
pub fn fallback_service<T>(self, service: T) -> Self
where
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
let route = Route::new(service);
self.catch_all_fallback = Fallback::Service(route.clone());
self.fallback_endpoint(Endpoint::Route(route))
self.tap_inner_mut(|this| {
this.catch_all_fallback = Fallback::Service(route.clone());
})
.fallback_endpoint(Endpoint::Route(route))
}
fn fallback_endpoint(mut self, endpoint: Endpoint<S>) -> Self {
self.fallback_router.set_fallback(endpoint);
self.default_fallback = false;
self
fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self {
self.tap_inner_mut(|this| {
this.fallback_router.set_fallback(endpoint);
this.default_fallback = false;
})
}
#[doc = include_str!("../docs/routing/with_state.md")]
pub fn with_state<S2>(self, state: S) -> Router<S2> {
Router {
path_router: self.path_router.with_state(state.clone()),
fallback_router: self.fallback_router.with_state(state.clone()),
default_fallback: self.default_fallback,
catch_all_fallback: self.catch_all_fallback.with_state(state),
}
self.map_inner(|this| RouterInner {
path_router: this.path_router.with_state(state.clone()),
fallback_router: this.fallback_router.with_state(state.clone()),
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback.with_state(state),
})
}
pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<Infallible> {
let (req, state) = match self.path_router.call_with_state(req, state) {
pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<Infallible> {
let (req, state) = match self.inner.path_router.call_with_state(req, state) {
Ok(future) => return future,
Err((req, state)) => (req, state),
};
let (req, state) = match self.fallback_router.call_with_state(req, state) {
let (req, state) = match self.inner.fallback_router.call_with_state(req, state) {
Ok(future) => return future,
Err((req, state)) => (req, state),
};
self.catch_all_fallback.call_with_state(req, state)
self.inner
.catch_all_fallback
.clone()
.call_with_state(req, state)
}
/// Convert the router into a borrowed [`Service`] with a fixed request body type, to aid type

View file

@ -316,7 +316,7 @@ where
}
pub(super) fn call_with_state(
&mut self,
&self,
mut req: Request,
state: S,
) -> Result<RouteFuture<Infallible>, (Request, S)> {
@ -349,7 +349,7 @@ where
let endpoint = self
.routes
.get_mut(&id)
.get(&id)
.expect("no route for id. This is a bug in axum. Please file an issue");
match endpoint {

View file

@ -1,6 +1,7 @@
use crate::{
body::{Body, HttpBody},
response::Response,
util::AxumMutex,
};
use axum_core::{extract::Request, response::IntoResponse};
use bytes::Bytes;
@ -14,7 +15,6 @@ use std::{
fmt,
future::Future,
pin::Pin,
sync::Mutex,
task::{Context, Poll},
};
use tower::{
@ -28,7 +28,7 @@ use tower_service::Service;
///
/// You normally shouldn't need to care about this type. It's used in
/// [`Router::layer`](super::Router::layer).
pub struct Route<E = Infallible>(Mutex<BoxCloneService<Request, Response, E>>);
pub struct Route<E = Infallible>(AxumMutex<BoxCloneService<Request, Response, E>>);
impl<E> Route<E> {
pub(crate) fn new<T>(svc: T) -> Self
@ -37,7 +37,7 @@ impl<E> Route<E> {
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
{
Self(Mutex::new(BoxCloneService::new(
Self(AxumMutex::new(BoxCloneService::new(
svc.map_response(IntoResponse::into_response),
)))
}
@ -70,8 +70,9 @@ impl<E> Route<E> {
}
impl<E> Clone for Route<E> {
#[track_caller]
fn clone(&self) -> Self {
Self(Mutex::new(self.0.lock().unwrap().clone()))
Self(AxumMutex::new(self.0.lock().unwrap().clone()))
}
}

View file

@ -12,6 +12,7 @@ use crate::{
tracing_helpers::{capture_tracing, TracingEvent},
*,
},
util::mutex_num_locked,
BoxError, Extension, Json, Router, ServiceExt,
};
use axum_core::extract::Request;
@ -951,7 +952,7 @@ async fn state_isnt_cloned_too_much() {
client.get("/").await;
assert_eq!(COUNT.load(Ordering::SeqCst), 5);
assert_eq!(COUNT.load(Ordering::SeqCst), 4);
}
#[crate::test]
@ -1066,3 +1067,35 @@ async fn impl_handler_for_into_response() {
assert_eq!(res.status(), StatusCode::CREATED);
assert_eq!(res.text().await, "thing created");
}
#[crate::test]
async fn locks_mutex_very_little() {
let (num, app) = mutex_num_locked(|| async {
Router::new()
.route("/a", get(|| async {}))
.route("/b", get(|| async {}))
.route("/c", get(|| async {}))
.with_state::<()>(())
.into_service::<Body>()
})
.await;
// once for `Router::new` for setting the default fallback and 3 times, once per route
assert_eq!(num, 4);
for path in ["/a", "/b", "/c"] {
// calling the router should only lock the mutex once
let (num, _res) = mutex_num_locked(|| async {
// We cannot use `TestClient` because it uses `serve` which spawns a new task per
// connection and `mutex_num_locked` uses a task local to keep track of the number of
// locks. So spawning a new task would unset the task local set by `mutex_num_locked`
//
// So instead `call` the service directly without spawning new tasks.
app.clone()
.oneshot(Request::builder().uri(path).body(Body::empty()).unwrap())
.await
.unwrap()
})
.await;
assert_eq!(num, 1);
}
}

View file

@ -1,8 +1,5 @@
use std::{
future::Future,
io,
sync::{Arc, Mutex},
};
use crate::util::AxumMutex;
use std::{future::Future, io, sync::Arc};
use serde::{de::DeserializeOwned, Deserialize};
use tracing_subscriber::prelude::*;
@ -50,12 +47,12 @@ where
}
struct TestMakeWriter {
write: Arc<Mutex<Option<Vec<u8>>>>,
write: Arc<AxumMutex<Option<Vec<u8>>>>,
}
impl TestMakeWriter {
fn new() -> (Self, Handle) {
let write = Arc::new(Mutex::new(Some(Vec::<u8>::new())));
let write = Arc::new(AxumMutex::new(Some(Vec::<u8>::new())));
(
Self {
@ -97,7 +94,7 @@ impl<'a> io::Write for Writer<'a> {
}
struct Handle {
write: Arc<Mutex<Option<Vec<u8>>>>,
write: Arc<AxumMutex<Option<Vec<u8>>>>,
}
impl Handle {

View file

@ -1,6 +1,8 @@
use pin_project_lite::pin_project;
use std::{ops::Deref, sync::Arc};
pub(crate) use self::mutex::*;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct PercentDecodedStr(Arc<str>);
@ -59,3 +61,65 @@ fn test_try_downcast() {
assert_eq!(try_downcast::<i32, _>(5_u32), Err(5_u32));
assert_eq!(try_downcast::<i32, _>(5_i32), Ok(5_i32));
}
// `AxumMutex` is a wrapper around `std::sync::Mutex` which, in test mode, tracks the number of
// times it's been locked on the current task. That way we can write a test to ensure we don't
// accidentally introduce more locking.
//
// When not in test mode, it is just a type alias for `std::sync::Mutex`.
#[cfg(not(test))]
mod mutex {
#[allow(clippy::disallowed_types)]
pub(crate) type AxumMutex<T> = std::sync::Mutex<T>;
}
#[cfg(test)]
mod mutex {
#![allow(clippy::disallowed_types)]
use std::sync::{
atomic::{AtomicUsize, Ordering},
LockResult, Mutex, MutexGuard,
};
tokio::task_local! {
pub(crate) static NUM_LOCKED: AtomicUsize;
}
pub(crate) async fn mutex_num_locked<F, Fut>(f: F) -> (usize, Fut::Output)
where
F: FnOnce() -> Fut,
Fut: std::future::IntoFuture,
{
NUM_LOCKED
.scope(AtomicUsize::new(0), async move {
let output = f().await;
let num = NUM_LOCKED.with(|num| num.load(Ordering::SeqCst));
(num, output)
})
.await
}
pub(crate) struct AxumMutex<T>(Mutex<T>);
impl<T> AxumMutex<T> {
pub(crate) fn new(value: T) -> Self {
Self(Mutex::new(value))
}
pub(crate) fn get_mut(&mut self) -> LockResult<&mut T> {
self.0.get_mut()
}
pub(crate) fn into_inner(self) -> LockResult<T> {
self.0.into_inner()
}
pub(crate) fn lock(&self) -> LockResult<MutexGuard<'_, T>> {
_ = NUM_LOCKED.try_with(|num| {
num.fetch_add(1, Ordering::SeqCst);
});
self.0.lock()
}
}
}