diff --git a/CHANGELOG.md b/CHANGELOG.md index beeb15da..9d393ded 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -126,6 +126,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **fixed:** Middleware that return early (such as `tower_http::auth::RequireAuthorization`) now no longer catch requests that would otherwise be 404s. They also work correctly with `Router::merge` (previously called `or`) ([#408]) +- **fixed:** Correctly handle trailing slashes in routes: + - If a route with a trailing slash exists and a request without a trailing + slash is received, axum will send a 301 redirection to the route with the + trailing slash. + - Or vice versa if a route without a trailing slash exists and a request + with a trailing slash is received. + - This can be overridden by explicitly defining two routes: One with and one + without trailing a slash. [#339]: https://github.com/tokio-rs/axum/pull/339 [#286]: https://github.com/tokio-rs/axum/pull/286 diff --git a/src/routing/future.rs b/src/routing/future.rs index 6b01c3aa..07ec279c 100644 --- a/src/routing/future.rs +++ b/src/routing/future.rs @@ -21,6 +21,20 @@ opaque_future! { >; } +impl RouterFuture { + pub(super) fn from_oneshot(future: Oneshot, Request>) -> Self { + Self { + future: futures_util::future::Either::Left(future), + } + } + + pub(super) fn from_response(response: Response) -> Self { + RouterFuture { + future: futures_util::future::Either::Right(std::future::ready(Ok(response))), + } + } +} + opaque_future! { /// Response future for [`Route`](super::Route). pub type RouteFuture = diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 0d54f97e..22eeba21 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -759,9 +759,7 @@ where .expect("no route for id. This is a bug in axum. Please file an issue") .clone(); - RouterFuture { - future: futures_util::future::Either::Left(route.oneshot(req)), - } + RouterFuture::from_oneshot(route.oneshot(req)) } } @@ -787,16 +785,30 @@ where let path = req.uri().path().to_string(); - if let Ok(match_) = self.node.at(&path) { - self.call_route(match_, req) - } else if let Some(fallback) = &self.fallback { - RouterFuture { - future: futures_util::future::Either::Left(fallback.clone().oneshot(req)), - } - } else { - let res = EmptyRouter::::not_found().call_sync(req); - RouterFuture { - future: futures_util::future::Either::Right(std::future::ready(Ok(res))), + match self.node.at(&path) { + Ok(match_) => self.call_route(match_, req), + Err(err) => { + if err.tsr() + // workaround for https://github.com/ibraheemdev/matchit/issues/7 + && path != "/" + { + let redirect_to = if let Some(without_tsr) = path.strip_suffix('/') { + with_path(req.uri(), without_tsr) + } else { + with_path(req.uri(), &format!("{}/", path)) + }; + let res = Response::builder() + .status(StatusCode::MOVED_PERMANENTLY) + .header(http::header::LOCATION, redirect_to.to_string()) + .body(crate::body::empty()) + .unwrap(); + RouterFuture::from_response(res) + } else if let Some(fallback) = &self.fallback { + RouterFuture::from_oneshot(fallback.clone().oneshot(req)) + } else { + let res = EmptyRouter::::not_found().call_sync(req); + RouterFuture::from_response(res) + } } } } diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 73b36692..4df09fe0 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -579,6 +579,45 @@ async fn middleware_that_return_early() { assert_eq!(client.get("/public").send().await.status(), StatusCode::OK); } +#[tokio::test] +async fn with_trailing_slash() { + let app = Router::new().route("/foo", get(|| async {})); + + let client = TestClient::new(app); + + // `TestClient` automatically follows redirects + let res = client.get("/foo/").send().await; + assert_eq!(res.status(), StatusCode::OK); +} + +#[tokio::test] +async fn without_trailing_slash() { + let app = Router::new().route("/foo/", get(|| async {})); + + let client = TestClient::new(app); + + // `TestClient` automatically follows redirects + let res = client.get("/foo").send().await; + assert_eq!(res.status(), StatusCode::OK); +} + +#[tokio::test] +async fn with_and_without_trailing_slash() { + let app = Router::new() + .route("/foo", get(|| async { "without tsr" })) + .route("/foo/", get(|| async { "with tsr" })); + + let client = TestClient::new(app); + + let res = client.get("/foo/").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "with tsr"); + + let res = client.get("/foo").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "without tsr"); +} + pub(crate) fn assert_send() {} pub(crate) fn assert_sync() {} pub(crate) fn assert_unpin() {}