diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index d72d54f7..e11e2606 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -7,7 +7,12 @@ and this project adheres to [Semantic Versioning]. # Unreleased -- None. +- **fixed:** Bug fixes for `RouterExt:{route_with_tsr, route_service_with_tsr}` ([#1608]): + - Redirects to the correct URI if the route contains path parameters + - Keeps query parameters when redirecting + - Better improved error message if adding route for `/` + +[#1608]: https://github.com/tokio-rs/axum/pull/1608 # 0.4.1 (29. November, 2022) diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 23a9dde2..6d52c284 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -1,13 +1,13 @@ //! Additional types for defining routes. use axum::{ - handler::HandlerWithoutStateExt, http::Request, - response::{IntoResponse, Redirect}, + response::{IntoResponse, Redirect, Response}, routing::{any, MethodRouter}, Router, }; -use std::{convert::Infallible, future::ready, sync::Arc}; +use http::{uri::PathAndQuery, StatusCode, Uri}; +use std::{borrow::Cow, convert::Infallible}; use tower_service::Service; mod resource; @@ -265,18 +265,9 @@ where where Self: Sized, { + validate_tsr_path(path); self = self.route(path, method_router); - - let redirect_service = { - let path: Arc<str> = path.into(); - (move || ready(Redirect::permanent(&path))).into_service() - }; - - if let Some(path_without_trailing_slash) = path.strip_suffix('/') { - self.route_service(path_without_trailing_slash, redirect_service) - } else { - self.route_service(&format!("{}/", path), redirect_service) - } + add_tsr_redirect_route(self, path) } #[track_caller] @@ -287,19 +278,65 @@ where T::Future: Send + 'static, Self: Sized, { + validate_tsr_path(path); self = self.route_service(path, service); + add_tsr_redirect_route(self, path) + } +} - let redirect = Redirect::permanent(path); +#[track_caller] +fn validate_tsr_path(path: &str) { + if path == "/" { + panic!("Cannot add a trailing slash redirect route for `/`") + } +} - if let Some(path_without_trailing_slash) = path.strip_suffix('/') { - self.route( - path_without_trailing_slash, - any(move || ready(redirect.clone())), - ) +fn add_tsr_redirect_route<S, B>(router: Router<S, B>, path: &str) -> Router<S, B> +where + B: axum::body::HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, +{ + async fn redirect_handler(uri: Uri) -> Response { + let new_uri = map_path(uri, |path| { + path.strip_suffix('/') + .map(Cow::Borrowed) + .unwrap_or_else(|| Cow::Owned(format!("{path}/"))) + }); + + if let Some(new_uri) = new_uri { + Redirect::permanent(&new_uri.to_string()).into_response() } else { - self.route(&format!("{}/", path), any(move || ready(redirect.clone()))) + StatusCode::BAD_REQUEST.into_response() } } + + if let Some(path_without_trailing_slash) = path.strip_suffix('/') { + router.route(path_without_trailing_slash, any(redirect_handler)) + } else { + router.route(&format!("{}/", path), any(redirect_handler)) + } +} + +/// Map the path of a `Uri`. +/// +/// Returns `None` if the `Uri` cannot be put back together with the new path. +fn map_path<F>(original_uri: Uri, f: F) -> Option<Uri> +where + F: FnOnce(&str) -> Cow<'_, str>, +{ + let mut parts = original_uri.into_parts(); + let path_and_query = parts.path_and_query.as_ref()?; + + let new_path = f(path_and_query.path()); + + let new_path_and_query = if let Some(query) = &path_and_query.query() { + format!("{new_path}?{query}").parse::<PathAndQuery>().ok()? + } else { + new_path.parse::<PathAndQuery>().ok()? + }; + parts.path_and_query = Some(new_path_and_query); + + Uri::from_parts(parts).ok() } mod sealed { @@ -311,7 +348,7 @@ mod sealed { mod tests { use super::*; use crate::test_helpers::*; - use axum::{http::StatusCode, routing::get}; + use axum::{extract::Path, http::StatusCode, routing::get}; #[tokio::test] async fn test_tsr() { @@ -335,4 +372,52 @@ mod tests { assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); assert_eq!(res.headers()["location"], "/bar/"); } + + #[tokio::test] + async fn tsr_with_params() { + let app = Router::new() + .route_with_tsr( + "/a/:a", + get(|Path(param): Path<String>| async move { param }), + ) + .route_with_tsr( + "/b/:b/", + get(|Path(param): Path<String>| async move { param }), + ); + + let client = TestClient::new(app); + + let res = client.get("/a/foo").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "foo"); + + let res = client.get("/a/foo/").send().await; + assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); + assert_eq!(res.headers()["location"], "/a/foo"); + + let res = client.get("/b/foo/").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "foo"); + + let res = client.get("/b/foo").send().await; + assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); + assert_eq!(res.headers()["location"], "/b/foo/"); + } + + #[tokio::test] + async fn tsr_maintains_query_params() { + let app = Router::new().route_with_tsr("/foo", get(|| async {})); + + let client = TestClient::new(app); + + let res = client.get("/foo/?a=a").send().await; + assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); + assert_eq!(res.headers()["location"], "/foo?a=a"); + } + + #[test] + #[should_panic = "Cannot add a trailing slash redirect route for `/`"] + fn tsr_at_root() { + let _: Router = Router::new().route_with_tsr("/", get(|| async move {})); + } }