mirror of
https://github.com/tokio-rs/axum.git
synced 2024-10-23 17:36:39 +02:00
Internally Arc
Router
, without breaking changes (#2483)
Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
parent
d3112a40d5
commit
45116730c6
10 changed files with 277 additions and 125 deletions
|
@ -8,9 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
# Unreleased
|
# Unreleased
|
||||||
|
|
||||||
- **fixed:** Improve `debug_handler` on tuple response types ([#2201])
|
- **fixed:** Improve `debug_handler` on tuple response types ([#2201])
|
||||||
|
- **fixed:** Fix performance regression present since axum 0.7.0 ([#2483])
|
||||||
- **added:** Add `must_use` attribute to `Serve` and `WithGracefulShutdown` ([#2484])
|
- **added:** Add `must_use` attribute to `Serve` and `WithGracefulShutdown` ([#2484])
|
||||||
- **added:** Re-export `axum_core::body::BodyDataStream` from axum
|
- **added:** Re-export `axum_core::body::BodyDataStream` from axum
|
||||||
|
|
||||||
|
[#2201]: https://github.com/tokio-rs/axum/pull/2201
|
||||||
|
[#2483]: https://github.com/tokio-rs/axum/pull/2483
|
||||||
[#2201]: https://github.com/tokio-rs/axum/pull/2201
|
[#2201]: https://github.com/tokio-rs/axum/pull/2201
|
||||||
[#2484]: https://github.com/tokio-rs/axum/pull/2484
|
[#2484]: https://github.com/tokio-rs/axum/pull/2484
|
||||||
|
|
||||||
|
|
3
axum/clippy.toml
Normal file
3
axum/clippy.toml
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
disallowed-types = [
|
||||||
|
{ path = "std::sync::Mutex", reason = "Use our internal AxumMutex instead" },
|
||||||
|
]
|
|
@ -1,6 +1,7 @@
|
||||||
use std::{convert::Infallible, fmt, sync::Mutex};
|
use std::{convert::Infallible, fmt};
|
||||||
|
|
||||||
use crate::extract::Request;
|
use crate::extract::Request;
|
||||||
|
use crate::util::AxumMutex;
|
||||||
use tower::Service;
|
use tower::Service;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -9,7 +10,7 @@ use crate::{
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub(crate) struct BoxedIntoRoute<S, E>(Mutex<Box<dyn ErasedIntoRoute<S, E>>>);
|
pub(crate) struct BoxedIntoRoute<S, E>(AxumMutex<Box<dyn ErasedIntoRoute<S, E>>>);
|
||||||
|
|
||||||
impl<S> BoxedIntoRoute<S, Infallible>
|
impl<S> BoxedIntoRoute<S, Infallible>
|
||||||
where
|
where
|
||||||
|
@ -20,7 +21,7 @@ where
|
||||||
H: Handler<T, S>,
|
H: Handler<T, S>,
|
||||||
T: 'static,
|
T: 'static,
|
||||||
{
|
{
|
||||||
Self(Mutex::new(Box::new(MakeErasedHandler {
|
Self(AxumMutex::new(Box::new(MakeErasedHandler {
|
||||||
handler,
|
handler,
|
||||||
into_route: |handler, state| Route::new(Handler::with_state(handler, state)),
|
into_route: |handler, state| Route::new(Handler::with_state(handler, state)),
|
||||||
})))
|
})))
|
||||||
|
@ -35,7 +36,7 @@ impl<S, E> BoxedIntoRoute<S, E> {
|
||||||
F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + 'static,
|
F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + 'static,
|
||||||
E2: 'static,
|
E2: 'static,
|
||||||
{
|
{
|
||||||
BoxedIntoRoute(Mutex::new(Box::new(Map {
|
BoxedIntoRoute(AxumMutex::new(Box::new(Map {
|
||||||
inner: self.0.into_inner().unwrap(),
|
inner: self.0.into_inner().unwrap(),
|
||||||
layer: Box::new(f),
|
layer: Box::new(f),
|
||||||
})))
|
})))
|
||||||
|
@ -48,7 +49,7 @@ impl<S, E> BoxedIntoRoute<S, E> {
|
||||||
|
|
||||||
impl<S, E> Clone for BoxedIntoRoute<S, E> {
|
impl<S, E> Clone for BoxedIntoRoute<S, E> {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self(Mutex::new(self.0.lock().unwrap().clone_box()))
|
Self(AxumMutex::new(self.0.lock().unwrap().clone_box()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,7 +119,7 @@ where
|
||||||
(self.into_route)(self.router, state)
|
(self.into_route)(self.router, state)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call_with_state(mut self: Box<Self>, request: Request, state: S) -> RouteFuture<Infallible> {
|
fn call_with_state(self: Box<Self>, request: Request, state: S) -> RouteFuture<Infallible> {
|
||||||
self.router.call_with_state(request, state)
|
self.router.call_with_state(request, state)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1022,7 +1022,7 @@ where
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<E> {
|
pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<E> {
|
||||||
macro_rules! call {
|
macro_rules! call {
|
||||||
(
|
(
|
||||||
$req:expr,
|
$req:expr,
|
||||||
|
@ -1034,12 +1034,12 @@ where
|
||||||
match $svc {
|
match $svc {
|
||||||
MethodEndpoint::None => {}
|
MethodEndpoint::None => {}
|
||||||
MethodEndpoint::Route(route) => {
|
MethodEndpoint::Route(route) => {
|
||||||
return RouteFuture::from_future(route.oneshot_inner($req))
|
return RouteFuture::from_future(route.clone().oneshot_inner($req))
|
||||||
.strip_body($method == Method::HEAD);
|
.strip_body($method == Method::HEAD);
|
||||||
}
|
}
|
||||||
MethodEndpoint::BoxedHandler(handler) => {
|
MethodEndpoint::BoxedHandler(handler) => {
|
||||||
let mut route = handler.clone().into_route(state);
|
let route = handler.clone().into_route(state);
|
||||||
return RouteFuture::from_future(route.oneshot_inner($req))
|
return RouteFuture::from_future(route.clone().oneshot_inner($req))
|
||||||
.strip_body($method == Method::HEAD);
|
.strip_body($method == Method::HEAD);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1073,7 +1073,7 @@ where
|
||||||
call!(req, method, DELETE, delete);
|
call!(req, method, DELETE, delete);
|
||||||
call!(req, method, TRACE, trace);
|
call!(req, method, TRACE, trace);
|
||||||
|
|
||||||
let future = fallback.call_with_state(req, state);
|
let future = fallback.clone().call_with_state(req, state);
|
||||||
|
|
||||||
match allow_header {
|
match allow_header {
|
||||||
AllowHeader::None => future.allow_header(Bytes::new()),
|
AllowHeader::None => future.allow_header(Bytes::new()),
|
||||||
|
@ -1219,7 +1219,7 @@ where
|
||||||
{
|
{
|
||||||
type Future = InfallibleRouteFuture;
|
type Future = InfallibleRouteFuture;
|
||||||
|
|
||||||
fn call(mut self, req: Request, state: S) -> Self::Future {
|
fn call(self, req: Request, state: S) -> Self::Future {
|
||||||
InfallibleRouteFuture::new(self.call_with_state(req, state))
|
InfallibleRouteFuture::new(self.call_with_state(req, state))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@ use std::{
|
||||||
convert::Infallible,
|
convert::Infallible,
|
||||||
fmt,
|
fmt,
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
|
sync::Arc,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
use tower_layer::Layer;
|
use tower_layer::Layer;
|
||||||
|
@ -59,23 +60,24 @@ pub(crate) struct RouteId(u32);
|
||||||
/// The router type for composing handlers and services.
|
/// The router type for composing handlers and services.
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub struct Router<S = ()> {
|
pub struct Router<S = ()> {
|
||||||
path_router: PathRouter<S, false>,
|
inner: Arc<RouterInner<S>>,
|
||||||
fallback_router: PathRouter<S, true>,
|
|
||||||
default_fallback: bool,
|
|
||||||
catch_all_fallback: Fallback<S>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> Clone for Router<S> {
|
impl<S> Clone for Router<S> {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
path_router: self.path_router.clone(),
|
inner: Arc::clone(&self.inner),
|
||||||
fallback_router: self.fallback_router.clone(),
|
|
||||||
default_fallback: self.default_fallback,
|
|
||||||
catch_all_fallback: self.catch_all_fallback.clone(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct RouterInner<S> {
|
||||||
|
path_router: PathRouter<S, false>,
|
||||||
|
fallback_router: PathRouter<S, true>,
|
||||||
|
default_fallback: bool,
|
||||||
|
catch_all_fallback: Fallback<S>,
|
||||||
|
}
|
||||||
|
|
||||||
impl<S> Default for Router<S>
|
impl<S> Default for Router<S>
|
||||||
where
|
where
|
||||||
S: Clone + Send + Sync + 'static,
|
S: Clone + Send + Sync + 'static,
|
||||||
|
@ -88,10 +90,10 @@ where
|
||||||
impl<S> fmt::Debug for Router<S> {
|
impl<S> fmt::Debug for Router<S> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("Router")
|
f.debug_struct("Router")
|
||||||
.field("path_router", &self.path_router)
|
.field("path_router", &self.inner.path_router)
|
||||||
.field("fallback_router", &self.fallback_router)
|
.field("fallback_router", &self.inner.fallback_router)
|
||||||
.field("default_fallback", &self.default_fallback)
|
.field("default_fallback", &self.inner.default_fallback)
|
||||||
.field("catch_all_fallback", &self.catch_all_fallback)
|
.field("catch_all_fallback", &self.inner.catch_all_fallback)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -111,22 +113,57 @@ where
|
||||||
/// all requests.
|
/// all requests.
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
inner: Arc::new(RouterInner {
|
||||||
path_router: Default::default(),
|
path_router: Default::default(),
|
||||||
fallback_router: PathRouter::new_fallback(),
|
fallback_router: PathRouter::new_fallback(),
|
||||||
default_fallback: true,
|
default_fallback: true,
|
||||||
catch_all_fallback: Fallback::Default(Route::new(NotFound)),
|
catch_all_fallback: Fallback::Default(Route::new(NotFound)),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_inner<F, S2>(self, f: F) -> Router<S2>
|
||||||
|
where
|
||||||
|
F: FnOnce(RouterInner<S>) -> RouterInner<S2>,
|
||||||
|
{
|
||||||
|
Router {
|
||||||
|
inner: Arc::new(f(self.into_inner())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tap_inner_mut<F>(self, f: F) -> Self
|
||||||
|
where
|
||||||
|
F: FnOnce(&mut RouterInner<S>),
|
||||||
|
{
|
||||||
|
let mut inner = self.into_inner();
|
||||||
|
f(&mut inner);
|
||||||
|
Router {
|
||||||
|
inner: Arc::new(inner),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn into_inner(self) -> RouterInner<S> {
|
||||||
|
match Arc::try_unwrap(self.inner) {
|
||||||
|
Ok(inner) => inner,
|
||||||
|
Err(arc) => RouterInner {
|
||||||
|
path_router: arc.path_router.clone(),
|
||||||
|
fallback_router: arc.fallback_router.clone(),
|
||||||
|
default_fallback: arc.default_fallback,
|
||||||
|
catch_all_fallback: arc.catch_all_fallback.clone(),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc = include_str!("../docs/routing/route.md")]
|
#[doc = include_str!("../docs/routing/route.md")]
|
||||||
#[track_caller]
|
#[track_caller]
|
||||||
pub fn route(mut self, path: &str, method_router: MethodRouter<S>) -> Self {
|
pub fn route(self, path: &str, method_router: MethodRouter<S>) -> Self {
|
||||||
panic_on_err!(self.path_router.route(path, method_router));
|
self.tap_inner_mut(|this| {
|
||||||
self
|
panic_on_err!(this.path_router.route(path, method_router));
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc = include_str!("../docs/routing/route_service.md")]
|
#[doc = include_str!("../docs/routing/route_service.md")]
|
||||||
pub fn route_service<T>(mut self, path: &str, service: T) -> Self
|
pub fn route_service<T>(self, path: &str, service: T) -> Self
|
||||||
where
|
where
|
||||||
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
|
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
|
||||||
T::Response: IntoResponse,
|
T::Response: IntoResponse,
|
||||||
|
@ -142,14 +179,15 @@ where
|
||||||
Err(service) => service,
|
Err(service) => service,
|
||||||
};
|
};
|
||||||
|
|
||||||
panic_on_err!(self.path_router.route_service(path, service));
|
self.tap_inner_mut(|this| {
|
||||||
self
|
panic_on_err!(this.path_router.route_service(path, service));
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc = include_str!("../docs/routing/nest.md")]
|
#[doc = include_str!("../docs/routing/nest.md")]
|
||||||
#[track_caller]
|
#[track_caller]
|
||||||
pub fn nest(mut self, path: &str, router: Router<S>) -> Self {
|
pub fn nest(self, path: &str, router: Router<S>) -> Self {
|
||||||
let Router {
|
let RouterInner {
|
||||||
path_router,
|
path_router,
|
||||||
fallback_router,
|
fallback_router,
|
||||||
default_fallback,
|
default_fallback,
|
||||||
|
@ -157,63 +195,66 @@ where
|
||||||
// requests with an empty path. If we were to inherit the catch-all fallback
|
// requests with an empty path. If we were to inherit the catch-all fallback
|
||||||
// it would end up matching `/{path}/*` which doesn't match empty paths.
|
// it would end up matching `/{path}/*` which doesn't match empty paths.
|
||||||
catch_all_fallback: _,
|
catch_all_fallback: _,
|
||||||
} = router;
|
} = router.into_inner();
|
||||||
|
|
||||||
panic_on_err!(self.path_router.nest(path, path_router));
|
self.tap_inner_mut(|this| {
|
||||||
|
panic_on_err!(this.path_router.nest(path, path_router));
|
||||||
|
|
||||||
if !default_fallback {
|
if !default_fallback {
|
||||||
panic_on_err!(self.fallback_router.nest(path, fallback_router));
|
panic_on_err!(this.fallback_router.nest(path, fallback_router));
|
||||||
}
|
}
|
||||||
|
})
|
||||||
self
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Like [`nest`](Self::nest), but accepts an arbitrary `Service`.
|
/// Like [`nest`](Self::nest), but accepts an arbitrary `Service`.
|
||||||
#[track_caller]
|
#[track_caller]
|
||||||
pub fn nest_service<T>(mut self, path: &str, service: T) -> Self
|
pub fn nest_service<T>(self, path: &str, service: T) -> Self
|
||||||
where
|
where
|
||||||
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
|
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
|
||||||
T::Response: IntoResponse,
|
T::Response: IntoResponse,
|
||||||
T::Future: Send + 'static,
|
T::Future: Send + 'static,
|
||||||
{
|
{
|
||||||
panic_on_err!(self.path_router.nest_service(path, service));
|
self.tap_inner_mut(|this| {
|
||||||
self
|
panic_on_err!(this.path_router.nest_service(path, service));
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc = include_str!("../docs/routing/merge.md")]
|
#[doc = include_str!("../docs/routing/merge.md")]
|
||||||
#[track_caller]
|
#[track_caller]
|
||||||
pub fn merge<R>(mut self, other: R) -> Self
|
pub fn merge<R>(self, other: R) -> Self
|
||||||
where
|
where
|
||||||
R: Into<Router<S>>,
|
R: Into<Router<S>>,
|
||||||
{
|
{
|
||||||
const PANIC_MSG: &str =
|
const PANIC_MSG: &str =
|
||||||
"Failed to merge fallbacks. This is a bug in axum. Please file an issue";
|
"Failed to merge fallbacks. This is a bug in axum. Please file an issue";
|
||||||
|
|
||||||
let Router {
|
let other: Router<S> = other.into();
|
||||||
|
let RouterInner {
|
||||||
path_router,
|
path_router,
|
||||||
fallback_router: mut other_fallback,
|
fallback_router: mut other_fallback,
|
||||||
default_fallback,
|
default_fallback,
|
||||||
catch_all_fallback,
|
catch_all_fallback,
|
||||||
} = other.into();
|
} = other.into_inner();
|
||||||
|
|
||||||
panic_on_err!(self.path_router.merge(path_router));
|
self.map_inner(|mut this| {
|
||||||
|
panic_on_err!(this.path_router.merge(path_router));
|
||||||
|
|
||||||
match (self.default_fallback, default_fallback) {
|
match (this.default_fallback, default_fallback) {
|
||||||
// both have the default fallback
|
// both have the default fallback
|
||||||
// use the one from other
|
// use the one from other
|
||||||
(true, true) => {
|
(true, true) => {
|
||||||
self.fallback_router.merge(other_fallback).expect(PANIC_MSG);
|
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
|
||||||
}
|
}
|
||||||
// self has default fallback, other has a custom fallback
|
// this has default fallback, other has a custom fallback
|
||||||
(true, false) => {
|
(true, false) => {
|
||||||
self.fallback_router.merge(other_fallback).expect(PANIC_MSG);
|
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
|
||||||
self.default_fallback = false;
|
this.default_fallback = false;
|
||||||
}
|
}
|
||||||
// self has a custom fallback, other has a default
|
// this has a custom fallback, other has a default
|
||||||
(false, true) => {
|
(false, true) => {
|
||||||
let fallback_router = std::mem::take(&mut self.fallback_router);
|
let fallback_router = std::mem::take(&mut this.fallback_router);
|
||||||
other_fallback.merge(fallback_router).expect(PANIC_MSG);
|
other_fallback.merge(fallback_router).expect(PANIC_MSG);
|
||||||
self.fallback_router = other_fallback;
|
this.fallback_router = other_fallback;
|
||||||
}
|
}
|
||||||
// both have a custom fallback, not allowed
|
// both have a custom fallback, not allowed
|
||||||
(false, false) => {
|
(false, false) => {
|
||||||
|
@ -221,12 +262,13 @@ where
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
self.catch_all_fallback = self
|
this.catch_all_fallback = this
|
||||||
.catch_all_fallback
|
.catch_all_fallback
|
||||||
.merge(catch_all_fallback)
|
.merge(catch_all_fallback)
|
||||||
.unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback"));
|
.unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback"));
|
||||||
|
|
||||||
self
|
this
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc = include_str!("../docs/routing/layer.md")]
|
#[doc = include_str!("../docs/routing/layer.md")]
|
||||||
|
@ -238,12 +280,12 @@ where
|
||||||
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
|
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
|
||||||
<L::Service as Service<Request>>::Future: Send + 'static,
|
<L::Service as Service<Request>>::Future: Send + 'static,
|
||||||
{
|
{
|
||||||
Router {
|
self.map_inner(|this| RouterInner {
|
||||||
path_router: self.path_router.layer(layer.clone()),
|
path_router: this.path_router.layer(layer.clone()),
|
||||||
fallback_router: self.fallback_router.layer(layer.clone()),
|
fallback_router: this.fallback_router.layer(layer.clone()),
|
||||||
default_fallback: self.default_fallback,
|
default_fallback: this.default_fallback,
|
||||||
catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)),
|
catch_all_fallback: this.catch_all_fallback.map(|route| route.layer(layer)),
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc = include_str!("../docs/routing/route_layer.md")]
|
#[doc = include_str!("../docs/routing/route_layer.md")]
|
||||||
|
@ -256,68 +298,76 @@ where
|
||||||
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
|
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
|
||||||
<L::Service as Service<Request>>::Future: Send + 'static,
|
<L::Service as Service<Request>>::Future: Send + 'static,
|
||||||
{
|
{
|
||||||
Router {
|
self.map_inner(|this| RouterInner {
|
||||||
path_router: self.path_router.route_layer(layer),
|
path_router: this.path_router.route_layer(layer),
|
||||||
fallback_router: self.fallback_router,
|
fallback_router: this.fallback_router,
|
||||||
default_fallback: self.default_fallback,
|
default_fallback: this.default_fallback,
|
||||||
catch_all_fallback: self.catch_all_fallback,
|
catch_all_fallback: this.catch_all_fallback,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[track_caller]
|
#[track_caller]
|
||||||
#[doc = include_str!("../docs/routing/fallback.md")]
|
#[doc = include_str!("../docs/routing/fallback.md")]
|
||||||
pub fn fallback<H, T>(mut self, handler: H) -> Self
|
pub fn fallback<H, T>(self, handler: H) -> Self
|
||||||
where
|
where
|
||||||
H: Handler<T, S>,
|
H: Handler<T, S>,
|
||||||
T: 'static,
|
T: 'static,
|
||||||
{
|
{
|
||||||
self.catch_all_fallback =
|
self.tap_inner_mut(|this| {
|
||||||
|
this.catch_all_fallback =
|
||||||
Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone()));
|
Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone()));
|
||||||
self.fallback_endpoint(Endpoint::MethodRouter(any(handler)))
|
})
|
||||||
|
.fallback_endpoint(Endpoint::MethodRouter(any(handler)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a fallback [`Service`] to the router.
|
/// Add a fallback [`Service`] to the router.
|
||||||
///
|
///
|
||||||
/// See [`Router::fallback`] for more details.
|
/// See [`Router::fallback`] for more details.
|
||||||
pub fn fallback_service<T>(mut self, service: T) -> Self
|
pub fn fallback_service<T>(self, service: T) -> Self
|
||||||
where
|
where
|
||||||
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
|
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
|
||||||
T::Response: IntoResponse,
|
T::Response: IntoResponse,
|
||||||
T::Future: Send + 'static,
|
T::Future: Send + 'static,
|
||||||
{
|
{
|
||||||
let route = Route::new(service);
|
let route = Route::new(service);
|
||||||
self.catch_all_fallback = Fallback::Service(route.clone());
|
self.tap_inner_mut(|this| {
|
||||||
self.fallback_endpoint(Endpoint::Route(route))
|
this.catch_all_fallback = Fallback::Service(route.clone());
|
||||||
|
})
|
||||||
|
.fallback_endpoint(Endpoint::Route(route))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fallback_endpoint(mut self, endpoint: Endpoint<S>) -> Self {
|
fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self {
|
||||||
self.fallback_router.set_fallback(endpoint);
|
self.tap_inner_mut(|this| {
|
||||||
self.default_fallback = false;
|
this.fallback_router.set_fallback(endpoint);
|
||||||
self
|
this.default_fallback = false;
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc = include_str!("../docs/routing/with_state.md")]
|
#[doc = include_str!("../docs/routing/with_state.md")]
|
||||||
pub fn with_state<S2>(self, state: S) -> Router<S2> {
|
pub fn with_state<S2>(self, state: S) -> Router<S2> {
|
||||||
Router {
|
self.map_inner(|this| RouterInner {
|
||||||
path_router: self.path_router.with_state(state.clone()),
|
path_router: this.path_router.with_state(state.clone()),
|
||||||
fallback_router: self.fallback_router.with_state(state.clone()),
|
fallback_router: this.fallback_router.with_state(state.clone()),
|
||||||
default_fallback: self.default_fallback,
|
default_fallback: this.default_fallback,
|
||||||
catch_all_fallback: self.catch_all_fallback.with_state(state),
|
catch_all_fallback: this.catch_all_fallback.with_state(state),
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<Infallible> {
|
pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<Infallible> {
|
||||||
let (req, state) = match self.path_router.call_with_state(req, state) {
|
let (req, state) = match self.inner.path_router.call_with_state(req, state) {
|
||||||
Ok(future) => return future,
|
Ok(future) => return future,
|
||||||
Err((req, state)) => (req, state),
|
Err((req, state)) => (req, state),
|
||||||
};
|
};
|
||||||
|
|
||||||
let (req, state) = match self.fallback_router.call_with_state(req, state) {
|
let (req, state) = match self.inner.fallback_router.call_with_state(req, state) {
|
||||||
Ok(future) => return future,
|
Ok(future) => return future,
|
||||||
Err((req, state)) => (req, state),
|
Err((req, state)) => (req, state),
|
||||||
};
|
};
|
||||||
|
|
||||||
self.catch_all_fallback.call_with_state(req, state)
|
self.inner
|
||||||
|
.catch_all_fallback
|
||||||
|
.clone()
|
||||||
|
.call_with_state(req, state)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert the router into a borrowed [`Service`] with a fixed request body type, to aid type
|
/// Convert the router into a borrowed [`Service`] with a fixed request body type, to aid type
|
||||||
|
|
|
@ -316,7 +316,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn call_with_state(
|
pub(super) fn call_with_state(
|
||||||
&mut self,
|
&self,
|
||||||
mut req: Request,
|
mut req: Request,
|
||||||
state: S,
|
state: S,
|
||||||
) -> Result<RouteFuture<Infallible>, (Request, S)> {
|
) -> Result<RouteFuture<Infallible>, (Request, S)> {
|
||||||
|
@ -349,7 +349,7 @@ where
|
||||||
|
|
||||||
let endpoint = self
|
let endpoint = self
|
||||||
.routes
|
.routes
|
||||||
.get_mut(&id)
|
.get(&id)
|
||||||
.expect("no route for id. This is a bug in axum. Please file an issue");
|
.expect("no route for id. This is a bug in axum. Please file an issue");
|
||||||
|
|
||||||
match endpoint {
|
match endpoint {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
body::{Body, HttpBody},
|
body::{Body, HttpBody},
|
||||||
response::Response,
|
response::Response,
|
||||||
|
util::AxumMutex,
|
||||||
};
|
};
|
||||||
use axum_core::{extract::Request, response::IntoResponse};
|
use axum_core::{extract::Request, response::IntoResponse};
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
|
@ -14,7 +15,6 @@ use std::{
|
||||||
fmt,
|
fmt,
|
||||||
future::Future,
|
future::Future,
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
sync::Mutex,
|
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
use tower::{
|
use tower::{
|
||||||
|
@ -28,7 +28,7 @@ use tower_service::Service;
|
||||||
///
|
///
|
||||||
/// You normally shouldn't need to care about this type. It's used in
|
/// You normally shouldn't need to care about this type. It's used in
|
||||||
/// [`Router::layer`](super::Router::layer).
|
/// [`Router::layer`](super::Router::layer).
|
||||||
pub struct Route<E = Infallible>(Mutex<BoxCloneService<Request, Response, E>>);
|
pub struct Route<E = Infallible>(AxumMutex<BoxCloneService<Request, Response, E>>);
|
||||||
|
|
||||||
impl<E> Route<E> {
|
impl<E> Route<E> {
|
||||||
pub(crate) fn new<T>(svc: T) -> Self
|
pub(crate) fn new<T>(svc: T) -> Self
|
||||||
|
@ -37,7 +37,7 @@ impl<E> Route<E> {
|
||||||
T::Response: IntoResponse + 'static,
|
T::Response: IntoResponse + 'static,
|
||||||
T::Future: Send + 'static,
|
T::Future: Send + 'static,
|
||||||
{
|
{
|
||||||
Self(Mutex::new(BoxCloneService::new(
|
Self(AxumMutex::new(BoxCloneService::new(
|
||||||
svc.map_response(IntoResponse::into_response),
|
svc.map_response(IntoResponse::into_response),
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
@ -70,8 +70,9 @@ impl<E> Route<E> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<E> Clone for Route<E> {
|
impl<E> Clone for Route<E> {
|
||||||
|
#[track_caller]
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self(Mutex::new(self.0.lock().unwrap().clone()))
|
Self(AxumMutex::new(self.0.lock().unwrap().clone()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ use crate::{
|
||||||
tracing_helpers::{capture_tracing, TracingEvent},
|
tracing_helpers::{capture_tracing, TracingEvent},
|
||||||
*,
|
*,
|
||||||
},
|
},
|
||||||
|
util::mutex_num_locked,
|
||||||
BoxError, Extension, Json, Router, ServiceExt,
|
BoxError, Extension, Json, Router, ServiceExt,
|
||||||
};
|
};
|
||||||
use axum_core::extract::Request;
|
use axum_core::extract::Request;
|
||||||
|
@ -951,7 +952,7 @@ async fn state_isnt_cloned_too_much() {
|
||||||
|
|
||||||
client.get("/").await;
|
client.get("/").await;
|
||||||
|
|
||||||
assert_eq!(COUNT.load(Ordering::SeqCst), 5);
|
assert_eq!(COUNT.load(Ordering::SeqCst), 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[crate::test]
|
#[crate::test]
|
||||||
|
@ -1066,3 +1067,35 @@ async fn impl_handler_for_into_response() {
|
||||||
assert_eq!(res.status(), StatusCode::CREATED);
|
assert_eq!(res.status(), StatusCode::CREATED);
|
||||||
assert_eq!(res.text().await, "thing created");
|
assert_eq!(res.text().await, "thing created");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[crate::test]
|
||||||
|
async fn locks_mutex_very_little() {
|
||||||
|
let (num, app) = mutex_num_locked(|| async {
|
||||||
|
Router::new()
|
||||||
|
.route("/a", get(|| async {}))
|
||||||
|
.route("/b", get(|| async {}))
|
||||||
|
.route("/c", get(|| async {}))
|
||||||
|
.with_state::<()>(())
|
||||||
|
.into_service::<Body>()
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
// once for `Router::new` for setting the default fallback and 3 times, once per route
|
||||||
|
assert_eq!(num, 4);
|
||||||
|
|
||||||
|
for path in ["/a", "/b", "/c"] {
|
||||||
|
// calling the router should only lock the mutex once
|
||||||
|
let (num, _res) = mutex_num_locked(|| async {
|
||||||
|
// We cannot use `TestClient` because it uses `serve` which spawns a new task per
|
||||||
|
// connection and `mutex_num_locked` uses a task local to keep track of the number of
|
||||||
|
// locks. So spawning a new task would unset the task local set by `mutex_num_locked`
|
||||||
|
//
|
||||||
|
// So instead `call` the service directly without spawning new tasks.
|
||||||
|
app.clone()
|
||||||
|
.oneshot(Request::builder().uri(path).body(Body::empty()).unwrap())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
assert_eq!(num, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
use std::{
|
use crate::util::AxumMutex;
|
||||||
future::Future,
|
use std::{future::Future, io, sync::Arc};
|
||||||
io,
|
|
||||||
sync::{Arc, Mutex},
|
|
||||||
};
|
|
||||||
|
|
||||||
use serde::{de::DeserializeOwned, Deserialize};
|
use serde::{de::DeserializeOwned, Deserialize};
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
|
@ -50,12 +47,12 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TestMakeWriter {
|
struct TestMakeWriter {
|
||||||
write: Arc<Mutex<Option<Vec<u8>>>>,
|
write: Arc<AxumMutex<Option<Vec<u8>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TestMakeWriter {
|
impl TestMakeWriter {
|
||||||
fn new() -> (Self, Handle) {
|
fn new() -> (Self, Handle) {
|
||||||
let write = Arc::new(Mutex::new(Some(Vec::<u8>::new())));
|
let write = Arc::new(AxumMutex::new(Some(Vec::<u8>::new())));
|
||||||
|
|
||||||
(
|
(
|
||||||
Self {
|
Self {
|
||||||
|
@ -97,7 +94,7 @@ impl<'a> io::Write for Writer<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Handle {
|
struct Handle {
|
||||||
write: Arc<Mutex<Option<Vec<u8>>>>,
|
write: Arc<AxumMutex<Option<Vec<u8>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Handle {
|
impl Handle {
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
use pin_project_lite::pin_project;
|
use pin_project_lite::pin_project;
|
||||||
use std::{ops::Deref, sync::Arc};
|
use std::{ops::Deref, sync::Arc};
|
||||||
|
|
||||||
|
pub(crate) use self::mutex::*;
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||||
pub(crate) struct PercentDecodedStr(Arc<str>);
|
pub(crate) struct PercentDecodedStr(Arc<str>);
|
||||||
|
|
||||||
|
@ -59,3 +61,65 @@ fn test_try_downcast() {
|
||||||
assert_eq!(try_downcast::<i32, _>(5_u32), Err(5_u32));
|
assert_eq!(try_downcast::<i32, _>(5_u32), Err(5_u32));
|
||||||
assert_eq!(try_downcast::<i32, _>(5_i32), Ok(5_i32));
|
assert_eq!(try_downcast::<i32, _>(5_i32), Ok(5_i32));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// `AxumMutex` is a wrapper around `std::sync::Mutex` which, in test mode, tracks the number of
|
||||||
|
// times it's been locked on the current task. That way we can write a test to ensure we don't
|
||||||
|
// accidentally introduce more locking.
|
||||||
|
//
|
||||||
|
// When not in test mode, it is just a type alias for `std::sync::Mutex`.
|
||||||
|
#[cfg(not(test))]
|
||||||
|
mod mutex {
|
||||||
|
#[allow(clippy::disallowed_types)]
|
||||||
|
pub(crate) type AxumMutex<T> = std::sync::Mutex<T>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod mutex {
|
||||||
|
#![allow(clippy::disallowed_types)]
|
||||||
|
|
||||||
|
use std::sync::{
|
||||||
|
atomic::{AtomicUsize, Ordering},
|
||||||
|
LockResult, Mutex, MutexGuard,
|
||||||
|
};
|
||||||
|
|
||||||
|
tokio::task_local! {
|
||||||
|
pub(crate) static NUM_LOCKED: AtomicUsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn mutex_num_locked<F, Fut>(f: F) -> (usize, Fut::Output)
|
||||||
|
where
|
||||||
|
F: FnOnce() -> Fut,
|
||||||
|
Fut: std::future::IntoFuture,
|
||||||
|
{
|
||||||
|
NUM_LOCKED
|
||||||
|
.scope(AtomicUsize::new(0), async move {
|
||||||
|
let output = f().await;
|
||||||
|
let num = NUM_LOCKED.with(|num| num.load(Ordering::SeqCst));
|
||||||
|
(num, output)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct AxumMutex<T>(Mutex<T>);
|
||||||
|
|
||||||
|
impl<T> AxumMutex<T> {
|
||||||
|
pub(crate) fn new(value: T) -> Self {
|
||||||
|
Self(Mutex::new(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_mut(&mut self) -> LockResult<&mut T> {
|
||||||
|
self.0.get_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn into_inner(self) -> LockResult<T> {
|
||||||
|
self.0.into_inner()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn lock(&self) -> LockResult<MutexGuard<'_, T>> {
|
||||||
|
_ = NUM_LOCKED.try_with(|num| {
|
||||||
|
num.fetch_add(1, Ordering::SeqCst);
|
||||||
|
});
|
||||||
|
self.0.lock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue