From b816ac7cdf14d8dade5517ee6b49989e1db62074 Mon Sep 17 00:00:00 2001 From: David Pedersen <david.pdrsn@gmail.com> Date: Sat, 19 Nov 2022 22:51:55 +0100 Subject: [PATCH] Add `RouterService::{layer, route_layer}` (#1550) * Add `RouterService::{layer, route_layer}` Figure we might as well have these. * changelog --- axum/CHANGELOG.md | 4 ++- axum/src/routing/mod.rs | 21 ++++++++++++++- axum/src/routing/service.rs | 54 +++++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 2 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 6f43c760..bbf51ef8 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **added:** Add `RouterService::{layer, route_layer}` ([#1550]) + +[#1550]: https://github.com/tokio-rs/axum/pull/1550 # 0.6.0-rc.5 (18. November, 2022) diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 5eaf1c0d..2544ac41 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -311,13 +311,14 @@ where } #[doc = include_str!("../docs/routing/layer.md")] - pub fn layer<L, NewReqBody: 'static>(self, layer: L) -> Router<S, NewReqBody> + pub fn layer<L, NewReqBody>(self, layer: L) -> Router<S, NewReqBody> where L: Layer<Route<B>> + Clone + Send + 'static, L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static, <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, + NewReqBody: 'static, { let routes = self .routes @@ -566,6 +567,24 @@ pub(crate) enum FallbackRoute<B, E = Infallible> { Service(Route<B, E>), } +impl<B, E> FallbackRoute<B, E> { + fn layer<L, NewReqBody, NewError>(self, layer: L) -> FallbackRoute<NewReqBody, NewError> + where + L: Layer<Route<B, E>> + Clone + Send + 'static, + L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static, + <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, + <L::Service as Service<Request<NewReqBody>>>::Error: Into<NewError> + 'static, + <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, + NewReqBody: 'static, + NewError: 'static, + { + match self { + FallbackRoute::Default(route) => FallbackRoute::Default(route.layer(layer)), + FallbackRoute::Service(route) => FallbackRoute::Service(route.layer(layer)), + } + } +} + impl<B, E> fmt::Debug for FallbackRoute<B, E> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { diff --git a/axum/src/routing/service.rs b/axum/src/routing/service.rs index 141bd442..cbd655d8 100644 --- a/axum/src/routing/service.rs +++ b/axum/src/routing/service.rs @@ -5,6 +5,7 @@ use crate::{ body::{Body, HttpBody}, response::Response, }; +use axum_core::response::IntoResponse; use http::Request; use matchit::MatchError; use std::{ @@ -15,6 +16,7 @@ use std::{ }; use sync_wrapper::SyncWrapper; use tower::Service; +use tower_layer::Layer; /// A [`Router`] converted into a [`Service`]. #[derive(Debug)] @@ -76,6 +78,58 @@ where route.call(req) } + /// Apply a [`tower::Layer`] to all routes in the router. + /// + /// See [`Router::layer`] for more details. + pub fn layer<L, NewReqBody>(self, layer: L) -> RouterService<NewReqBody> + where + L: Layer<Route<B>> + Clone + Send + 'static, + L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static, + <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, + <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static, + <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, + NewReqBody: 'static, + { + let routes = self + .routes + .into_iter() + .map(|(id, route)| (id, route.layer(layer.clone()))) + .collect(); + + let fallback = self.fallback.layer(layer); + + RouterService { + routes, + node: self.node, + fallback, + } + } + + /// Apply a [`tower::Layer`] to the router that will only run if the request matches + /// a route. + /// + /// See [`Router::route_layer`] for more details. + pub fn route_layer<L>(self, layer: L) -> Self + where + L: Layer<Route<B>> + Clone + Send + 'static, + L::Service: Service<Request<B>> + Clone + Send + 'static, + <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static, + <L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static, + <L::Service as Service<Request<B>>>::Future: Send + 'static, + { + let routes = self + .routes + .into_iter() + .map(|(id, route)| (id, route.layer(layer.clone()))) + .collect(); + + Self { + routes, + node: self.node, + fallback: self.fallback, + } + } + /// Convert the `RouterService` into a [`MakeService`]. /// /// See [`Router::into_make_service`] for more details.