diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 61e7e0f6..7499e1fc 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -103,6 +103,31 @@ pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_pa pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback"; pub(crate) const FALLBACK_PARAM_PATH: &str = "/*__private__axum_fallback"; +macro_rules! map_inner { + ( $self_:ident, $inner:pat_param => $expr:expr) => { + #[allow(redundant_semicolons)] + { + let $inner = $self_.into_inner(); + Router { + inner: Arc::new($expr), + } + } + }; +} + +macro_rules! tap_inner { + ( $self_:ident, mut $inner:ident => { $($stmt:stmt)* } ) => { + #[allow(redundant_semicolons)] + { + let mut $inner = $self_.into_inner(); + $($stmt)* + Router { + inner: Arc::new($inner), + } + } + }; +} + impl Router where S: Clone + Send + Sync + 'static, @@ -122,26 +147,6 @@ where } } - fn map_inner(self, f: F) -> Router - where - F: FnOnce(RouterInner) -> RouterInner, - { - Router { - inner: Arc::new(f(self.into_inner())), - } - } - - fn tap_inner_mut(self, f: F) -> Self - where - F: FnOnce(&mut RouterInner), - { - let mut inner = self.into_inner(); - f(&mut inner); - Router { - inner: Arc::new(inner), - } - } - fn into_inner(self) -> RouterInner { match Arc::try_unwrap(self.inner) { Ok(inner) => inner, @@ -157,7 +162,7 @@ where #[doc = include_str!("../docs/routing/route.md")] #[track_caller] pub fn route(self, path: &str, method_router: MethodRouter) -> Self { - self.tap_inner_mut(|this| { + tap_inner!(self, mut this => { panic_on_err!(this.path_router.route(path, method_router)); }) } @@ -179,7 +184,7 @@ where Err(service) => service, }; - self.tap_inner_mut(|this| { + tap_inner!(self, mut this => { panic_on_err!(this.path_router.route_service(path, service)); }) } @@ -198,7 +203,7 @@ where catch_all_fallback: _, } = router.into_inner(); - self.tap_inner_mut(|this| { + tap_inner!(self, mut this => { panic_on_err!(this.path_router.nest(path, path_router)); if !default_fallback { @@ -215,7 +220,7 @@ where T::Response: IntoResponse, T::Future: Send + 'static, { - self.tap_inner_mut(|this| { + tap_inner!(self, mut this => { panic_on_err!(this.path_router.nest_service(path, service)); }) } @@ -237,7 +242,7 @@ where catch_all_fallback, } = other.into_inner(); - self.map_inner(|mut this| { + map_inner!(self, mut this => { panic_on_err!(this.path_router.merge(path_router)); match (this.default_fallback, default_fallback) { @@ -281,7 +286,7 @@ where >::Error: Into + 'static, >::Future: Send + 'static, { - self.map_inner(|this| RouterInner { + map_inner!(self, this => RouterInner { path_router: this.path_router.layer(layer.clone()), fallback_router: this.fallback_router.layer(layer.clone()), default_fallback: this.default_fallback, @@ -299,7 +304,7 @@ where >::Error: Into + 'static, >::Future: Send + 'static, { - self.map_inner(|this| RouterInner { + map_inner!(self, this => RouterInner { path_router: this.path_router.route_layer(layer), fallback_router: this.fallback_router, default_fallback: this.default_fallback, @@ -319,7 +324,7 @@ where H: Handler, T: 'static, { - self.tap_inner_mut(|this| { + tap_inner!(self, mut this => { this.catch_all_fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone())); }) @@ -336,14 +341,14 @@ where T::Future: Send + 'static, { let route = Route::new(service); - self.tap_inner_mut(|this| { + tap_inner!(self, mut this => { this.catch_all_fallback = Fallback::Service(route.clone()); }) .fallback_endpoint(Endpoint::Route(route)) } fn fallback_endpoint(self, endpoint: Endpoint) -> Self { - self.tap_inner_mut(|this| { + tap_inner!(self, mut this => { this.fallback_router.set_fallback(endpoint); this.default_fallback = false; }) @@ -351,7 +356,7 @@ where #[doc = include_str!("../docs/routing/with_state.md")] pub fn with_state(self, state: S) -> Router { - self.map_inner(|this| RouterInner { + map_inner!(self, this => RouterInner { path_router: this.path_router.with_state(state.clone()), fallback_router: this.fallback_router.with_state(state.clone()), default_fallback: this.default_fallback,