Fix trailing redirection with query parameters (#936)

* Fix trailing redirection with query parameters

When the request URI matches a route that need a trailing slash, or has an extra trailing slash, the redirect URI is not generated correctly.

This change adds or removes a trailing slash to the path part of the URI, instead of the full URI, preserving query parameters during redirection.

Signed-off-by: David Calavera <david.calavera@gmail.com>

* Make trailing slash logic safer

Extract parts from Uri and recreate it, so it doesn't bump
into corner cases with string manipulation.

Signed-off-by: David Calavera <david.calavera@gmail.com>

* Remove extra assignment.

Signed-off-by: David Calavera <david.calavera@gmail.com>

* Update axum/src/routing/mod.rs

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>

* changelog

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
David Calavera 2022-04-17 14:12:52 -07:00 committed by GitHub
parent 0313c08dc9
commit 83e1a15040
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 8 deletions

View file

@ -9,8 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Add `axum::extract::multipart::Field::chunk` method for streaming a single chunk from - **added:** Add `axum::extract::multipart::Field::chunk` method for streaming a single chunk from
the field ([#901]) the field ([#901])
- **fixed:** Fix trailing slash redirection with query parameters ([#936])
[#901]: https://github.com/tokio-rs/axum/pull/901 [#901]: https://github.com/tokio-rs/axum/pull/901
[#936]: https://github.com/tokio-rs/axum/pull/936
# 0.5.1 (03. April, 2022) # 0.5.1 (03. April, 2022)

View file

@ -9,7 +9,7 @@ use crate::{
util::try_downcast, util::try_downcast,
BoxError, BoxError,
}; };
use http::Request; use http::{Request, Uri};
use matchit::MatchError; use matchit::MatchError;
use std::{ use std::{
borrow::Cow, borrow::Cow,
@ -497,13 +497,20 @@ where
match self.node.at(&path) { match self.node.at(&path) {
Ok(match_) => self.call_route(match_, req), Ok(match_) => self.call_route(match_, req),
Err(MatchError::MissingTrailingSlash) => RouteFuture::from_response( Err(MatchError::MissingTrailingSlash) => {
Redirect::permanent(&format!("{}/", req.uri())).into_response(), let new_uri = replace_trailing_slash(req.uri(), &format!("{}/", &path));
),
Err(MatchError::ExtraTrailingSlash) => RouteFuture::from_response( RouteFuture::from_response(
Redirect::permanent(req.uri().to_string().strip_suffix('/').unwrap()) Redirect::permanent(&new_uri.to_string()).into_response(),
.into_response(), )
), }
Err(MatchError::ExtraTrailingSlash) => {
let new_uri = replace_trailing_slash(req.uri(), &path.strip_suffix('/').unwrap());
RouteFuture::from_response(
Redirect::permanent(&new_uri.to_string()).into_response(),
)
}
Err(MatchError::NotFound) => match &self.fallback { Err(MatchError::NotFound) => match &self.fallback {
Fallback::Default(inner) => inner.clone().call(req), Fallback::Default(inner) => inner.clone().call(req),
Fallback::Custom(inner) => inner.clone().call(req), Fallback::Custom(inner) => inner.clone().call(req),
@ -512,6 +519,19 @@ where
} }
} }
fn replace_trailing_slash(uri: &Uri, new_path: &str) -> Uri {
let mut new_path_and_query = new_path.to_string();
if let Some(query) = uri.query() {
new_path_and_query.push('?');
new_path_and_query.push_str(query);
}
let mut parts = uri.clone().into_parts();
parts.path_and_query = Some(new_path_and_query.parse().unwrap());
Uri::from_parts(parts).unwrap()
}
/// Wrapper around `matchit::Router` that supports merging two `Router`s. /// Wrapper around `matchit::Router` that supports merging two `Router`s.
#[derive(Clone, Default)] #[derive(Clone, Default)]
struct Node { struct Node {

View file

@ -324,6 +324,10 @@ async fn with_trailing_slash() {
let res = client.get("/foo/").send().await; let res = client.get("/foo/").send().await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers().get("location").unwrap(), "/foo"); assert_eq!(res.headers().get("location").unwrap(), "/foo");
let res = client.get("/foo/?bar=baz").send().await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers().get("location").unwrap(), "/foo?bar=baz");
} }
#[tokio::test] #[tokio::test]
@ -335,6 +339,10 @@ async fn without_trailing_slash() {
let res = client.get("/foo").send().await; let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers().get("location").unwrap(), "/foo/"); assert_eq!(res.headers().get("location").unwrap(), "/foo/");
let res = client.get("/foo?bar=baz").send().await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers().get("location").unwrap(), "/foo/?bar=baz");
} }
#[tokio::test] #[tokio::test]