mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-16 06:30:39 +01:00
Bug fixes for RouterExt:{route_with_tsr, route_service_with_tsr}
(#1608)
* Bug fixes for `RouterExt:{route_with_tsr, route_service_with_tsr}` * changelog link
This commit is contained in:
parent
56d0dd9ec2
commit
7386e5d185
2 changed files with 113 additions and 23 deletions
|
@ -7,7 +7,12 @@ and this project adheres to [Semantic Versioning].
|
|||
|
||||
# Unreleased
|
||||
|
||||
- None.
|
||||
- **fixed:** Bug fixes for `RouterExt:{route_with_tsr, route_service_with_tsr}` ([#1608]):
|
||||
- Redirects to the correct URI if the route contains path parameters
|
||||
- Keeps query parameters when redirecting
|
||||
- Better improved error message if adding route for `/`
|
||||
|
||||
[#1608]: https://github.com/tokio-rs/axum/pull/1608
|
||||
|
||||
# 0.4.1 (29. November, 2022)
|
||||
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
//! Additional types for defining routes.
|
||||
|
||||
use axum::{
|
||||
handler::HandlerWithoutStateExt,
|
||||
http::Request,
|
||||
response::{IntoResponse, Redirect},
|
||||
response::{IntoResponse, Redirect, Response},
|
||||
routing::{any, MethodRouter},
|
||||
Router,
|
||||
};
|
||||
use std::{convert::Infallible, future::ready, sync::Arc};
|
||||
use http::{uri::PathAndQuery, StatusCode, Uri};
|
||||
use std::{borrow::Cow, convert::Infallible};
|
||||
use tower_service::Service;
|
||||
|
||||
mod resource;
|
||||
|
@ -265,18 +265,9 @@ where
|
|||
where
|
||||
Self: Sized,
|
||||
{
|
||||
validate_tsr_path(path);
|
||||
self = self.route(path, method_router);
|
||||
|
||||
let redirect_service = {
|
||||
let path: Arc<str> = path.into();
|
||||
(move || ready(Redirect::permanent(&path))).into_service()
|
||||
};
|
||||
|
||||
if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
|
||||
self.route_service(path_without_trailing_slash, redirect_service)
|
||||
} else {
|
||||
self.route_service(&format!("{}/", path), redirect_service)
|
||||
}
|
||||
add_tsr_redirect_route(self, path)
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
|
@ -287,19 +278,65 @@ where
|
|||
T::Future: Send + 'static,
|
||||
Self: Sized,
|
||||
{
|
||||
validate_tsr_path(path);
|
||||
self = self.route_service(path, service);
|
||||
add_tsr_redirect_route(self, path)
|
||||
}
|
||||
}
|
||||
|
||||
let redirect = Redirect::permanent(path);
|
||||
#[track_caller]
|
||||
fn validate_tsr_path(path: &str) {
|
||||
if path == "/" {
|
||||
panic!("Cannot add a trailing slash redirect route for `/`")
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
|
||||
self.route(
|
||||
path_without_trailing_slash,
|
||||
any(move || ready(redirect.clone())),
|
||||
)
|
||||
fn add_tsr_redirect_route<S, B>(router: Router<S, B>, path: &str) -> Router<S, B>
|
||||
where
|
||||
B: axum::body::HttpBody + Send + 'static,
|
||||
S: Clone + Send + Sync + 'static,
|
||||
{
|
||||
async fn redirect_handler(uri: Uri) -> Response {
|
||||
let new_uri = map_path(uri, |path| {
|
||||
path.strip_suffix('/')
|
||||
.map(Cow::Borrowed)
|
||||
.unwrap_or_else(|| Cow::Owned(format!("{path}/")))
|
||||
});
|
||||
|
||||
if let Some(new_uri) = new_uri {
|
||||
Redirect::permanent(&new_uri.to_string()).into_response()
|
||||
} else {
|
||||
self.route(&format!("{}/", path), any(move || ready(redirect.clone())))
|
||||
StatusCode::BAD_REQUEST.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
|
||||
router.route(path_without_trailing_slash, any(redirect_handler))
|
||||
} else {
|
||||
router.route(&format!("{}/", path), any(redirect_handler))
|
||||
}
|
||||
}
|
||||
|
||||
/// Map the path of a `Uri`.
|
||||
///
|
||||
/// Returns `None` if the `Uri` cannot be put back together with the new path.
|
||||
fn map_path<F>(original_uri: Uri, f: F) -> Option<Uri>
|
||||
where
|
||||
F: FnOnce(&str) -> Cow<'_, str>,
|
||||
{
|
||||
let mut parts = original_uri.into_parts();
|
||||
let path_and_query = parts.path_and_query.as_ref()?;
|
||||
|
||||
let new_path = f(path_and_query.path());
|
||||
|
||||
let new_path_and_query = if let Some(query) = &path_and_query.query() {
|
||||
format!("{new_path}?{query}").parse::<PathAndQuery>().ok()?
|
||||
} else {
|
||||
new_path.parse::<PathAndQuery>().ok()?
|
||||
};
|
||||
parts.path_and_query = Some(new_path_and_query);
|
||||
|
||||
Uri::from_parts(parts).ok()
|
||||
}
|
||||
|
||||
mod sealed {
|
||||
|
@ -311,7 +348,7 @@ mod sealed {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::test_helpers::*;
|
||||
use axum::{http::StatusCode, routing::get};
|
||||
use axum::{extract::Path, http::StatusCode, routing::get};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tsr() {
|
||||
|
@ -335,4 +372,52 @@ mod tests {
|
|||
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
|
||||
assert_eq!(res.headers()["location"], "/bar/");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tsr_with_params() {
|
||||
let app = Router::new()
|
||||
.route_with_tsr(
|
||||
"/a/:a",
|
||||
get(|Path(param): Path<String>| async move { param }),
|
||||
)
|
||||
.route_with_tsr(
|
||||
"/b/:b/",
|
||||
get(|Path(param): Path<String>| async move { param }),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/a/foo").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await, "foo");
|
||||
|
||||
let res = client.get("/a/foo/").send().await;
|
||||
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
|
||||
assert_eq!(res.headers()["location"], "/a/foo");
|
||||
|
||||
let res = client.get("/b/foo/").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await, "foo");
|
||||
|
||||
let res = client.get("/b/foo").send().await;
|
||||
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
|
||||
assert_eq!(res.headers()["location"], "/b/foo/");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tsr_maintains_query_params() {
|
||||
let app = Router::new().route_with_tsr("/foo", get(|| async {}));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo/?a=a").send().await;
|
||||
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
|
||||
assert_eq!(res.headers()["location"], "/foo?a=a");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Cannot add a trailing slash redirect route for `/`"]
|
||||
fn tsr_at_root() {
|
||||
let _: Router = Router::new().route_with_tsr("/", get(|| async move {}));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue