diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 10e4ddcc..da9cdeca 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -33,7 +33,7 @@ http = "0.2.5" http-body = "0.4.4" hyper = { version = "0.14.14", features = ["server", "tcp", "stream"] } itoa = "1.0.1" -matchit = "0.4.6" +matchit = "0.5.0" memchr = "2.4.1" mime = "0.3.16" percent-encoding = "2.1" diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index d80a3fd1..7e957228 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -10,6 +10,7 @@ use crate::{ BoxError, }; use http::{Request, Uri}; +use matchit::MatchError; use std::{ borrow::Cow, collections::HashMap, @@ -502,22 +503,17 @@ where match self.node.at(&path) { Ok(match_) => self.call_route(match_, req), - Err(err) => { - if err.tsr() { - let redirect_to = if let Some(without_tsr) = path.strip_suffix('/') { - with_path(req.uri(), without_tsr) - } else { - with_path(req.uri(), &format!("{}/", path)) - }; - let res = Redirect::permanent(redirect_to); - RouteFuture::from_response(res.into_response()) - } else { - match &self.fallback { - Fallback::Default(inner) => inner.clone().call(req), - Fallback::Custom(inner) => inner.clone().call(req), - } - } - } + Err(MatchError::MissingTrailingSlash) => RouteFuture::from_response( + Redirect::permanent(with_path(req.uri(), &format!("{}/", path))).into_response(), + ), + Err(MatchError::ExtraTrailingSlash) => RouteFuture::from_response( + Redirect::permanent(with_path(req.uri(), path.strip_suffix('/').unwrap())) + .into_response(), + ), + Err(MatchError::NotFound) => match &self.fallback { + Fallback::Default(inner) => inner.clone().call(req), + Fallback::Custom(inner) => inner.clone().call(req), + }, } } } @@ -551,10 +547,10 @@ fn with_path(uri: &Uri, new_path: &str) -> Uri { Uri::from_parts(parts).unwrap() } -/// Wrapper around `matchit::Node` that supports merging two `Node`s. +/// Wrapper around `matchit::Router` that supports merging two `Router`s. #[derive(Clone, Default)] struct Node { - inner: matchit::Node, + inner: matchit::Router, route_id_to_path: HashMap>, path_to_route_id: HashMap, RouteId>, } @@ -579,7 +575,7 @@ impl Node { fn at<'n, 'p>( &'n self, path: &'p str, - ) -> Result, matchit::MatchError> { + ) -> Result, MatchError> { self.inner.at(path) } }