From b816ac7cdf14d8dade5517ee6b49989e1db62074 Mon Sep 17 00:00:00 2001
From: David Pedersen <david.pdrsn@gmail.com>
Date: Sat, 19 Nov 2022 22:51:55 +0100
Subject: [PATCH] Add `RouterService::{layer, route_layer}` (#1550)

* Add `RouterService::{layer, route_layer}`

Figure we might as well have these.

* changelog
---
 axum/CHANGELOG.md           |  4 ++-
 axum/src/routing/mod.rs     | 21 ++++++++++++++-
 axum/src/routing/service.rs | 54 +++++++++++++++++++++++++++++++++++++
 3 files changed, 77 insertions(+), 2 deletions(-)

diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md
index 6f43c760..bbf51ef8 100644
--- a/axum/CHANGELOG.md
+++ b/axum/CHANGELOG.md
@@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 # Unreleased
 
-- None.
+- **added:** Add `RouterService::{layer, route_layer}` ([#1550])
+
+[#1550]: https://github.com/tokio-rs/axum/pull/1550
 
 # 0.6.0-rc.5 (18. November, 2022)
 
diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs
index 5eaf1c0d..2544ac41 100644
--- a/axum/src/routing/mod.rs
+++ b/axum/src/routing/mod.rs
@@ -311,13 +311,14 @@ where
     }
 
     #[doc = include_str!("../docs/routing/layer.md")]
-    pub fn layer<L, NewReqBody: 'static>(self, layer: L) -> Router<S, NewReqBody>
+    pub fn layer<L, NewReqBody>(self, layer: L) -> Router<S, NewReqBody>
     where
         L: Layer<Route<B>> + Clone + Send + 'static,
         L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
         <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
         <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
         <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
+        NewReqBody: 'static,
     {
         let routes = self
             .routes
@@ -566,6 +567,24 @@ pub(crate) enum FallbackRoute<B, E = Infallible> {
     Service(Route<B, E>),
 }
 
+impl<B, E> FallbackRoute<B, E> {
+    fn layer<L, NewReqBody, NewError>(self, layer: L) -> FallbackRoute<NewReqBody, NewError>
+    where
+        L: Layer<Route<B, E>> + Clone + Send + 'static,
+        L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
+        <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
+        <L::Service as Service<Request<NewReqBody>>>::Error: Into<NewError> + 'static,
+        <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
+        NewReqBody: 'static,
+        NewError: 'static,
+    {
+        match self {
+            FallbackRoute::Default(route) => FallbackRoute::Default(route.layer(layer)),
+            FallbackRoute::Service(route) => FallbackRoute::Service(route.layer(layer)),
+        }
+    }
+}
+
 impl<B, E> fmt::Debug for FallbackRoute<B, E> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         match self {
diff --git a/axum/src/routing/service.rs b/axum/src/routing/service.rs
index 141bd442..cbd655d8 100644
--- a/axum/src/routing/service.rs
+++ b/axum/src/routing/service.rs
@@ -5,6 +5,7 @@ use crate::{
     body::{Body, HttpBody},
     response::Response,
 };
+use axum_core::response::IntoResponse;
 use http::Request;
 use matchit::MatchError;
 use std::{
@@ -15,6 +16,7 @@ use std::{
 };
 use sync_wrapper::SyncWrapper;
 use tower::Service;
+use tower_layer::Layer;
 
 /// A [`Router`] converted into a [`Service`].
 #[derive(Debug)]
@@ -76,6 +78,58 @@ where
         route.call(req)
     }
 
+    /// Apply a [`tower::Layer`] to all routes in the router.
+    ///
+    /// See [`Router::layer`] for more details.
+    pub fn layer<L, NewReqBody>(self, layer: L) -> RouterService<NewReqBody>
+    where
+        L: Layer<Route<B>> + Clone + Send + 'static,
+        L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
+        <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
+        <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
+        <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
+        NewReqBody: 'static,
+    {
+        let routes = self
+            .routes
+            .into_iter()
+            .map(|(id, route)| (id, route.layer(layer.clone())))
+            .collect();
+
+        let fallback = self.fallback.layer(layer);
+
+        RouterService {
+            routes,
+            node: self.node,
+            fallback,
+        }
+    }
+
+    /// Apply a [`tower::Layer`] to the router that will only run if the request matches
+    /// a route.
+    ///
+    /// See [`Router::route_layer`] for more details.
+    pub fn route_layer<L>(self, layer: L) -> Self
+    where
+        L: Layer<Route<B>> + Clone + Send + 'static,
+        L::Service: Service<Request<B>> + Clone + Send + 'static,
+        <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
+        <L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
+        <L::Service as Service<Request<B>>>::Future: Send + 'static,
+    {
+        let routes = self
+            .routes
+            .into_iter()
+            .map(|(id, route)| (id, route.layer(layer.clone())))
+            .collect();
+
+        Self {
+            routes,
+            node: self.node,
+            fallback: self.fallback,
+        }
+    }
+
     /// Convert the `RouterService` into a [`MakeService`].
     ///
     /// See [`Router::into_make_service`] for more details.