mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-24 16:17:56 +01:00
Internally Arc
Router
, without breaking changes (#2483)
Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
parent
d3112a40d5
commit
45116730c6
10 changed files with 277 additions and 125 deletions
|
@ -8,9 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
# Unreleased
|
||||
|
||||
- **fixed:** Improve `debug_handler` on tuple response types ([#2201])
|
||||
- **fixed:** Fix performance regression present since axum 0.7.0 ([#2483])
|
||||
- **added:** Add `must_use` attribute to `Serve` and `WithGracefulShutdown` ([#2484])
|
||||
- **added:** Re-export `axum_core::body::BodyDataStream` from axum
|
||||
|
||||
[#2201]: https://github.com/tokio-rs/axum/pull/2201
|
||||
[#2483]: https://github.com/tokio-rs/axum/pull/2483
|
||||
[#2201]: https://github.com/tokio-rs/axum/pull/2201
|
||||
[#2484]: https://github.com/tokio-rs/axum/pull/2484
|
||||
|
||||
|
|
3
axum/clippy.toml
Normal file
3
axum/clippy.toml
Normal file
|
@ -0,0 +1,3 @@
|
|||
disallowed-types = [
|
||||
{ path = "std::sync::Mutex", reason = "Use our internal AxumMutex instead" },
|
||||
]
|
|
@ -1,6 +1,7 @@
|
|||
use std::{convert::Infallible, fmt, sync::Mutex};
|
||||
use std::{convert::Infallible, fmt};
|
||||
|
||||
use crate::extract::Request;
|
||||
use crate::util::AxumMutex;
|
||||
use tower::Service;
|
||||
|
||||
use crate::{
|
||||
|
@ -9,7 +10,7 @@ use crate::{
|
|||
Router,
|
||||
};
|
||||
|
||||
pub(crate) struct BoxedIntoRoute<S, E>(Mutex<Box<dyn ErasedIntoRoute<S, E>>>);
|
||||
pub(crate) struct BoxedIntoRoute<S, E>(AxumMutex<Box<dyn ErasedIntoRoute<S, E>>>);
|
||||
|
||||
impl<S> BoxedIntoRoute<S, Infallible>
|
||||
where
|
||||
|
@ -20,7 +21,7 @@ where
|
|||
H: Handler<T, S>,
|
||||
T: 'static,
|
||||
{
|
||||
Self(Mutex::new(Box::new(MakeErasedHandler {
|
||||
Self(AxumMutex::new(Box::new(MakeErasedHandler {
|
||||
handler,
|
||||
into_route: |handler, state| Route::new(Handler::with_state(handler, state)),
|
||||
})))
|
||||
|
@ -35,7 +36,7 @@ impl<S, E> BoxedIntoRoute<S, E> {
|
|||
F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + 'static,
|
||||
E2: 'static,
|
||||
{
|
||||
BoxedIntoRoute(Mutex::new(Box::new(Map {
|
||||
BoxedIntoRoute(AxumMutex::new(Box::new(Map {
|
||||
inner: self.0.into_inner().unwrap(),
|
||||
layer: Box::new(f),
|
||||
})))
|
||||
|
@ -48,7 +49,7 @@ impl<S, E> BoxedIntoRoute<S, E> {
|
|||
|
||||
impl<S, E> Clone for BoxedIntoRoute<S, E> {
|
||||
fn clone(&self) -> Self {
|
||||
Self(Mutex::new(self.0.lock().unwrap().clone_box()))
|
||||
Self(AxumMutex::new(self.0.lock().unwrap().clone_box()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -118,7 +119,7 @@ where
|
|||
(self.into_route)(self.router, state)
|
||||
}
|
||||
|
||||
fn call_with_state(mut self: Box<Self>, request: Request, state: S) -> RouteFuture<Infallible> {
|
||||
fn call_with_state(self: Box<Self>, request: Request, state: S) -> RouteFuture<Infallible> {
|
||||
self.router.call_with_state(request, state)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1022,7 +1022,7 @@ where
|
|||
self
|
||||
}
|
||||
|
||||
pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<E> {
|
||||
pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<E> {
|
||||
macro_rules! call {
|
||||
(
|
||||
$req:expr,
|
||||
|
@ -1034,12 +1034,12 @@ where
|
|||
match $svc {
|
||||
MethodEndpoint::None => {}
|
||||
MethodEndpoint::Route(route) => {
|
||||
return RouteFuture::from_future(route.oneshot_inner($req))
|
||||
return RouteFuture::from_future(route.clone().oneshot_inner($req))
|
||||
.strip_body($method == Method::HEAD);
|
||||
}
|
||||
MethodEndpoint::BoxedHandler(handler) => {
|
||||
let mut route = handler.clone().into_route(state);
|
||||
return RouteFuture::from_future(route.oneshot_inner($req))
|
||||
let route = handler.clone().into_route(state);
|
||||
return RouteFuture::from_future(route.clone().oneshot_inner($req))
|
||||
.strip_body($method == Method::HEAD);
|
||||
}
|
||||
}
|
||||
|
@ -1073,7 +1073,7 @@ where
|
|||
call!(req, method, DELETE, delete);
|
||||
call!(req, method, TRACE, trace);
|
||||
|
||||
let future = fallback.call_with_state(req, state);
|
||||
let future = fallback.clone().call_with_state(req, state);
|
||||
|
||||
match allow_header {
|
||||
AllowHeader::None => future.allow_header(Bytes::new()),
|
||||
|
@ -1219,7 +1219,7 @@ where
|
|||
{
|
||||
type Future = InfallibleRouteFuture;
|
||||
|
||||
fn call(mut self, req: Request, state: S) -> Self::Future {
|
||||
fn call(self, req: Request, state: S) -> Self::Future {
|
||||
InfallibleRouteFuture::new(self.call_with_state(req, state))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ use std::{
|
|||
convert::Infallible,
|
||||
fmt,
|
||||
marker::PhantomData,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tower_layer::Layer;
|
||||
|
@ -59,23 +60,24 @@ pub(crate) struct RouteId(u32);
|
|||
/// The router type for composing handlers and services.
|
||||
#[must_use]
|
||||
pub struct Router<S = ()> {
|
||||
path_router: PathRouter<S, false>,
|
||||
fallback_router: PathRouter<S, true>,
|
||||
default_fallback: bool,
|
||||
catch_all_fallback: Fallback<S>,
|
||||
inner: Arc<RouterInner<S>>,
|
||||
}
|
||||
|
||||
impl<S> Clone for Router<S> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
path_router: self.path_router.clone(),
|
||||
fallback_router: self.fallback_router.clone(),
|
||||
default_fallback: self.default_fallback,
|
||||
catch_all_fallback: self.catch_all_fallback.clone(),
|
||||
inner: Arc::clone(&self.inner),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RouterInner<S> {
|
||||
path_router: PathRouter<S, false>,
|
||||
fallback_router: PathRouter<S, true>,
|
||||
default_fallback: bool,
|
||||
catch_all_fallback: Fallback<S>,
|
||||
}
|
||||
|
||||
impl<S> Default for Router<S>
|
||||
where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
|
@ -88,10 +90,10 @@ where
|
|||
impl<S> fmt::Debug for Router<S> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Router")
|
||||
.field("path_router", &self.path_router)
|
||||
.field("fallback_router", &self.fallback_router)
|
||||
.field("default_fallback", &self.default_fallback)
|
||||
.field("catch_all_fallback", &self.catch_all_fallback)
|
||||
.field("path_router", &self.inner.path_router)
|
||||
.field("fallback_router", &self.inner.fallback_router)
|
||||
.field("default_fallback", &self.inner.default_fallback)
|
||||
.field("catch_all_fallback", &self.inner.catch_all_fallback)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
@ -111,22 +113,57 @@ where
|
|||
/// all requests.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: Arc::new(RouterInner {
|
||||
path_router: Default::default(),
|
||||
fallback_router: PathRouter::new_fallback(),
|
||||
default_fallback: true,
|
||||
catch_all_fallback: Fallback::Default(Route::new(NotFound)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn map_inner<F, S2>(self, f: F) -> Router<S2>
|
||||
where
|
||||
F: FnOnce(RouterInner<S>) -> RouterInner<S2>,
|
||||
{
|
||||
Router {
|
||||
inner: Arc::new(f(self.into_inner())),
|
||||
}
|
||||
}
|
||||
|
||||
fn tap_inner_mut<F>(self, f: F) -> Self
|
||||
where
|
||||
F: FnOnce(&mut RouterInner<S>),
|
||||
{
|
||||
let mut inner = self.into_inner();
|
||||
f(&mut inner);
|
||||
Router {
|
||||
inner: Arc::new(inner),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_inner(self) -> RouterInner<S> {
|
||||
match Arc::try_unwrap(self.inner) {
|
||||
Ok(inner) => inner,
|
||||
Err(arc) => RouterInner {
|
||||
path_router: arc.path_router.clone(),
|
||||
fallback_router: arc.fallback_router.clone(),
|
||||
default_fallback: arc.default_fallback,
|
||||
catch_all_fallback: arc.catch_all_fallback.clone(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/route.md")]
|
||||
#[track_caller]
|
||||
pub fn route(mut self, path: &str, method_router: MethodRouter<S>) -> Self {
|
||||
panic_on_err!(self.path_router.route(path, method_router));
|
||||
self
|
||||
pub fn route(self, path: &str, method_router: MethodRouter<S>) -> Self {
|
||||
self.tap_inner_mut(|this| {
|
||||
panic_on_err!(this.path_router.route(path, method_router));
|
||||
})
|
||||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/route_service.md")]
|
||||
pub fn route_service<T>(mut self, path: &str, service: T) -> Self
|
||||
pub fn route_service<T>(self, path: &str, service: T) -> Self
|
||||
where
|
||||
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
|
||||
T::Response: IntoResponse,
|
||||
|
@ -142,14 +179,15 @@ where
|
|||
Err(service) => service,
|
||||
};
|
||||
|
||||
panic_on_err!(self.path_router.route_service(path, service));
|
||||
self
|
||||
self.tap_inner_mut(|this| {
|
||||
panic_on_err!(this.path_router.route_service(path, service));
|
||||
})
|
||||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/nest.md")]
|
||||
#[track_caller]
|
||||
pub fn nest(mut self, path: &str, router: Router<S>) -> Self {
|
||||
let Router {
|
||||
pub fn nest(self, path: &str, router: Router<S>) -> Self {
|
||||
let RouterInner {
|
||||
path_router,
|
||||
fallback_router,
|
||||
default_fallback,
|
||||
|
@ -157,63 +195,66 @@ where
|
|||
// requests with an empty path. If we were to inherit the catch-all fallback
|
||||
// it would end up matching `/{path}/*` which doesn't match empty paths.
|
||||
catch_all_fallback: _,
|
||||
} = router;
|
||||
} = router.into_inner();
|
||||
|
||||
panic_on_err!(self.path_router.nest(path, path_router));
|
||||
self.tap_inner_mut(|this| {
|
||||
panic_on_err!(this.path_router.nest(path, path_router));
|
||||
|
||||
if !default_fallback {
|
||||
panic_on_err!(self.fallback_router.nest(path, fallback_router));
|
||||
panic_on_err!(this.fallback_router.nest(path, fallback_router));
|
||||
}
|
||||
|
||||
self
|
||||
})
|
||||
}
|
||||
|
||||
/// Like [`nest`](Self::nest), but accepts an arbitrary `Service`.
|
||||
#[track_caller]
|
||||
pub fn nest_service<T>(mut self, path: &str, service: T) -> Self
|
||||
pub fn nest_service<T>(self, path: &str, service: T) -> Self
|
||||
where
|
||||
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
|
||||
T::Response: IntoResponse,
|
||||
T::Future: Send + 'static,
|
||||
{
|
||||
panic_on_err!(self.path_router.nest_service(path, service));
|
||||
self
|
||||
self.tap_inner_mut(|this| {
|
||||
panic_on_err!(this.path_router.nest_service(path, service));
|
||||
})
|
||||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/merge.md")]
|
||||
#[track_caller]
|
||||
pub fn merge<R>(mut self, other: R) -> Self
|
||||
pub fn merge<R>(self, other: R) -> Self
|
||||
where
|
||||
R: Into<Router<S>>,
|
||||
{
|
||||
const PANIC_MSG: &str =
|
||||
"Failed to merge fallbacks. This is a bug in axum. Please file an issue";
|
||||
|
||||
let Router {
|
||||
let other: Router<S> = other.into();
|
||||
let RouterInner {
|
||||
path_router,
|
||||
fallback_router: mut other_fallback,
|
||||
default_fallback,
|
||||
catch_all_fallback,
|
||||
} = other.into();
|
||||
} = other.into_inner();
|
||||
|
||||
panic_on_err!(self.path_router.merge(path_router));
|
||||
self.map_inner(|mut this| {
|
||||
panic_on_err!(this.path_router.merge(path_router));
|
||||
|
||||
match (self.default_fallback, default_fallback) {
|
||||
match (this.default_fallback, default_fallback) {
|
||||
// both have the default fallback
|
||||
// use the one from other
|
||||
(true, true) => {
|
||||
self.fallback_router.merge(other_fallback).expect(PANIC_MSG);
|
||||
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
|
||||
}
|
||||
// self has default fallback, other has a custom fallback
|
||||
// this has default fallback, other has a custom fallback
|
||||
(true, false) => {
|
||||
self.fallback_router.merge(other_fallback).expect(PANIC_MSG);
|
||||
self.default_fallback = false;
|
||||
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
|
||||
this.default_fallback = false;
|
||||
}
|
||||
// self has a custom fallback, other has a default
|
||||
// this has a custom fallback, other has a default
|
||||
(false, true) => {
|
||||
let fallback_router = std::mem::take(&mut self.fallback_router);
|
||||
let fallback_router = std::mem::take(&mut this.fallback_router);
|
||||
other_fallback.merge(fallback_router).expect(PANIC_MSG);
|
||||
self.fallback_router = other_fallback;
|
||||
this.fallback_router = other_fallback;
|
||||
}
|
||||
// both have a custom fallback, not allowed
|
||||
(false, false) => {
|
||||
|
@ -221,12 +262,13 @@ where
|
|||
}
|
||||
};
|
||||
|
||||
self.catch_all_fallback = self
|
||||
this.catch_all_fallback = this
|
||||
.catch_all_fallback
|
||||
.merge(catch_all_fallback)
|
||||
.unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback"));
|
||||
|
||||
self
|
||||
this
|
||||
})
|
||||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/layer.md")]
|
||||
|
@ -238,12 +280,12 @@ where
|
|||
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
|
||||
<L::Service as Service<Request>>::Future: Send + 'static,
|
||||
{
|
||||
Router {
|
||||
path_router: self.path_router.layer(layer.clone()),
|
||||
fallback_router: self.fallback_router.layer(layer.clone()),
|
||||
default_fallback: self.default_fallback,
|
||||
catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)),
|
||||
}
|
||||
self.map_inner(|this| RouterInner {
|
||||
path_router: this.path_router.layer(layer.clone()),
|
||||
fallback_router: this.fallback_router.layer(layer.clone()),
|
||||
default_fallback: this.default_fallback,
|
||||
catch_all_fallback: this.catch_all_fallback.map(|route| route.layer(layer)),
|
||||
})
|
||||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/route_layer.md")]
|
||||
|
@ -256,68 +298,76 @@ where
|
|||
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
|
||||
<L::Service as Service<Request>>::Future: Send + 'static,
|
||||
{
|
||||
Router {
|
||||
path_router: self.path_router.route_layer(layer),
|
||||
fallback_router: self.fallback_router,
|
||||
default_fallback: self.default_fallback,
|
||||
catch_all_fallback: self.catch_all_fallback,
|
||||
}
|
||||
self.map_inner(|this| RouterInner {
|
||||
path_router: this.path_router.route_layer(layer),
|
||||
fallback_router: this.fallback_router,
|
||||
default_fallback: this.default_fallback,
|
||||
catch_all_fallback: this.catch_all_fallback,
|
||||
})
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
#[doc = include_str!("../docs/routing/fallback.md")]
|
||||
pub fn fallback<H, T>(mut self, handler: H) -> Self
|
||||
pub fn fallback<H, T>(self, handler: H) -> Self
|
||||
where
|
||||
H: Handler<T, S>,
|
||||
T: 'static,
|
||||
{
|
||||
self.catch_all_fallback =
|
||||
self.tap_inner_mut(|this| {
|
||||
this.catch_all_fallback =
|
||||
Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone()));
|
||||
self.fallback_endpoint(Endpoint::MethodRouter(any(handler)))
|
||||
})
|
||||
.fallback_endpoint(Endpoint::MethodRouter(any(handler)))
|
||||
}
|
||||
|
||||
/// Add a fallback [`Service`] to the router.
|
||||
///
|
||||
/// See [`Router::fallback`] for more details.
|
||||
pub fn fallback_service<T>(mut self, service: T) -> Self
|
||||
pub fn fallback_service<T>(self, service: T) -> Self
|
||||
where
|
||||
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
|
||||
T::Response: IntoResponse,
|
||||
T::Future: Send + 'static,
|
||||
{
|
||||
let route = Route::new(service);
|
||||
self.catch_all_fallback = Fallback::Service(route.clone());
|
||||
self.fallback_endpoint(Endpoint::Route(route))
|
||||
self.tap_inner_mut(|this| {
|
||||
this.catch_all_fallback = Fallback::Service(route.clone());
|
||||
})
|
||||
.fallback_endpoint(Endpoint::Route(route))
|
||||
}
|
||||
|
||||
fn fallback_endpoint(mut self, endpoint: Endpoint<S>) -> Self {
|
||||
self.fallback_router.set_fallback(endpoint);
|
||||
self.default_fallback = false;
|
||||
self
|
||||
fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self {
|
||||
self.tap_inner_mut(|this| {
|
||||
this.fallback_router.set_fallback(endpoint);
|
||||
this.default_fallback = false;
|
||||
})
|
||||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/with_state.md")]
|
||||
pub fn with_state<S2>(self, state: S) -> Router<S2> {
|
||||
Router {
|
||||
path_router: self.path_router.with_state(state.clone()),
|
||||
fallback_router: self.fallback_router.with_state(state.clone()),
|
||||
default_fallback: self.default_fallback,
|
||||
catch_all_fallback: self.catch_all_fallback.with_state(state),
|
||||
}
|
||||
self.map_inner(|this| RouterInner {
|
||||
path_router: this.path_router.with_state(state.clone()),
|
||||
fallback_router: this.fallback_router.with_state(state.clone()),
|
||||
default_fallback: this.default_fallback,
|
||||
catch_all_fallback: this.catch_all_fallback.with_state(state),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<Infallible> {
|
||||
let (req, state) = match self.path_router.call_with_state(req, state) {
|
||||
pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<Infallible> {
|
||||
let (req, state) = match self.inner.path_router.call_with_state(req, state) {
|
||||
Ok(future) => return future,
|
||||
Err((req, state)) => (req, state),
|
||||
};
|
||||
|
||||
let (req, state) = match self.fallback_router.call_with_state(req, state) {
|
||||
let (req, state) = match self.inner.fallback_router.call_with_state(req, state) {
|
||||
Ok(future) => return future,
|
||||
Err((req, state)) => (req, state),
|
||||
};
|
||||
|
||||
self.catch_all_fallback.call_with_state(req, state)
|
||||
self.inner
|
||||
.catch_all_fallback
|
||||
.clone()
|
||||
.call_with_state(req, state)
|
||||
}
|
||||
|
||||
/// Convert the router into a borrowed [`Service`] with a fixed request body type, to aid type
|
||||
|
|
|
@ -316,7 +316,7 @@ where
|
|||
}
|
||||
|
||||
pub(super) fn call_with_state(
|
||||
&mut self,
|
||||
&self,
|
||||
mut req: Request,
|
||||
state: S,
|
||||
) -> Result<RouteFuture<Infallible>, (Request, S)> {
|
||||
|
@ -349,7 +349,7 @@ where
|
|||
|
||||
let endpoint = self
|
||||
.routes
|
||||
.get_mut(&id)
|
||||
.get(&id)
|
||||
.expect("no route for id. This is a bug in axum. Please file an issue");
|
||||
|
||||
match endpoint {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate::{
|
||||
body::{Body, HttpBody},
|
||||
response::Response,
|
||||
util::AxumMutex,
|
||||
};
|
||||
use axum_core::{extract::Request, response::IntoResponse};
|
||||
use bytes::Bytes;
|
||||
|
@ -14,7 +15,6 @@ use std::{
|
|||
fmt,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::Mutex,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tower::{
|
||||
|
@ -28,7 +28,7 @@ use tower_service::Service;
|
|||
///
|
||||
/// You normally shouldn't need to care about this type. It's used in
|
||||
/// [`Router::layer`](super::Router::layer).
|
||||
pub struct Route<E = Infallible>(Mutex<BoxCloneService<Request, Response, E>>);
|
||||
pub struct Route<E = Infallible>(AxumMutex<BoxCloneService<Request, Response, E>>);
|
||||
|
||||
impl<E> Route<E> {
|
||||
pub(crate) fn new<T>(svc: T) -> Self
|
||||
|
@ -37,7 +37,7 @@ impl<E> Route<E> {
|
|||
T::Response: IntoResponse + 'static,
|
||||
T::Future: Send + 'static,
|
||||
{
|
||||
Self(Mutex::new(BoxCloneService::new(
|
||||
Self(AxumMutex::new(BoxCloneService::new(
|
||||
svc.map_response(IntoResponse::into_response),
|
||||
)))
|
||||
}
|
||||
|
@ -70,8 +70,9 @@ impl<E> Route<E> {
|
|||
}
|
||||
|
||||
impl<E> Clone for Route<E> {
|
||||
#[track_caller]
|
||||
fn clone(&self) -> Self {
|
||||
Self(Mutex::new(self.0.lock().unwrap().clone()))
|
||||
Self(AxumMutex::new(self.0.lock().unwrap().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ use crate::{
|
|||
tracing_helpers::{capture_tracing, TracingEvent},
|
||||
*,
|
||||
},
|
||||
util::mutex_num_locked,
|
||||
BoxError, Extension, Json, Router, ServiceExt,
|
||||
};
|
||||
use axum_core::extract::Request;
|
||||
|
@ -951,7 +952,7 @@ async fn state_isnt_cloned_too_much() {
|
|||
|
||||
client.get("/").await;
|
||||
|
||||
assert_eq!(COUNT.load(Ordering::SeqCst), 5);
|
||||
assert_eq!(COUNT.load(Ordering::SeqCst), 4);
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
|
@ -1066,3 +1067,35 @@ async fn impl_handler_for_into_response() {
|
|||
assert_eq!(res.status(), StatusCode::CREATED);
|
||||
assert_eq!(res.text().await, "thing created");
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn locks_mutex_very_little() {
|
||||
let (num, app) = mutex_num_locked(|| async {
|
||||
Router::new()
|
||||
.route("/a", get(|| async {}))
|
||||
.route("/b", get(|| async {}))
|
||||
.route("/c", get(|| async {}))
|
||||
.with_state::<()>(())
|
||||
.into_service::<Body>()
|
||||
})
|
||||
.await;
|
||||
// once for `Router::new` for setting the default fallback and 3 times, once per route
|
||||
assert_eq!(num, 4);
|
||||
|
||||
for path in ["/a", "/b", "/c"] {
|
||||
// calling the router should only lock the mutex once
|
||||
let (num, _res) = mutex_num_locked(|| async {
|
||||
// We cannot use `TestClient` because it uses `serve` which spawns a new task per
|
||||
// connection and `mutex_num_locked` uses a task local to keep track of the number of
|
||||
// locks. So spawning a new task would unset the task local set by `mutex_num_locked`
|
||||
//
|
||||
// So instead `call` the service directly without spawning new tasks.
|
||||
app.clone()
|
||||
.oneshot(Request::builder().uri(path).body(Body::empty()).unwrap())
|
||||
.await
|
||||
.unwrap()
|
||||
})
|
||||
.await;
|
||||
assert_eq!(num, 1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
use std::{
|
||||
future::Future,
|
||||
io,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use crate::util::AxumMutex;
|
||||
use std::{future::Future, io, sync::Arc};
|
||||
|
||||
use serde::{de::DeserializeOwned, Deserialize};
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
@ -50,12 +47,12 @@ where
|
|||
}
|
||||
|
||||
struct TestMakeWriter {
|
||||
write: Arc<Mutex<Option<Vec<u8>>>>,
|
||||
write: Arc<AxumMutex<Option<Vec<u8>>>>,
|
||||
}
|
||||
|
||||
impl TestMakeWriter {
|
||||
fn new() -> (Self, Handle) {
|
||||
let write = Arc::new(Mutex::new(Some(Vec::<u8>::new())));
|
||||
let write = Arc::new(AxumMutex::new(Some(Vec::<u8>::new())));
|
||||
|
||||
(
|
||||
Self {
|
||||
|
@ -97,7 +94,7 @@ impl<'a> io::Write for Writer<'a> {
|
|||
}
|
||||
|
||||
struct Handle {
|
||||
write: Arc<Mutex<Option<Vec<u8>>>>,
|
||||
write: Arc<AxumMutex<Option<Vec<u8>>>>,
|
||||
}
|
||||
|
||||
impl Handle {
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use pin_project_lite::pin_project;
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
|
||||
pub(crate) use self::mutex::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct PercentDecodedStr(Arc<str>);
|
||||
|
||||
|
@ -59,3 +61,65 @@ fn test_try_downcast() {
|
|||
assert_eq!(try_downcast::<i32, _>(5_u32), Err(5_u32));
|
||||
assert_eq!(try_downcast::<i32, _>(5_i32), Ok(5_i32));
|
||||
}
|
||||
|
||||
// `AxumMutex` is a wrapper around `std::sync::Mutex` which, in test mode, tracks the number of
|
||||
// times it's been locked on the current task. That way we can write a test to ensure we don't
|
||||
// accidentally introduce more locking.
|
||||
//
|
||||
// When not in test mode, it is just a type alias for `std::sync::Mutex`.
|
||||
#[cfg(not(test))]
|
||||
mod mutex {
|
||||
#[allow(clippy::disallowed_types)]
|
||||
pub(crate) type AxumMutex<T> = std::sync::Mutex<T>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod mutex {
|
||||
#![allow(clippy::disallowed_types)]
|
||||
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
LockResult, Mutex, MutexGuard,
|
||||
};
|
||||
|
||||
tokio::task_local! {
|
||||
pub(crate) static NUM_LOCKED: AtomicUsize;
|
||||
}
|
||||
|
||||
pub(crate) async fn mutex_num_locked<F, Fut>(f: F) -> (usize, Fut::Output)
|
||||
where
|
||||
F: FnOnce() -> Fut,
|
||||
Fut: std::future::IntoFuture,
|
||||
{
|
||||
NUM_LOCKED
|
||||
.scope(AtomicUsize::new(0), async move {
|
||||
let output = f().await;
|
||||
let num = NUM_LOCKED.with(|num| num.load(Ordering::SeqCst));
|
||||
(num, output)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) struct AxumMutex<T>(Mutex<T>);
|
||||
|
||||
impl<T> AxumMutex<T> {
|
||||
pub(crate) fn new(value: T) -> Self {
|
||||
Self(Mutex::new(value))
|
||||
}
|
||||
|
||||
pub(crate) fn get_mut(&mut self) -> LockResult<&mut T> {
|
||||
self.0.get_mut()
|
||||
}
|
||||
|
||||
pub(crate) fn into_inner(self) -> LockResult<T> {
|
||||
self.0.into_inner()
|
||||
}
|
||||
|
||||
pub(crate) fn lock(&self) -> LockResult<MutexGuard<'_, T>> {
|
||||
_ = NUM_LOCKED.try_with(|num| {
|
||||
num.fetch_add(1, Ordering::SeqCst);
|
||||
});
|
||||
self.0.lock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue