Bug fixes for RouterExt:{route_with_tsr, route_service_with_tsr} (#1608)

* Bug fixes for `RouterExt:{route_with_tsr, route_service_with_tsr}`

* changelog link
This commit is contained in:
David Pedersen 2022-12-02 11:42:49 +01:00 committed by GitHub
parent 56d0dd9ec2
commit 7386e5d185
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 113 additions and 23 deletions

View file

@ -7,7 +7,12 @@ and this project adheres to [Semantic Versioning].
# Unreleased
- None.
- **fixed:** Bug fixes for `RouterExt:{route_with_tsr, route_service_with_tsr}` ([#1608]):
- Redirects to the correct URI if the route contains path parameters
- Keeps query parameters when redirecting
- Better improved error message if adding route for `/`
[#1608]: https://github.com/tokio-rs/axum/pull/1608
# 0.4.1 (29. November, 2022)

View file

@ -1,13 +1,13 @@
//! Additional types for defining routes.
use axum::{
handler::HandlerWithoutStateExt,
http::Request,
response::{IntoResponse, Redirect},
response::{IntoResponse, Redirect, Response},
routing::{any, MethodRouter},
Router,
};
use std::{convert::Infallible, future::ready, sync::Arc};
use http::{uri::PathAndQuery, StatusCode, Uri};
use std::{borrow::Cow, convert::Infallible};
use tower_service::Service;
mod resource;
@ -265,18 +265,9 @@ where
where
Self: Sized,
{
validate_tsr_path(path);
self = self.route(path, method_router);
let redirect_service = {
let path: Arc<str> = path.into();
(move || ready(Redirect::permanent(&path))).into_service()
};
if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
self.route_service(path_without_trailing_slash, redirect_service)
} else {
self.route_service(&format!("{}/", path), redirect_service)
}
add_tsr_redirect_route(self, path)
}
#[track_caller]
@ -287,19 +278,65 @@ where
T::Future: Send + 'static,
Self: Sized,
{
validate_tsr_path(path);
self = self.route_service(path, service);
add_tsr_redirect_route(self, path)
}
}
let redirect = Redirect::permanent(path);
#[track_caller]
fn validate_tsr_path(path: &str) {
if path == "/" {
panic!("Cannot add a trailing slash redirect route for `/`")
}
}
if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
self.route(
path_without_trailing_slash,
any(move || ready(redirect.clone())),
)
fn add_tsr_redirect_route<S, B>(router: Router<S, B>, path: &str) -> Router<S, B>
where
B: axum::body::HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
async fn redirect_handler(uri: Uri) -> Response {
let new_uri = map_path(uri, |path| {
path.strip_suffix('/')
.map(Cow::Borrowed)
.unwrap_or_else(|| Cow::Owned(format!("{path}/")))
});
if let Some(new_uri) = new_uri {
Redirect::permanent(&new_uri.to_string()).into_response()
} else {
self.route(&format!("{}/", path), any(move || ready(redirect.clone())))
StatusCode::BAD_REQUEST.into_response()
}
}
if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
router.route(path_without_trailing_slash, any(redirect_handler))
} else {
router.route(&format!("{}/", path), any(redirect_handler))
}
}
/// Map the path of a `Uri`.
///
/// Returns `None` if the `Uri` cannot be put back together with the new path.
fn map_path<F>(original_uri: Uri, f: F) -> Option<Uri>
where
F: FnOnce(&str) -> Cow<'_, str>,
{
let mut parts = original_uri.into_parts();
let path_and_query = parts.path_and_query.as_ref()?;
let new_path = f(path_and_query.path());
let new_path_and_query = if let Some(query) = &path_and_query.query() {
format!("{new_path}?{query}").parse::<PathAndQuery>().ok()?
} else {
new_path.parse::<PathAndQuery>().ok()?
};
parts.path_and_query = Some(new_path_and_query);
Uri::from_parts(parts).ok()
}
mod sealed {
@ -311,7 +348,7 @@ mod sealed {
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{http::StatusCode, routing::get};
use axum::{extract::Path, http::StatusCode, routing::get};
#[tokio::test]
async fn test_tsr() {
@ -335,4 +372,52 @@ mod tests {
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/bar/");
}
#[tokio::test]
async fn tsr_with_params() {
let app = Router::new()
.route_with_tsr(
"/a/:a",
get(|Path(param): Path<String>| async move { param }),
)
.route_with_tsr(
"/b/:b/",
get(|Path(param): Path<String>| async move { param }),
);
let client = TestClient::new(app);
let res = client.get("/a/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "foo");
let res = client.get("/a/foo/").send().await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/a/foo");
let res = client.get("/b/foo/").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "foo");
let res = client.get("/b/foo").send().await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/b/foo/");
}
#[tokio::test]
async fn tsr_maintains_query_params() {
let app = Router::new().route_with_tsr("/foo", get(|| async {}));
let client = TestClient::new(app);
let res = client.get("/foo/?a=a").send().await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/foo?a=a");
}
#[test]
#[should_panic = "Cannot add a trailing slash redirect route for `/`"]
fn tsr_at_root() {
let _: Router = Router::new().route_with_tsr("/", get(|| async move {}));
}
}