Replace Router::{map_inner, tap_inner_mut} by macros (#2954)

This commit is contained in:
Jonas Platte 2024-10-04 17:00:50 +00:00 committed by GitHub
parent 31a87f8b2b
commit 20a0624795
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -103,6 +103,31 @@ pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/{*__private__axum_nest_tail_p
pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback"; pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback";
pub(crate) const FALLBACK_PARAM_PATH: &str = "/{*__private__axum_fallback}"; pub(crate) const FALLBACK_PARAM_PATH: &str = "/{*__private__axum_fallback}";
macro_rules! map_inner {
( $self_:ident, $inner:pat_param => $expr:expr) => {
#[allow(redundant_semicolons)]
{
let $inner = $self_.into_inner();
Router {
inner: Arc::new($expr),
}
}
};
}
macro_rules! tap_inner {
( $self_:ident, mut $inner:ident => { $($stmt:stmt)* } ) => {
#[allow(redundant_semicolons)]
{
let mut $inner = $self_.into_inner();
$($stmt)*
Router {
inner: Arc::new($inner),
}
}
};
}
impl<S> Router<S> impl<S> Router<S>
where where
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
@ -122,26 +147,6 @@ where
} }
} }
fn map_inner<F, S2>(self, f: F) -> Router<S2>
where
F: FnOnce(RouterInner<S>) -> RouterInner<S2>,
{
Router {
inner: Arc::new(f(self.into_inner())),
}
}
fn tap_inner_mut<F>(self, f: F) -> Self
where
F: FnOnce(&mut RouterInner<S>),
{
let mut inner = self.into_inner();
f(&mut inner);
Router {
inner: Arc::new(inner),
}
}
fn into_inner(self) -> RouterInner<S> { fn into_inner(self) -> RouterInner<S> {
match Arc::try_unwrap(self.inner) { match Arc::try_unwrap(self.inner) {
Ok(inner) => inner, Ok(inner) => inner,
@ -156,7 +161,7 @@ where
#[doc = include_str!("../docs/routing/without_v07_checks.md")] #[doc = include_str!("../docs/routing/without_v07_checks.md")]
pub fn without_v07_checks(self) -> Self { pub fn without_v07_checks(self) -> Self {
self.tap_inner_mut(|this| { tap_inner!(self, mut this => {
this.path_router.without_v07_checks(); this.path_router.without_v07_checks();
}) })
} }
@ -164,7 +169,7 @@ where
#[doc = include_str!("../docs/routing/route.md")] #[doc = include_str!("../docs/routing/route.md")]
#[track_caller] #[track_caller]
pub fn route(self, path: &str, method_router: MethodRouter<S>) -> Self { pub fn route(self, path: &str, method_router: MethodRouter<S>) -> Self {
self.tap_inner_mut(|this| { tap_inner!(self, mut this => {
panic_on_err!(this.path_router.route(path, method_router)); panic_on_err!(this.path_router.route(path, method_router));
}) })
} }
@ -186,7 +191,7 @@ where
Err(service) => service, Err(service) => service,
}; };
self.tap_inner_mut(|this| { tap_inner!(self, mut this => {
panic_on_err!(this.path_router.route_service(path, service)); panic_on_err!(this.path_router.route_service(path, service));
}) })
} }
@ -205,7 +210,7 @@ where
catch_all_fallback: _, catch_all_fallback: _,
} = router.into_inner(); } = router.into_inner();
self.tap_inner_mut(|this| { tap_inner!(self, mut this => {
panic_on_err!(this.path_router.nest(path, path_router)); panic_on_err!(this.path_router.nest(path, path_router));
if !default_fallback { if !default_fallback {
@ -222,7 +227,7 @@ where
T::Response: IntoResponse, T::Response: IntoResponse,
T::Future: Send + 'static, T::Future: Send + 'static,
{ {
self.tap_inner_mut(|this| { tap_inner!(self, mut this => {
panic_on_err!(this.path_router.nest_service(path, service)); panic_on_err!(this.path_router.nest_service(path, service));
}) })
} }
@ -244,7 +249,7 @@ where
catch_all_fallback, catch_all_fallback,
} = other.into_inner(); } = other.into_inner();
self.map_inner(|mut this| { map_inner!(self, mut this => {
panic_on_err!(this.path_router.merge(path_router)); panic_on_err!(this.path_router.merge(path_router));
match (this.default_fallback, default_fallback) { match (this.default_fallback, default_fallback) {
@ -288,7 +293,7 @@ where
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static, <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request>>::Future: Send + 'static, <L::Service as Service<Request>>::Future: Send + 'static,
{ {
self.map_inner(|this| RouterInner { map_inner!(self, this => RouterInner {
path_router: this.path_router.layer(layer.clone()), path_router: this.path_router.layer(layer.clone()),
fallback_router: this.fallback_router.layer(layer.clone()), fallback_router: this.fallback_router.layer(layer.clone()),
default_fallback: this.default_fallback, default_fallback: this.default_fallback,
@ -306,7 +311,7 @@ where
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static, <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request>>::Future: Send + 'static, <L::Service as Service<Request>>::Future: Send + 'static,
{ {
self.map_inner(|this| RouterInner { map_inner!(self, this => RouterInner {
path_router: this.path_router.route_layer(layer), path_router: this.path_router.route_layer(layer),
fallback_router: this.fallback_router, fallback_router: this.fallback_router,
default_fallback: this.default_fallback, default_fallback: this.default_fallback,
@ -326,7 +331,7 @@ where
H: Handler<T, S>, H: Handler<T, S>,
T: 'static, T: 'static,
{ {
self.tap_inner_mut(|this| { tap_inner!(self, mut this => {
this.catch_all_fallback = this.catch_all_fallback =
Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone())); Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone()));
}) })
@ -343,14 +348,14 @@ where
T::Future: Send + 'static, T::Future: Send + 'static,
{ {
let route = Route::new(service); let route = Route::new(service);
self.tap_inner_mut(|this| { tap_inner!(self, mut this => {
this.catch_all_fallback = Fallback::Service(route.clone()); this.catch_all_fallback = Fallback::Service(route.clone());
}) })
.fallback_endpoint(Endpoint::Route(route)) .fallback_endpoint(Endpoint::Route(route))
} }
fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self { fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self {
self.tap_inner_mut(|this| { tap_inner!(self, mut this => {
this.fallback_router.set_fallback(endpoint); this.fallback_router.set_fallback(endpoint);
this.default_fallback = false; this.default_fallback = false;
}) })
@ -358,7 +363,7 @@ where
#[doc = include_str!("../docs/routing/with_state.md")] #[doc = include_str!("../docs/routing/with_state.md")]
pub fn with_state<S2>(self, state: S) -> Router<S2> { pub fn with_state<S2>(self, state: S) -> Router<S2> {
self.map_inner(|this| RouterInner { map_inner!(self, this => RouterInner {
path_router: this.path_router.with_state(state.clone()), path_router: this.path_router.with_state(state.clone()),
fallback_router: this.fallback_router.with_state(state.clone()), fallback_router: this.fallback_router.with_state(state.clone()),
default_fallback: this.default_fallback, default_fallback: this.default_fallback,