Correctly handle trailing slashes in routes (#410)

This commit is contained in:
David Pedersen 2021-10-25 22:42:49 +02:00 committed by GitHub
parent baf7cabfe1
commit 59819e42bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 86 additions and 13 deletions

View file

@ -126,6 +126,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **fixed:** Middleware that return early (such as `tower_http::auth::RequireAuthorization`)
now no longer catch requests that would otherwise be 404s. They also work
correctly with `Router::merge` (previously called `or`) ([#408])
- **fixed:** Correctly handle trailing slashes in routes:
- If a route with a trailing slash exists and a request without a trailing
slash is received, axum will send a 301 redirection to the route with the
trailing slash.
- Or vice versa if a route without a trailing slash exists and a request
with a trailing slash is received.
- This can be overridden by explicitly defining two routes: One with and one
without trailing a slash.
[#339]: https://github.com/tokio-rs/axum/pull/339
[#286]: https://github.com/tokio-rs/axum/pull/286

View file

@ -21,6 +21,20 @@ opaque_future! {
>;
}
impl<B> RouterFuture<B> {
pub(super) fn from_oneshot(future: Oneshot<super::Route<B>, Request<B>>) -> Self {
Self {
future: futures_util::future::Either::Left(future),
}
}
pub(super) fn from_response(response: Response<BoxBody>) -> Self {
RouterFuture {
future: futures_util::future::Either::Right(std::future::ready(Ok(response))),
}
}
}
opaque_future! {
/// Response future for [`Route`](super::Route).
pub type RouteFuture =

View file

@ -759,9 +759,7 @@ where
.expect("no route for id. This is a bug in axum. Please file an issue")
.clone();
RouterFuture {
future: futures_util::future::Either::Left(route.oneshot(req)),
}
RouterFuture::from_oneshot(route.oneshot(req))
}
}
@ -787,16 +785,30 @@ where
let path = req.uri().path().to_string();
if let Ok(match_) = self.node.at(&path) {
self.call_route(match_, req)
} else if let Some(fallback) = &self.fallback {
RouterFuture {
future: futures_util::future::Either::Left(fallback.clone().oneshot(req)),
}
} else {
let res = EmptyRouter::<Infallible>::not_found().call_sync(req);
RouterFuture {
future: futures_util::future::Either::Right(std::future::ready(Ok(res))),
match self.node.at(&path) {
Ok(match_) => self.call_route(match_, req),
Err(err) => {
if err.tsr()
// workaround for https://github.com/ibraheemdev/matchit/issues/7
&& path != "/"
{
let redirect_to = if let Some(without_tsr) = path.strip_suffix('/') {
with_path(req.uri(), without_tsr)
} else {
with_path(req.uri(), &format!("{}/", path))
};
let res = Response::builder()
.status(StatusCode::MOVED_PERMANENTLY)
.header(http::header::LOCATION, redirect_to.to_string())
.body(crate::body::empty())
.unwrap();
RouterFuture::from_response(res)
} else if let Some(fallback) = &self.fallback {
RouterFuture::from_oneshot(fallback.clone().oneshot(req))
} else {
let res = EmptyRouter::<Infallible>::not_found().call_sync(req);
RouterFuture::from_response(res)
}
}
}
}

View file

@ -579,6 +579,45 @@ async fn middleware_that_return_early() {
assert_eq!(client.get("/public").send().await.status(), StatusCode::OK);
}
#[tokio::test]
async fn with_trailing_slash() {
let app = Router::new().route("/foo", get(|| async {}));
let client = TestClient::new(app);
// `TestClient` automatically follows redirects
let res = client.get("/foo/").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn without_trailing_slash() {
let app = Router::new().route("/foo/", get(|| async {}));
let client = TestClient::new(app);
// `TestClient` automatically follows redirects
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn with_and_without_trailing_slash() {
let app = Router::new()
.route("/foo", get(|| async { "without tsr" }))
.route("/foo/", get(|| async { "with tsr" }));
let client = TestClient::new(app);
let res = client.get("/foo/").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "with tsr");
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "without tsr");
}
pub(crate) fn assert_send<T: Send>() {}
pub(crate) fn assert_sync<T: Sync>() {}
pub(crate) fn assert_unpin<T: Unpin>() {}