1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-04-26 13:56:22 +02:00

axum/routing: Merge fallbacks with the rest of the router

This commit is contained in:
David Mládek 2025-01-08 00:02:49 +01:00
parent b5236eaff4
commit 8dbe9114b5
5 changed files with 252 additions and 104 deletions
axum/src

View file

@ -36,8 +36,7 @@ documentation for more details.
It is not possible to create segments that only match some types like numbers or
regular expression. You must handle that manually in your handlers.
[`MatchedPath`](crate::extract::MatchedPath) can be used to extract the matched
path rather than the actual path.
[`MatchedPath`] can be used to extract the matched path rather than the actual path.
# Wildcards

View file

@ -3,6 +3,8 @@
use self::{future::RouteFuture, not_found::NotFound, path_router::PathRouter};
#[cfg(feature = "tokio")]
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
#[cfg(feature = "matched-path")]
use crate::extract::MatchedPath;
use crate::{
body::{Body, HttpBody},
boxed::BoxedIntoRoute,
@ -20,7 +22,8 @@ use std::{
sync::Arc,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower::service_fn;
use tower_layer::{layer_fn, Layer};
use tower_service::Service;
pub mod future;
@ -72,8 +75,7 @@ impl<S> Clone for Router<S> {
}
struct RouterInner<S> {
path_router: PathRouter<S, false>,
fallback_router: PathRouter<S, true>,
path_router: PathRouter<S>,
default_fallback: bool,
catch_all_fallback: Fallback<S>,
}
@ -91,7 +93,6 @@ impl<S> fmt::Debug for Router<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Router")
.field("path_router", &self.inner.path_router)
.field("fallback_router", &self.inner.fallback_router)
.field("default_fallback", &self.inner.default_fallback)
.field("catch_all_fallback", &self.inner.catch_all_fallback)
.finish()
@ -141,7 +142,6 @@ where
Self {
inner: Arc::new(RouterInner {
path_router: Default::default(),
fallback_router: PathRouter::new_fallback(),
default_fallback: true,
catch_all_fallback: Fallback::Default(Route::new(NotFound)),
}),
@ -153,7 +153,6 @@ where
Ok(inner) => inner,
Err(arc) => RouterInner {
path_router: arc.path_router.clone(),
fallback_router: arc.fallback_router.clone(),
default_fallback: arc.default_fallback,
catch_all_fallback: arc.catch_all_fallback.clone(),
},
@ -207,8 +206,7 @@ where
let RouterInner {
path_router,
fallback_router,
default_fallback,
default_fallback: _,
// we don't need to inherit the catch-all fallback. It is only used for CONNECT
// requests with an empty path. If we were to inherit the catch-all fallback
// it would end up matching `/{path}/*` which doesn't match empty paths.
@ -217,10 +215,6 @@ where
tap_inner!(self, mut this => {
panic_on_err!(this.path_router.nest(path, path_router));
if !default_fallback {
panic_on_err!(this.fallback_router.nest(path, fallback_router));
}
})
}
@ -247,36 +241,24 @@ where
where
R: Into<Router<S>>,
{
const PANIC_MSG: &str =
"Failed to merge fallbacks. This is a bug in axum. Please file an issue";
let other: Router<S> = other.into();
let RouterInner {
path_router,
fallback_router: mut other_fallback,
default_fallback,
catch_all_fallback,
} = other.into_inner();
map_inner!(self, mut this => {
panic_on_err!(this.path_router.merge(path_router));
match (this.default_fallback, default_fallback) {
// both have the default fallback
// use the one from other
(true, true) => {
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
}
(true, true) => {}
// this has default fallback, other has a custom fallback
(true, false) => {
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
this.default_fallback = false;
}
// this has a custom fallback, other has a default
(false, true) => {
let fallback_router = std::mem::take(&mut this.fallback_router);
other_fallback.merge(fallback_router).expect(PANIC_MSG);
this.fallback_router = other_fallback;
}
// both have a custom fallback, not allowed
(false, false) => {
@ -284,6 +266,8 @@ where
}
};
panic_on_err!(this.path_router.merge(path_router));
this.catch_all_fallback = this
.catch_all_fallback
.merge(catch_all_fallback)
@ -304,7 +288,6 @@ where
{
map_inner!(self, this => RouterInner {
path_router: this.path_router.layer(layer.clone()),
fallback_router: this.fallback_router.layer(layer.clone()),
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback.map(|route| route.layer(layer)),
})
@ -322,7 +305,6 @@ where
{
map_inner!(self, this => RouterInner {
path_router: this.path_router.route_layer(layer),
fallback_router: this.fallback_router,
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback,
})
@ -376,8 +358,51 @@ where
}
fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self {
// TODO make this better, get rid of the `unwrap`s.
// We need the returned `Service` to be `Clone` and the function inside `service_fn` to be
// `FnMut` so instead of just using the owned service, we do this trick with `Option`. We
// know this will be called just once so it's fine. We're doing that so that we avoid one
// clone inside `oneshot_inner` so that the `Router` and subsequently the `State` is not
// cloned too much.
tap_inner!(self, mut this => {
this.fallback_router.set_fallback(endpoint);
_ = this.path_router.route_endpoint(
"/",
endpoint.clone().layer(
layer_fn(
|service: Route| {
let mut service = Some(service);
service_fn(
#[cfg_attr(not(feature = "matched-path"), allow(unused_mut))]
move |mut request: Request| {
#[cfg(feature = "matched-path")]
request.extensions_mut().remove::<MatchedPath>();
service.take().unwrap().oneshot_inner_owned(request)
}
)
}
)
)
);
_ = this.path_router.route_endpoint(
FALLBACK_PARAM_PATH,
endpoint.layer(
layer_fn(
|service: Route| {
let mut service = Some(service);
service_fn(
#[cfg_attr(not(feature = "matched-path"), allow(unused_mut))]
move |mut request: Request| {
#[cfg(feature = "matched-path")]
request.extensions_mut().remove::<MatchedPath>();
service.take().unwrap().oneshot_inner_owned(request)
}
)
}
)
)
);
this.default_fallback = false;
})
}
@ -386,7 +411,6 @@ where
pub fn with_state<S2>(self, state: S) -> Router<S2> {
map_inner!(self, this => RouterInner {
path_router: this.path_router.with_state(state.clone()),
fallback_router: this.fallback_router.with_state(state.clone()),
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback.with_state(state),
})
@ -398,11 +422,6 @@ where
Err((req, state)) => (req, state),
};
let (req, state) = match self.inner.fallback_router.call_with_state(req, state) {
Ok(future) => return future,
Err((req, state)) => (req, state),
};
self.inner
.catch_all_fallback
.clone()

View file

@ -9,33 +9,17 @@ use tower_layer::Layer;
use tower_service::Service;
use super::{
future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM,
future::RouteFuture, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route,
RouteId, NEST_TAIL_PARAM,
};
pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
pub(super) struct PathRouter<S> {
routes: HashMap<RouteId, Endpoint<S>>,
node: Arc<Node>,
prev_route_id: RouteId,
v7_checks: bool,
}
impl<S> PathRouter<S, true>
where
S: Clone + Send + Sync + 'static,
{
pub(super) fn new_fallback() -> Self {
let mut this = Self::default();
this.set_fallback(Endpoint::Route(Route::new(NotFound)));
this
}
pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S>) {
self.replace_endpoint("/", endpoint.clone());
self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint);
}
}
fn validate_path(v7_checks: bool, path: &str) -> Result<(), &'static str> {
if path.is_empty() {
return Err("Paths must start with a `/`. Use \"/\" for root routes");
@ -72,7 +56,7 @@ fn validate_v07_paths(path: &str) -> Result<(), &'static str> {
.unwrap_or(Ok(()))
}
impl<S, const IS_FALLBACK: bool> PathRouter<S, IS_FALLBACK>
impl<S> PathRouter<S>
where
S: Clone + Send + Sync + 'static,
{
@ -159,10 +143,7 @@ where
.map_err(|err| format!("Invalid route {path:?}: {err}"))
}
pub(super) fn merge(
&mut self,
other: PathRouter<S, IS_FALLBACK>,
) -> Result<(), Cow<'static, str>> {
pub(super) fn merge(&mut self, other: PathRouter<S>) -> Result<(), Cow<'static, str>> {
let PathRouter {
routes,
node,
@ -179,24 +160,9 @@ where
.get(&id)
.expect("no path for route id. This is a bug in axum. Please file an issue");
if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) {
// when merging two routers it doesn't matter if you do `a.merge(b)` or
// `b.merge(a)`. This must also be true for fallbacks.
//
// However all fallback routers will have routes for `/` and `/*` so when merging
// we have to ignore the top level fallbacks on one side otherwise we get
// conflicts.
//
// `Router::merge` makes sure that when merging fallbacks `other` always has the
// fallback we want to keep. It panics if both routers have a custom fallback. Thus
// it is always okay to ignore one fallback and `Router::merge` also makes sure the
// one we can ignore is that of `self`.
self.replace_endpoint(path, route);
} else {
match route {
Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
Endpoint::Route(route) => self.route_service(path, route)?,
}
match route {
Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
Endpoint::Route(route) => self.route_service(path, route)?,
}
}
@ -206,7 +172,7 @@ where
pub(super) fn nest(
&mut self,
path_to_nest_at: &str,
router: PathRouter<S, IS_FALLBACK>,
router: PathRouter<S>,
) -> Result<(), Cow<'static, str>> {
let prefix = validate_nest_path(self.v7_checks, path_to_nest_at);
@ -282,7 +248,7 @@ where
Ok(())
}
pub(super) fn layer<L>(self, layer: L) -> PathRouter<S, IS_FALLBACK>
pub(super) fn layer<L>(self, layer: L) -> PathRouter<S>
where
L: Layer<Route> + Clone + Send + Sync + 'static,
L::Service: Service<Request> + Clone + Send + Sync + 'static,
@ -344,7 +310,7 @@ where
!self.routes.is_empty()
}
pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, IS_FALLBACK> {
pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2> {
let routes = self
.routes
.into_iter()
@ -388,14 +354,12 @@ where
Ok(match_) => {
let id = *match_.value;
if !IS_FALLBACK {
#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
&mut parts.extensions,
);
}
#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
&mut parts.extensions,
);
url_params::insert_url_params(&mut parts.extensions, match_.params);
@ -418,18 +382,6 @@ where
}
}
pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S>) {
match self.node.at(path) {
Ok(match_) => {
let id = *match_.value;
self.routes.insert(id, endpoint);
}
Err(_) => self
.route_endpoint(path, endpoint)
.expect("path wasn't matched so endpoint shouldn't exist"),
}
}
fn next_route_id(&mut self) -> RouteId {
let next_id = self
.prev_route_id
@ -441,7 +393,7 @@ where
}
}
impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
impl<S> Default for PathRouter<S> {
fn default() -> Self {
Self {
routes: Default::default(),
@ -452,7 +404,7 @@ impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
}
}
impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
impl<S> fmt::Debug for PathRouter<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PathRouter")
.field("routes", &self.routes)
@ -461,7 +413,7 @@ impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
}
}
impl<S, const IS_FALLBACK: bool> Clone for PathRouter<S, IS_FALLBACK> {
impl<S> Clone for PathRouter<S> {
fn clone(&self) -> Self {
Self {
routes: self.routes.clone(),

View file

@ -407,6 +407,87 @@ async fn what_matches_wildcard() {
assert_eq!(get("/x/a/b/").await, "x");
}
#[should_panic(
expected = "Invalid route \"/{*wild}\": Insertion failed due to conflict with previously registered route: /{*__private__axum_fallback}"
)]
#[test]
fn colliding_fallback_with_wildcard() {
_ = Router::<()>::new()
.fallback(|| async { "fallback" })
.route("/{*wild}", get(|| async { "wildcard" }));
}
// We might want to reject this too
#[crate::test]
async fn colliding_wildcard_with_fallback() {
let router = Router::new()
.route("/{*wild}", get(|| async { "wildcard" }))
.fallback(|| async { "fallback" });
let client = TestClient::new(router);
let res = client.get("/").await;
let body = res.text().await;
assert_eq!(body, "fallback");
let res = client.get("/x").await;
let body = res.text().await;
assert_eq!(body, "wildcard");
}
// We might want to reject this too
#[crate::test]
async fn colliding_fallback_with_fallback() {
let router = Router::new()
.fallback(|| async { "fallback1" })
.fallback(|| async { "fallback2" });
let client = TestClient::new(router);
let res = client.get("/").await;
let body = res.text().await;
assert_eq!(body, "fallback1");
let res = client.get("/x").await;
let body = res.text().await;
assert_eq!(body, "fallback1");
}
#[crate::test]
async fn colliding_root_with_fallback() {
let router = Router::new()
.route("/", get(|| async { "root" }))
.fallback(|| async { "fallback" });
let client = TestClient::new(router);
let res = client.get("/").await;
let body = res.text().await;
assert_eq!(body, "root");
let res = client.get("/x").await;
let body = res.text().await;
assert_eq!(body, "fallback");
}
#[crate::test]
async fn colliding_fallback_with_root() {
let router = Router::new()
.fallback(|| async { "fallback" })
.route("/", get(|| async { "root" }));
let client = TestClient::new(router);
// This works because fallback registers `any` so the `get` gets merged into it.
let res = client.get("/").await;
let body = res.text().await;
assert_eq!(body, "root");
let res = client.get("/x").await;
let body = res.text().await;
assert_eq!(body, "fallback");
}
#[crate::test]
async fn static_and_dynamic_paths() {
let app = Router::new()

View file

@ -387,3 +387,100 @@ async fn colon_in_route() {
async fn asterisk_in_route() {
_ = Router::<()>::new().nest("/*foo", Router::new());
}
#[crate::test]
async fn nesting_router_with_fallback() {
let nested = Router::new().fallback(|| async { "nested" });
let router = Router::new().route("/{x}/{y}", get(|| async { "two segments" }));
let client = TestClient::new(router.nest("/nest", nested));
let res = client.get("/a/b").await;
let body = res.text().await;
assert_eq!(body, "two segments");
let res = client.get("/nest/b").await;
let body = res.text().await;
assert_eq!(body, "nested");
}
#[crate::test]
async fn defining_missing_routes_in_nested_router() {
let router = Router::new()
.route("/nest/before", get(|| async { "before" }))
.nest(
"/nest",
Router::new()
.route("/mid", get(|| async { "nested mid" }))
.fallback(|| async { "nested fallback" }),
)
.route("/nest/after", get(|| async { "after" }));
let client = TestClient::new(router);
let res = client.get("/nest/before").await;
let body = res.text().await;
assert_eq!(body, "before");
let res = client.get("/nest/after").await;
let body = res.text().await;
assert_eq!(body, "after");
let res = client.get("/nest/mid").await;
let body = res.text().await;
assert_eq!(body, "nested mid");
let res = client.get("/nest/fallback").await;
let body = res.text().await;
assert_eq!(body, "nested fallback");
}
#[test]
#[should_panic(
expected = "Overlapping method route. Handler for `GET /nest/override` already exists"
)]
fn overriding_by_nested_router() {
_ = Router::<()>::new()
.route("/nest/override", get(|| async { "outer" }))
.nest(
"/nest",
Router::new().route("/override", get(|| async { "inner" })),
);
}
#[test]
#[should_panic(
expected = "Overlapping method route. Handler for `GET /nest/override` already exists"
)]
fn overriding_nested_router_() {
_ = Router::<()>::new()
.nest(
"/nest",
Router::new().route("/override", get(|| async { "inner" })),
)
.route("/nest/override", get(|| async { "outer" }));
}
// This is just documenting current state, not intended behavior.
#[crate::test]
async fn overriding_nested_service_router() {
let router = Router::new()
.route("/nest/before", get(|| async { "outer" }))
.nest_service(
"/nest",
Router::new()
.route("/before", get(|| async { "inner" }))
.route("/after", get(|| async { "inner" })),
)
.route("/nest/after", get(|| async { "outer" }));
let client = TestClient::new(router);
let res = client.get("/nest/before").await;
let body = res.text().await;
assert_eq!(body, "outer");
let res = client.get("/nest/after").await;
let body = res.text().await;
assert_eq!(body, "outer");
}