Add method_not_allowed_fallback to router (#2903)

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
Leon Lux 2024-10-12 11:36:56 +02:00 committed by GitHub
parent 0712a46cd9
commit 73db1631c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 124 additions and 1 deletions

View file

@ -0,0 +1,38 @@
Add a fallback [`Handler`] for the case where a route exists, but the method of the request is not supported.
Sets a fallback on all previously registered [`MethodRouter`]s,
to be called when no matching method handler is set.
```rust,no_run
use axum::{response::IntoResponse, routing::get, Router};
async fn hello_world() -> impl IntoResponse {
"Hello, world!\n"
}
async fn default_fallback() -> impl IntoResponse {
"Default fallback\n"
}
async fn handle_405() -> impl IntoResponse {
"Method not allowed fallback"
}
#[tokio::main]
async fn main() {
let router = Router::new()
.route("/", get(hello_world))
.fallback(default_fallback)
.method_not_allowed_fallback(handle_405);
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, router).await.unwrap();
}
```
The fallback only applies if there is a `MethodRouter` registered for a given path,
but the method used in the request is not specified. In the example, a `GET` on
`http://localhost:3000` causes the `hello_world` handler to react, while issuing a
`POST` triggers `handle_405`. Calling an entirely different route, like `http://localhost:3000/hello`
causes `default_fallback` to run.

View file

@ -659,6 +659,19 @@ where
self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler)); self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
self self
} }
/// Add a fallback [`Handler`] if no custom one has been provided.
pub(crate) fn default_fallback<H, T>(self, handler: H) -> Self
where
H: Handler<T, S>,
T: 'static,
S: Send + Sync + 'static,
{
match self.fallback {
Fallback::Default(_) => self.fallback(handler),
_ => self,
}
}
} }
impl MethodRouter<(), Infallible> { impl MethodRouter<(), Infallible> {

View file

@ -354,6 +354,18 @@ where
.fallback_endpoint(Endpoint::Route(route)) .fallback_endpoint(Endpoint::Route(route))
} }
#[doc = include_str!("../docs/routing/method_not_allowed_fallback.md")]
pub fn method_not_allowed_fallback<H, T>(self, handler: H) -> Self
where
H: Handler<T, S>,
T: 'static,
{
tap_inner!(self, mut this => {
this.path_router
.method_not_allowed_fallback(handler.clone())
})
}
fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self { fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self {
tap_inner!(self, mut this => { tap_inner!(self, mut this => {
this.fallback_router.set_fallback(endpoint); this.fallback_router.set_fallback(endpoint);

View file

@ -1,4 +1,7 @@
use crate::extract::{nested_path::SetNestedPath, Request}; use crate::{
extract::{nested_path::SetNestedPath, Request},
handler::Handler,
};
use axum_core::response::IntoResponse; use axum_core::response::IntoResponse;
use matchit::MatchError; use matchit::MatchError;
use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc}; use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
@ -110,6 +113,18 @@ where
Ok(()) Ok(())
} }
pub(super) fn method_not_allowed_fallback<H, T>(&mut self, handler: H)
where
H: Handler<T, S>,
T: 'static,
{
for (_, endpoint) in self.routes.iter_mut() {
if let Endpoint::MethodRouter(rt) = endpoint {
*rt = rt.clone().default_fallback(handler.clone());
}
}
}
pub(super) fn route_service<T>( pub(super) fn route_service<T>(
&mut self, &mut self,
path: &str, path: &str,

View file

@ -329,6 +329,51 @@ async fn merge_router_with_fallback_into_empty() {
assert_eq!(res.text().await, "outer"); assert_eq!(res.text().await, "outer");
} }
#[crate::test]
async fn mna_fallback_with_existing_fallback() {
let app = Router::new()
.route(
"/",
get(|| async { "test" }).fallback(|| async { "index fallback" }),
)
.route("/path", get(|| async { "path" }))
.method_not_allowed_fallback(|| async { "method not allowed fallback" });
let client = TestClient::new(app);
let index_fallback = client.post("/").await;
let method_not_allowed_fallback = client.post("/path").await;
assert_eq!(index_fallback.text().await, "index fallback");
assert_eq!(
method_not_allowed_fallback.text().await,
"method not allowed fallback"
);
}
#[crate::test]
async fn mna_fallback_with_state() {
let app = Router::new()
.route("/", get(|| async { "index" }))
.method_not_allowed_fallback(|State(state): State<&'static str>| async move { state })
.with_state("state");
let client = TestClient::new(app);
let res = client.post("/").await;
assert_eq!(res.text().await, "state");
}
#[crate::test]
async fn mna_fallback_with_unused_state() {
let app = Router::new()
.route("/", get(|| async { "index" }))
.with_state(())
.method_not_allowed_fallback(|| async move { "bla" });
let client = TestClient::new(app);
let res = client.post("/").await;
assert_eq!(res.text().await, "bla");
}
#[crate::test] #[crate::test]
async fn state_isnt_cloned_too_much_with_fallback() { async fn state_isnt_cloned_too_much_with_fallback() {
let state = CountingCloneableState::new(); let state = CountingCloneableState::new();