diff --git a/axum/src/docs/routing/method_not_allowed_fallback.md b/axum/src/docs/routing/method_not_allowed_fallback.md new file mode 100644 index 00000000..22905cd9 --- /dev/null +++ b/axum/src/docs/routing/method_not_allowed_fallback.md @@ -0,0 +1,38 @@ +Add a fallback [`Handler`] for the case where a route exists, but the method of the request is not supported. + +Sets a fallback on all previously registered [`MethodRouter`]s, +to be called when no matching method handler is set. + +```rust,no_run +use axum::{response::IntoResponse, routing::get, Router}; + +async fn hello_world() -> impl IntoResponse { + "Hello, world!\n" +} + +async fn default_fallback() -> impl IntoResponse { + "Default fallback\n" +} + +async fn handle_405() -> impl IntoResponse { + "Method not allowed fallback" +} + +#[tokio::main] +async fn main() { + let router = Router::new() + .route("/", get(hello_world)) + .fallback(default_fallback) + .method_not_allowed_fallback(handle_405); + + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + + axum::serve(listener, router).await.unwrap(); +} +``` + +The fallback only applies if there is a `MethodRouter` registered for a given path, +but the method used in the request is not specified. In the example, a `GET` on +`http://localhost:3000` causes the `hello_world` handler to react, while issuing a +`POST` triggers `handle_405`. Calling an entirely different route, like `http://localhost:3000/hello` +causes `default_fallback` to run. diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 28fa314f..ecc1dae0 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -658,6 +658,19 @@ where self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler)); self } + + /// Add a fallback [`Handler`] if no custom one has been provided. + pub(crate) fn default_fallback(self, handler: H) -> Self + where + H: Handler, + T: 'static, + S: Send + Sync + 'static, + { + match self.fallback { + Fallback::Default(_) => self.fallback(handler), + _ => self, + } + } } impl MethodRouter<(), Infallible> { diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 7499e1fc..57a3d430 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -347,6 +347,18 @@ where .fallback_endpoint(Endpoint::Route(route)) } + #[doc = include_str!("../docs/routing/method_not_allowed_fallback.md")] + pub fn method_not_allowed_fallback(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + tap_inner!(self, mut this => { + this.path_router + .method_not_allowed_fallback(handler.clone()) + }) + } + fn fallback_endpoint(self, endpoint: Endpoint) -> Self { tap_inner!(self, mut this => { this.fallback_router.set_fallback(endpoint); diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index 5e317b90..400ce32d 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -1,4 +1,7 @@ -use crate::extract::{nested_path::SetNestedPath, Request}; +use crate::{ + extract::{nested_path::SetNestedPath, Request}, + handler::Handler, +}; use axum_core::response::IntoResponse; use matchit::MatchError; use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc}; @@ -79,6 +82,18 @@ where Ok(()) } + pub(super) fn method_not_allowed_fallback(&mut self, handler: H) + where + H: Handler, + T: 'static, + { + for (_, endpoint) in self.routes.iter_mut() { + if let Endpoint::MethodRouter(rt) = endpoint { + *rt = rt.clone().default_fallback(handler.clone()); + } + } + } + pub(super) fn route_service( &mut self, path: &str, diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index 02850b19..9dd1c6c2 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -329,6 +329,51 @@ async fn merge_router_with_fallback_into_empty() { assert_eq!(res.text().await, "outer"); } +#[crate::test] +async fn mna_fallback_with_existing_fallback() { + let app = Router::new() + .route( + "/", + get(|| async { "test" }).fallback(|| async { "index fallback" }), + ) + .route("/path", get(|| async { "path" })) + .method_not_allowed_fallback(|| async { "method not allowed fallback" }); + + let client = TestClient::new(app); + let index_fallback = client.post("/").await; + let method_not_allowed_fallback = client.post("/path").await; + + assert_eq!(index_fallback.text().await, "index fallback"); + assert_eq!( + method_not_allowed_fallback.text().await, + "method not allowed fallback" + ); +} + +#[crate::test] +async fn mna_fallback_with_state() { + let app = Router::new() + .route("/", get(|| async { "index" })) + .method_not_allowed_fallback(|State(state): State<&'static str>| async move { state }) + .with_state("state"); + + let client = TestClient::new(app); + let res = client.post("/").await; + assert_eq!(res.text().await, "state"); +} + +#[crate::test] +async fn mna_fallback_with_unused_state() { + let app = Router::new() + .route("/", get(|| async { "index" })) + .with_state(()) + .method_not_allowed_fallback(|| async move { "bla" }); + + let client = TestClient::new(app); + let res = client.post("/").await; + assert_eq!(res.text().await, "bla"); +} + #[crate::test] async fn state_isnt_cloned_too_much_with_fallback() { let state = CountingCloneableState::new();