From 7386e5d185dd8238a0b79eca8e11dee8d697454c Mon Sep 17 00:00:00 2001
From: David Pedersen <david.pdrsn@gmail.com>
Date: Fri, 2 Dec 2022 11:42:49 +0100
Subject: [PATCH] Bug fixes for `RouterExt:{route_with_tsr,
 route_service_with_tsr}` (#1608)

* Bug fixes for `RouterExt:{route_with_tsr, route_service_with_tsr}`

* changelog link
---
 axum-extra/CHANGELOG.md       |   7 +-
 axum-extra/src/routing/mod.rs | 129 ++++++++++++++++++++++++++++------
 2 files changed, 113 insertions(+), 23 deletions(-)

diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md
index d72d54f7..e11e2606 100644
--- a/axum-extra/CHANGELOG.md
+++ b/axum-extra/CHANGELOG.md
@@ -7,7 +7,12 @@ and this project adheres to [Semantic Versioning].
 
 # Unreleased
 
-- None.
+- **fixed:** Bug fixes for `RouterExt:{route_with_tsr, route_service_with_tsr}` ([#1608]):
+  - Redirects to the correct URI if the route contains path parameters
+  - Keeps query parameters when redirecting
+  - Better improved error message if adding route for `/`
+
+[#1608]: https://github.com/tokio-rs/axum/pull/1608
 
 # 0.4.1 (29. November, 2022)
 
diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs
index 23a9dde2..6d52c284 100644
--- a/axum-extra/src/routing/mod.rs
+++ b/axum-extra/src/routing/mod.rs
@@ -1,13 +1,13 @@
 //! Additional types for defining routes.
 
 use axum::{
-    handler::HandlerWithoutStateExt,
     http::Request,
-    response::{IntoResponse, Redirect},
+    response::{IntoResponse, Redirect, Response},
     routing::{any, MethodRouter},
     Router,
 };
-use std::{convert::Infallible, future::ready, sync::Arc};
+use http::{uri::PathAndQuery, StatusCode, Uri};
+use std::{borrow::Cow, convert::Infallible};
 use tower_service::Service;
 
 mod resource;
@@ -265,18 +265,9 @@ where
     where
         Self: Sized,
     {
+        validate_tsr_path(path);
         self = self.route(path, method_router);
-
-        let redirect_service = {
-            let path: Arc<str> = path.into();
-            (move || ready(Redirect::permanent(&path))).into_service()
-        };
-
-        if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
-            self.route_service(path_without_trailing_slash, redirect_service)
-        } else {
-            self.route_service(&format!("{}/", path), redirect_service)
-        }
+        add_tsr_redirect_route(self, path)
     }
 
     #[track_caller]
@@ -287,19 +278,65 @@ where
         T::Future: Send + 'static,
         Self: Sized,
     {
+        validate_tsr_path(path);
         self = self.route_service(path, service);
+        add_tsr_redirect_route(self, path)
+    }
+}
 
-        let redirect = Redirect::permanent(path);
+#[track_caller]
+fn validate_tsr_path(path: &str) {
+    if path == "/" {
+        panic!("Cannot add a trailing slash redirect route for `/`")
+    }
+}
 
-        if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
-            self.route(
-                path_without_trailing_slash,
-                any(move || ready(redirect.clone())),
-            )
+fn add_tsr_redirect_route<S, B>(router: Router<S, B>, path: &str) -> Router<S, B>
+where
+    B: axum::body::HttpBody + Send + 'static,
+    S: Clone + Send + Sync + 'static,
+{
+    async fn redirect_handler(uri: Uri) -> Response {
+        let new_uri = map_path(uri, |path| {
+            path.strip_suffix('/')
+                .map(Cow::Borrowed)
+                .unwrap_or_else(|| Cow::Owned(format!("{path}/")))
+        });
+
+        if let Some(new_uri) = new_uri {
+            Redirect::permanent(&new_uri.to_string()).into_response()
         } else {
-            self.route(&format!("{}/", path), any(move || ready(redirect.clone())))
+            StatusCode::BAD_REQUEST.into_response()
         }
     }
+
+    if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
+        router.route(path_without_trailing_slash, any(redirect_handler))
+    } else {
+        router.route(&format!("{}/", path), any(redirect_handler))
+    }
+}
+
+/// Map the path of a `Uri`.
+///
+/// Returns `None` if the `Uri` cannot be put back together with the new path.
+fn map_path<F>(original_uri: Uri, f: F) -> Option<Uri>
+where
+    F: FnOnce(&str) -> Cow<'_, str>,
+{
+    let mut parts = original_uri.into_parts();
+    let path_and_query = parts.path_and_query.as_ref()?;
+
+    let new_path = f(path_and_query.path());
+
+    let new_path_and_query = if let Some(query) = &path_and_query.query() {
+        format!("{new_path}?{query}").parse::<PathAndQuery>().ok()?
+    } else {
+        new_path.parse::<PathAndQuery>().ok()?
+    };
+    parts.path_and_query = Some(new_path_and_query);
+
+    Uri::from_parts(parts).ok()
 }
 
 mod sealed {
@@ -311,7 +348,7 @@ mod sealed {
 mod tests {
     use super::*;
     use crate::test_helpers::*;
-    use axum::{http::StatusCode, routing::get};
+    use axum::{extract::Path, http::StatusCode, routing::get};
 
     #[tokio::test]
     async fn test_tsr() {
@@ -335,4 +372,52 @@ mod tests {
         assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
         assert_eq!(res.headers()["location"], "/bar/");
     }
+
+    #[tokio::test]
+    async fn tsr_with_params() {
+        let app = Router::new()
+            .route_with_tsr(
+                "/a/:a",
+                get(|Path(param): Path<String>| async move { param }),
+            )
+            .route_with_tsr(
+                "/b/:b/",
+                get(|Path(param): Path<String>| async move { param }),
+            );
+
+        let client = TestClient::new(app);
+
+        let res = client.get("/a/foo").send().await;
+        assert_eq!(res.status(), StatusCode::OK);
+        assert_eq!(res.text().await, "foo");
+
+        let res = client.get("/a/foo/").send().await;
+        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
+        assert_eq!(res.headers()["location"], "/a/foo");
+
+        let res = client.get("/b/foo/").send().await;
+        assert_eq!(res.status(), StatusCode::OK);
+        assert_eq!(res.text().await, "foo");
+
+        let res = client.get("/b/foo").send().await;
+        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
+        assert_eq!(res.headers()["location"], "/b/foo/");
+    }
+
+    #[tokio::test]
+    async fn tsr_maintains_query_params() {
+        let app = Router::new().route_with_tsr("/foo", get(|| async {}));
+
+        let client = TestClient::new(app);
+
+        let res = client.get("/foo/?a=a").send().await;
+        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
+        assert_eq!(res.headers()["location"], "/foo?a=a");
+    }
+
+    #[test]
+    #[should_panic = "Cannot add a trailing slash redirect route for `/`"]
+    fn tsr_at_root() {
+        let _: Router = Router::new().route_with_tsr("/", get(|| async move {}));
+    }
 }