mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-16 22:43:03 +01:00
Add fallback inheritance for nested routers (#1521)
* fallback inheritance * cleanup * changelog
This commit is contained in:
parent
2e8a7e51a1
commit
7090649377
7 changed files with 218 additions and 60 deletions
|
@ -7,8 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
# Unreleased
|
||||
|
||||
- **fixed:** Nested routers will now inherit fallbacks from outer routers ([#1521])
|
||||
- **added:** Add `accept_unmasked_frames` setting in WebSocketUpgrade ([#1529])
|
||||
|
||||
[#1521]: https://github.com/tokio-rs/axum/pull/1521
|
||||
|
||||
# 0.6.0-rc.4 (9. November, 2022)
|
||||
|
||||
- **changed**: The inner error of a `JsonRejection` is now
|
||||
|
|
|
@ -90,8 +90,8 @@ let app = Router::new()
|
|||
|
||||
# Fallbacks
|
||||
|
||||
When nesting a router, if a request matches the prefix but the nested router doesn't have a matching
|
||||
route, the outer fallback will _not_ be called:
|
||||
If a nested router doesn't have its own fallback then it will inherit the
|
||||
fallback from the outer router:
|
||||
|
||||
```rust
|
||||
use axum::{routing::get, http::StatusCode, handler::Handler, Router};
|
||||
|
@ -100,7 +100,7 @@ async fn fallback() -> (StatusCode, &'static str) {
|
|||
(StatusCode::NOT_FOUND, "Not Found")
|
||||
}
|
||||
|
||||
let api_routes = Router::new().nest_service("/users", get(|| async {}));
|
||||
let api_routes = Router::new().route("/users", get(|| async {}));
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/api", api_routes)
|
||||
|
@ -108,30 +108,35 @@ let app = Router::new()
|
|||
# let _: Router = app;
|
||||
```
|
||||
|
||||
Here requests like `GET /api/not-found` will go into `api_routes` and then to
|
||||
the fallback of `api_routes` which will return an empty `404 Not Found`
|
||||
response. The outer fallback declared on `app` will _not_ be called.
|
||||
Here requests like `GET /api/not-found` will go into `api_routes` but because
|
||||
it doesn't have a matching route and doesn't have its own fallback it will call
|
||||
the fallback from the outer router, i.e. the `fallback` function.
|
||||
|
||||
Think of nested services as swallowing requests that matches the prefix and
|
||||
not falling back to outer router even if they don't have a matching route.
|
||||
|
||||
You can still add separate fallbacks to nested routers:
|
||||
If the nested router has its own fallback then the outer fallback will not be
|
||||
inherited:
|
||||
|
||||
```rust
|
||||
use axum::{routing::get, http::StatusCode, handler::Handler, Json, Router};
|
||||
use serde_json::{json, Value};
|
||||
use axum::{
|
||||
routing::get,
|
||||
http::StatusCode,
|
||||
handler::Handler,
|
||||
Json,
|
||||
Router,
|
||||
};
|
||||
|
||||
async fn fallback() -> (StatusCode, &'static str) {
|
||||
(StatusCode::NOT_FOUND, "Not Found")
|
||||
}
|
||||
|
||||
async fn api_fallback() -> (StatusCode, Json<Value>) {
|
||||
(StatusCode::NOT_FOUND, Json(json!({ "error": "Not Found" })))
|
||||
async fn api_fallback() -> (StatusCode, Json<serde_json::Value>) {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({ "status": "Not Found" })),
|
||||
)
|
||||
}
|
||||
|
||||
let api_routes = Router::new()
|
||||
.nest_service("/users", get(|| async {}))
|
||||
// add dedicated fallback for requests starting with `/api`
|
||||
.route("/users", get(|| async {}))
|
||||
.fallback(api_fallback);
|
||||
|
||||
let app = Router::new()
|
||||
|
@ -140,6 +145,8 @@ let app = Router::new()
|
|||
# let _: Router = app;
|
||||
```
|
||||
|
||||
Here requests like `GET /api/not-found` will go to `api_fallback`.
|
||||
|
||||
# Panics
|
||||
|
||||
- If the route overlaps with another route. See [`Router::route`]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
//! Route to services and handlers based on HTTP methods.
|
||||
|
||||
use super::IntoMakeService;
|
||||
use super::{FallbackRoute, IntoMakeService};
|
||||
#[cfg(feature = "tokio")]
|
||||
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
|
||||
use crate::{
|
||||
|
@ -744,7 +744,7 @@ where
|
|||
post: self.post.into_route(&state),
|
||||
put: self.put.into_route(&state),
|
||||
trace: self.trace.into_route(&state),
|
||||
fallback: self.fallback.into_route(&state),
|
||||
fallback: self.fallback.into_fallback_route(&state),
|
||||
allow_header: self.allow_header,
|
||||
}
|
||||
}
|
||||
|
@ -1284,7 +1284,7 @@ pub struct WithState<B, E> {
|
|||
post: Option<Route<B, E>>,
|
||||
put: Option<Route<B, E>>,
|
||||
trace: Option<Route<B, E>>,
|
||||
fallback: Route<B, E>,
|
||||
fallback: FallbackRoute<B, E>,
|
||||
allow_header: AllowHeader,
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ use crate::{
|
|||
handler::{BoxedHandler, Handler},
|
||||
util::try_downcast,
|
||||
};
|
||||
use axum_core::response::IntoResponse;
|
||||
use axum_core::response::{IntoResponse, Response};
|
||||
use http::Request;
|
||||
use matchit::MatchError;
|
||||
use std::{
|
||||
|
@ -18,7 +18,10 @@ use std::{
|
|||
fmt,
|
||||
sync::Arc,
|
||||
};
|
||||
use tower::{util::MapResponseLayer, ServiceBuilder};
|
||||
use tower::{
|
||||
util::{BoxCloneService, MapResponseLayer, Oneshot},
|
||||
ServiceBuilder,
|
||||
};
|
||||
use tower_layer::Layer;
|
||||
use tower_service::Service;
|
||||
|
||||
|
@ -639,11 +642,29 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn into_route(self, state: &S) -> Route<B, E> {
|
||||
fn into_fallback_route(self, state: &S) -> FallbackRoute<B, E> {
|
||||
match self {
|
||||
Self::Default(route) => route,
|
||||
Self::Service(route) => route,
|
||||
Self::BoxedHandler(handler) => handler.into_route(state.clone()),
|
||||
Self::Default(route) => FallbackRoute::Default(route),
|
||||
Self::Service(route) => FallbackRoute::Service(route),
|
||||
Self::BoxedHandler(handler) => {
|
||||
FallbackRoute::Service(handler.into_route(state.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2>
|
||||
where
|
||||
S: 'static,
|
||||
B: 'static,
|
||||
E: 'static,
|
||||
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
|
||||
B2: 'static,
|
||||
E2: 'static,
|
||||
{
|
||||
match self {
|
||||
Self::Default(route) => Fallback::Default(f(route)),
|
||||
Self::Service(route) => Fallback::Service(f(route)),
|
||||
Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -668,20 +689,38 @@ impl<S, B, E> fmt::Debug for Fallback<S, B, E> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<S, B, E> Fallback<S, B, E> {
|
||||
fn map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2>
|
||||
where
|
||||
S: 'static,
|
||||
B: 'static,
|
||||
E: 'static,
|
||||
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
|
||||
B2: 'static,
|
||||
E2: 'static,
|
||||
{
|
||||
/// Like `Fallback` but without the `S` param so it can be stored in `RouterService`
|
||||
pub(crate) enum FallbackRoute<B, E = Infallible> {
|
||||
Default(Route<B, E>),
|
||||
Service(Route<B, E>),
|
||||
}
|
||||
|
||||
impl<B, E> fmt::Debug for FallbackRoute<B, E> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Default(inner) => Fallback::Default(f(inner)),
|
||||
Self::Service(inner) => Fallback::Service(f(inner)),
|
||||
Self::BoxedHandler(inner) => Fallback::BoxedHandler(inner.map(f)),
|
||||
Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
|
||||
Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, E> Clone for FallbackRoute<B, E> {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
Self::Default(inner) => Self::Default(inner.clone()),
|
||||
Self::Service(inner) => Self::Service(inner.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, E> FallbackRoute<B, E> {
|
||||
pub(crate) fn oneshot_inner(
|
||||
&mut self,
|
||||
req: Request<B>,
|
||||
) -> Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>> {
|
||||
match self {
|
||||
FallbackRoute::Default(inner) => inner.oneshot_inner(req),
|
||||
FallbackRoute::Service(inner) => inner.oneshot_inner(req),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
use super::{future::RouteFuture, url_params, Endpoint, Node, Route, RouteId, Router};
|
||||
use super::{
|
||||
future::RouteFuture, url_params, Endpoint, FallbackRoute, Node, Route, RouteId, Router,
|
||||
};
|
||||
use crate::{
|
||||
body::{Body, HttpBody},
|
||||
response::Response,
|
||||
|
@ -11,6 +13,7 @@ use std::{
|
|||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use sync_wrapper::SyncWrapper;
|
||||
use tower::Service;
|
||||
|
||||
/// A [`Router`] converted into a [`Service`].
|
||||
|
@ -18,7 +21,7 @@ use tower::Service;
|
|||
pub struct RouterService<B = Body> {
|
||||
routes: HashMap<RouteId, Route<B>>,
|
||||
node: Arc<Node>,
|
||||
fallback: Route<B>,
|
||||
fallback: FallbackRoute<B>,
|
||||
}
|
||||
|
||||
impl<B> RouterService<B>
|
||||
|
@ -52,7 +55,7 @@ where
|
|||
Self {
|
||||
routes,
|
||||
node: router.node,
|
||||
fallback: router.fallback.into_route(&state),
|
||||
fallback: router.fallback.into_fallback_route(&state),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -121,12 +124,35 @@ where
|
|||
let path = req.uri().path().to_owned();
|
||||
|
||||
match self.node.at(&path) {
|
||||
Ok(match_) => self.call_route(match_, req),
|
||||
Ok(match_) => {
|
||||
match &self.fallback {
|
||||
FallbackRoute::Default(_) => {}
|
||||
FallbackRoute::Service(fallback) => {
|
||||
req.extensions_mut()
|
||||
.insert(SuperFallback(SyncWrapper::new(fallback.clone())));
|
||||
}
|
||||
}
|
||||
|
||||
self.call_route(match_, req)
|
||||
}
|
||||
Err(
|
||||
MatchError::NotFound
|
||||
| MatchError::ExtraTrailingSlash
|
||||
| MatchError::MissingTrailingSlash,
|
||||
) => self.fallback.clone().call(req),
|
||||
) => match &mut self.fallback {
|
||||
FallbackRoute::Default(fallback) => {
|
||||
if let Some(super_fallback) = req.extensions_mut().remove::<SuperFallback<B>>()
|
||||
{
|
||||
let mut super_fallback = super_fallback.0.into_inner();
|
||||
super_fallback.call(req)
|
||||
} else {
|
||||
fallback.call(req)
|
||||
}
|
||||
}
|
||||
FallbackRoute::Service(fallback) => fallback.call(req),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SuperFallback<B>(SyncWrapper<Route<B>>);
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use super::*;
|
||||
use crate::middleware::{map_request, map_response};
|
||||
|
||||
#[tokio::test]
|
||||
async fn basic() {
|
||||
|
@ -58,3 +59,102 @@ async fn fallback_accessing_state() {
|
|||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await, "state");
|
||||
}
|
||||
|
||||
async fn inner_fallback() -> impl IntoResponse {
|
||||
(StatusCode::NOT_FOUND, "inner")
|
||||
}
|
||||
|
||||
async fn outer_fallback() -> impl IntoResponse {
|
||||
(StatusCode::NOT_FOUND, "outer")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn nested_router_inherits_fallback() {
|
||||
let inner = Router::new();
|
||||
let app = Router::new().nest("/foo", inner).fallback(outer_fallback);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo/bar").send().await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
assert_eq!(res.text().await, "outer");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn doesnt_inherit_fallback_if_overriden() {
|
||||
let inner = Router::new().fallback(inner_fallback);
|
||||
let app = Router::new().nest("/foo", inner).fallback(outer_fallback);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo/bar").send().await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
assert_eq!(res.text().await, "inner");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn deeply_nested_inherit_from_top() {
|
||||
let app = Router::new()
|
||||
.nest("/foo", Router::new().nest("/bar", Router::new()))
|
||||
.fallback(outer_fallback);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo/bar/baz").send().await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
assert_eq!(res.text().await, "outer");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn deeply_nested_inherit_from_middle() {
|
||||
let app = Router::new().nest(
|
||||
"/foo",
|
||||
Router::new()
|
||||
.nest("/bar", Router::new())
|
||||
.fallback(outer_fallback),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo/bar/baz").send().await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
assert_eq!(res.text().await, "outer");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn with_middleware_on_inner_fallback() {
|
||||
async fn never_called<B>(_: Request<B>) -> Request<B> {
|
||||
panic!("should never be called")
|
||||
}
|
||||
|
||||
let inner = Router::new().layer(map_request(never_called));
|
||||
let app = Router::new().nest("/foo", inner).fallback(outer_fallback);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo/bar").send().await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
assert_eq!(res.text().await, "outer");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn also_inherits_default_layered_fallback() {
|
||||
async fn set_header<B>(mut res: Response<B>) -> Response<B> {
|
||||
res.headers_mut()
|
||||
.insert("x-from-fallback", "1".parse().unwrap());
|
||||
res
|
||||
}
|
||||
|
||||
let inner = Router::new();
|
||||
let app = Router::new()
|
||||
.nest("/foo", inner)
|
||||
.fallback(outer_fallback)
|
||||
.layer(map_response(set_header));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo/bar").send().await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
assert_eq!(res.headers()["x-from-fallback"], "1");
|
||||
assert_eq!(res.text().await, "outer");
|
||||
}
|
||||
|
|
|
@ -351,23 +351,6 @@ async fn nest_with_and_without_trailing() {
|
|||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn doesnt_call_outer_fallback() {
|
||||
let app = Router::new()
|
||||
.nest("/foo", Router::new().route("/", get(|| async {})))
|
||||
.fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") });
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let res = client.get("/foo/not-found").send().await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
// the default fallback returns an empty body
|
||||
assert_eq!(res.text().await, "");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn nesting_with_root_inner_router() {
|
||||
let app = Router::new().nest(
|
||||
|
|
Loading…
Reference in a new issue