diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md
index 546553a9..0feea34c 100644
--- a/axum/CHANGELOG.md
+++ b/axum/CHANGELOG.md
@@ -7,7 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
-- None.
+- **breaking:** New `MethodRouter` that works similarly to `Router`:
+ - Route to handlers and services with the same type
+ - Add middleware to some routes more easily with `MethodRouter::layer` and
+ `MethodRouter::route_layer`.
+ - Merge method routers with `MethodRouter::merge`
+ - Customize response for unsupported methods with `MethodRouter::fallback`
+- **fixed:** Adding the same route with different methods now works ie
+ `.route("/", get(_)).route("/", post(_))`.
+- **breaking:** `routing::handler_method_router` and
+ `routing::service_method_router` has been removed in favor of
+ `routing::{get, get_service, ..., MethodRouter}`.
+- **breaking:** `HandleErrorExt` has been removed in favor of
+ `MethodRouter::handle_error`.
# 0.3.3 (13. November, 2021)
diff --git a/axum/src/docs/error_handling.md b/axum/src/docs/error_handling.md
index 7c9c3572..508302ab 100644
--- a/axum/src/docs/error_handling.md
+++ b/axum/src/docs/error_handling.md
@@ -43,14 +43,12 @@ functions as handlers. However if you're embedding general `Service`s or
applying middleware, which might produce errors you have to tell axum how to
convert those errors into responses.
-You can handle errors from services using [`HandleErrorExt::handle_error`]:
-
```rust
use axum::{
Router,
body::Body,
http::{Request, Response, StatusCode},
- error_handling::HandleErrorExt, // for `.handle_error()`
+ error_handling::HandleError,
};
async fn thing_that_might_fail() -> Result<(), anyhow::Error> {
@@ -69,7 +67,7 @@ let app = Router::new().route(
// we cannot route to `some_fallible_service` directly since it might fail.
// we have to use `handle_error` which converts its errors into responses
// and changes its error type from `anyhow::Error` to `Infallible`.
- some_fallible_service.handle_error(handle_anyhow_error),
+ HandleError::new(some_fallible_service, handle_anyhow_error),
);
// handle errors by converting them into something that implements
diff --git a/axum/src/docs/method_routing/fallback.md b/axum/src/docs/method_routing/fallback.md
new file mode 100644
index 00000000..c027578c
--- /dev/null
+++ b/axum/src/docs/method_routing/fallback.md
@@ -0,0 +1,53 @@
+Add a fallback service to the router.
+
+This service will be called if no routes matches the incoming request.
+
+```rust
+use axum::{
+ Router,
+ routing::get,
+ handler::Handler,
+ response::IntoResponse,
+ http::{StatusCode, Method, Uri},
+};
+
+let handler = get(|| async {}).fallback(fallback.into_service());
+
+let app = Router::new().route("/", handler);
+
+async fn fallback(method: Method, uri: Uri) -> impl IntoResponse {
+ (StatusCode::NOT_FOUND, format!("`{}` not allowed for {}", method, uri))
+}
+# async {
+# hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+# };
+```
+
+## When used with `MethodRouter::merge`
+
+Two routers that both have a fallback cannot be merged. Doing so results in a
+panic:
+
+```rust,should_panic
+use axum::{
+ routing::{get, post},
+ handler::Handler,
+ response::IntoResponse,
+ http::{StatusCode, Uri},
+};
+
+let one = get(|| async {})
+ .fallback(fallback_one.into_service());
+
+let two = post(|| async {})
+ .fallback(fallback_two.into_service());
+
+let method_route = one.merge(two);
+
+async fn fallback_one() -> impl IntoResponse {}
+async fn fallback_two() -> impl IntoResponse {}
+# let app = axum::Router::new().route("/", method_route);
+# async {
+# hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+# };
+```
diff --git a/axum/src/docs/method_routing/layer.md b/axum/src/docs/method_routing/layer.md
new file mode 100644
index 00000000..10bde2ab
--- /dev/null
+++ b/axum/src/docs/method_routing/layer.md
@@ -0,0 +1,28 @@
+Apply a [`tower::Layer`] to the router.
+
+All requests to the router will be processed by the layer's
+corresponding middleware.
+
+This can be used to add additional processing to a request for a group
+of routes.
+
+Works similarly to [`Router::layer`](super::Router::layer). See that method for
+more details.
+
+# Example
+
+```rust
+use axum::{routing::get, Router};
+use tower::limit::ConcurrencyLimitLayer;
+
+async fn hander() {}
+
+let app = Router::new().route(
+ "/",
+ // All requests to `GET /` will be sent through `ConcurrencyLimitLayer`
+ get(hander).layer(ConcurrencyLimitLayer::new(64)),
+);
+# async {
+# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+# };
+```
diff --git a/axum/src/docs/method_routing/merge.md b/axum/src/docs/method_routing/merge.md
new file mode 100644
index 00000000..39d74d04
--- /dev/null
+++ b/axum/src/docs/method_routing/merge.md
@@ -0,0 +1,25 @@
+Merge two routers into one.
+
+This is useful for breaking routers into smaller pieces and combining them
+into one.
+
+```rust
+use axum::{
+ routing::{get, post},
+ Router,
+};
+
+let get = get(|| async {});
+let post = post(|| async {});
+
+let merged = get.merge(post);
+
+let app = Router::new().route("/", merged);
+
+// Our app now accepts
+// - GET /
+// - POST /
+# async {
+# hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+# };
+```
diff --git a/axum/src/docs/method_routing/route_layer.md b/axum/src/docs/method_routing/route_layer.md
new file mode 100644
index 00000000..c2e2c061
--- /dev/null
+++ b/axum/src/docs/method_routing/route_layer.md
@@ -0,0 +1,30 @@
+Apply a [`tower::Layer`] to the router that will only run if the request matches
+a route.
+
+This works similarly to [`MethodRouter::layer`] except the middleware will only run if
+the request matches a route. This is useful for middleware that return early
+(such as authorization) which might otherwise convert a `405 Method Not Allowed` into a
+`401 Unauthorized`.
+
+# Example
+
+```rust
+use axum::{
+ routing::get,
+ Router,
+};
+use tower_http::auth::RequireAuthorizationLayer;
+
+let app = Router::new().route(
+ "/foo",
+ get(|| async {})
+ .route_layer(RequireAuthorizationLayer::bearer("password"))
+);
+
+// `GET /foo` with a valid token will receive `200 OK`
+// `GET /foo` with a invalid token will receive `401 Unauthorized`
+// `POST /FOO` with a invalid token will receive `405 Method Not Allowed`
+# async {
+# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+# };
+```
diff --git a/axum/src/docs/routing/nest.md b/axum/src/docs/routing/nest.md
index 468968dd..be913968 100644
--- a/axum/src/docs/routing/nest.md
+++ b/axum/src/docs/routing/nest.md
@@ -72,15 +72,14 @@ let app = Router::new().nest("/:version/api", users_api);
```rust
use axum::{
Router,
- routing::service_method_routing::get,
- error_handling::HandleErrorExt,
+ routing::get_service,
http::StatusCode,
};
use std::{io, convert::Infallible};
use tower_http::services::ServeDir;
// Serves files inside the `public` directory at `GET /public/*`
-let serve_dir_service = ServeDir::new("public")
+let serve_dir_service = get_service(ServeDir::new("public"))
.handle_error(|error: io::Error| {
(
StatusCode::INTERNAL_SERVER_ERROR,
@@ -88,7 +87,7 @@ let serve_dir_service = ServeDir::new("public")
)
});
-let app = Router::new().nest("/public", get(serve_dir_service));
+let app = Router::new().nest("/public", serve_dir_service);
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
diff --git a/axum/src/docs/routing/route.md b/axum/src/docs/routing/route.md
index 7d7cb245..4b4ad4d1 100644
--- a/axum/src/docs/routing/route.md
+++ b/axum/src/docs/routing/route.md
@@ -111,8 +111,7 @@ axum also supports routing to general [`Service`]s:
use axum::{
Router,
body::Body,
- routing::service_method_routing as service,
- error_handling::HandleErrorExt,
+ routing::{any_service, get_service},
http::{Request, StatusCode},
};
use tower_http::services::ServeFile;
@@ -125,9 +124,9 @@ let app = Router::new()
// Any request to `/` goes to a service
"/",
// Services whose response body is not `axum::body::BoxBody`
- // can be wrapped in `axum::service::any` (or one of the other routing filters)
+ // can be wrapped in `axum::routing::any_service` (or one of the other routing filters)
// to have the response body mapped
- service::any(service_fn(|_: Request
| async {
+ any_service(service_fn(|_: Request| async {
let res = Response::new(Body::from("Hi from `GET /`"));
Ok::<_, Infallible>(res)
}))
@@ -146,7 +145,7 @@ let app = Router::new()
.route(
// GET `/static/Cargo.toml` goes to a service from tower-http
"/static/Cargo.toml",
- service::get(ServeFile::new("Cargo.toml"))
+ get_service(ServeFile::new("Cargo.toml"))
// though we must handle any potential errors
.handle_error(|error: io::Error| {
(
@@ -161,8 +160,10 @@ let app = Router::new()
```
Routing to arbitrary services in this way has complications for backpressure
-([`Service::poll_ready`]). See the [`service_method_routing`] module for more
-details.
+([`Service::poll_ready`]). See the [Routing to services and backpressure] module
+for more details.
+
+[Routing to services and backpressure]: /#routing-to-services-and-backpressure
# Panics
diff --git a/axum/src/error_handling/mod.rs b/axum/src/error_handling/mod.rs
index 8eb20824..b1559549 100644
--- a/axum/src/error_handling/mod.rs
+++ b/axum/src/error_handling/mod.rs
@@ -134,19 +134,6 @@ where
}
}
-/// Extension trait to [`Service`] for handling errors by mapping them to
-/// responses.
-///
-/// See [module docs](self) for more details on axum's error handling model.
-pub trait HandleErrorExt: Service> + Sized {
- /// Apply a [`HandleError`] middleware.
- fn handle_error(self, f: F) -> HandleError {
- HandleError::new(self, f)
- }
-}
-
-impl HandleErrorExt for S where S: Service> {}
-
pub mod future {
//! Future types.
diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs
index f0799f07..68e80045 100644
--- a/axum/src/handler/mod.rs
+++ b/axum/src/handler/mod.rs
@@ -404,11 +404,4 @@ mod tests {
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "you said: hi there!");
}
-
- #[test]
- fn traits() {
- use crate::{routing::MethodRouter, test_helpers::*};
- assert_send::>();
- assert_sync::>();
- }
}
diff --git a/axum/src/lib.rs b/axum/src/lib.rs
index 49f33497..d2060072 100644
--- a/axum/src/lib.rs
+++ b/axum/src/lib.rs
@@ -11,6 +11,7 @@
//! - [Responses](#responses)
//! - [Error handling](#error-handling)
//! - [Middleware](#middleware)
+//! - [Routing to services and backpressure](#routing-to-services-and-backpressure)
//! - [Sharing state with handlers](#sharing-state-with-handlers)
//! - [Required dependencies](#required-dependencies)
//! - [Examples](#examples)
@@ -160,6 +161,68 @@
//!
#![doc = include_str!("docs/middleware.md")]
//!
+//! # Routing to services and backpressure
+//!
+//! Generally routing to one of multiple services and backpressure doesn't mix
+//! well. Ideally you would want ensure a service is ready to receive a request
+//! before calling it. However, in order to know which service to call, you need
+//! the request...
+//!
+//! One approach is to not consider the router service itself ready until all
+//! destination services are ready. That is the approach used by
+//! [`tower::steer::Steer`].
+//!
+//! Another approach is to always consider all services ready (always return
+//! `Poll::Ready(Ok(()))`) from `Service::poll_ready` and then actually drive
+//! readiness inside the response future returned by `Service::call`. This works
+//! well when your services don't care about backpressure and are always ready
+//! anyway.
+//!
+//! axum expects that all services used in your app wont care about
+//! backpressure and so it uses the latter strategy. However that means you
+//! should avoid routing to a service (or using a middleware) that _does_ care
+//! about backpressure. At the very least you should [load shed] so requests are
+//! dropped quickly and don't keep piling up.
+//!
+//! It also means that if `poll_ready` returns an error then that error will be
+//! returned in the response future from `call` and _not_ from `poll_ready`. In
+//! that case, the underlying service will _not_ be discarded and will continue
+//! to be used for future requests. Services that expect to be discarded if
+//! `poll_ready` fails should _not_ be used with axum.
+//!
+//! One possible approach is to only apply backpressure sensitive middleware
+//! around your entire app. This is possible because axum applications are
+//! themselves services:
+//!
+//! ```rust
+//! use axum::{
+//! routing::get,
+//! Router,
+//! };
+//! use tower::ServiceBuilder;
+//! # let some_backpressure_sensitive_middleware =
+//! # tower::layer::util::Identity::new();
+//!
+//! async fn handler() { /* ... */ }
+//!
+//! let app = Router::new().route("/", get(handler));
+//!
+//! let app = ServiceBuilder::new()
+//! .layer(some_backpressure_sensitive_middleware)
+//! .service(app);
+//! # async {
+//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+//! # };
+//! ```
+//!
+//! However when applying middleware around your whole application in this way
+//! you have to take care that errors are still being handled with
+//! appropriately.
+//!
+//! Also note that handlers created from async functions don't care about
+//! backpressure and are always ready. So if you're not using any Tower
+//! middleware you don't have to worry about any of this.
+//!
//! # Sharing state with handlers
//!
//! It is common to share some state between handlers for example to share a
@@ -255,8 +318,8 @@
//! [`OriginalUri`]: crate::extract::OriginalUri
//! [`Service`]: tower::Service
//! [`Service::poll_ready`]: tower::Service::poll_ready
+//! [`Service`'s]: tower::Service
//! [`tower::Service`]: tower::Service
-//! [`handle_error`]: error_handling::HandleErrorExt::handle_error
//! [tower-guides]: https://github.com/tower-rs/tower/tree/master/guides
//! [`Uuid`]: https://docs.rs/uuid/latest/uuid/
//! [`FromRequest`]: crate::extract::FromRequest
@@ -267,6 +330,7 @@
//! [`debug_handler`]: https://docs.rs/axum-debug/latest/axum_debug/attr.debug_handler.html
//! [`Handler`]: crate::handler::Handler
//! [`Infallible`]: std::convert::Infallible
+//! [load shed]: tower::load_shed
#![warn(
clippy::all,
diff --git a/axum/src/routing/future.rs b/axum/src/routing/future.rs
index e09be0ae..1cd974ff 100644
--- a/axum/src/routing/future.rs
+++ b/axum/src/routing/future.rs
@@ -2,26 +2,22 @@
use crate::body::BoxBody;
use futures_util::future::Either;
-use http::{Request, Response};
+use http::Response;
use std::{convert::Infallible, future::ready};
-use tower::util::Oneshot;
-pub use super::{
- into_make_service::IntoMakeServiceFuture, method_not_allowed::MethodNotAllowedFuture,
- route::RouteFuture,
-};
+pub use super::{into_make_service::IntoMakeServiceFuture, route::RouteFuture};
opaque_future! {
/// Response future for [`Router`](super::Router).
pub type RouterFuture =
futures_util::future::Either<
- Oneshot, Request>,
+ RouteFuture,
std::future::Ready, Infallible>>,
>;
}
impl RouterFuture {
- pub(super) fn from_oneshot(future: Oneshot, Request>) -> Self {
+ pub(super) fn from_future(future: RouteFuture) -> Self {
Self::new(Either::Left(future))
}
diff --git a/axum/src/routing/method_filter.rs b/axum/src/routing/method_filter.rs
index 251be637..82e823a5 100644
--- a/axum/src/routing/method_filter.rs
+++ b/axum/src/routing/method_filter.rs
@@ -1,5 +1,4 @@
use bitflags::bitflags;
-use http::Method;
bitflags! {
/// A filter that matches one or more HTTP methods.
@@ -22,21 +21,3 @@ bitflags! {
const TRACE = 0b100000000;
}
}
-
-impl MethodFilter {
- #[allow(clippy::match_like_matches_macro)]
- pub(crate) fn matches(self, method: &Method) -> bool {
- let method = match *method {
- Method::DELETE => Self::DELETE,
- Method::GET => Self::GET,
- Method::HEAD => Self::HEAD,
- Method::OPTIONS => Self::OPTIONS,
- Method::PATCH => Self::PATCH,
- Method::POST => Self::POST,
- Method::PUT => Self::PUT,
- Method::TRACE => Self::TRACE,
- _ => return false,
- };
- self.contains(method)
- }
-}
diff --git a/axum/src/routing/method_not_allowed.rs b/axum/src/routing/method_not_allowed.rs
deleted file mode 100644
index 6f812b11..00000000
--- a/axum/src/routing/method_not_allowed.rs
+++ /dev/null
@@ -1,82 +0,0 @@
-use crate::body::BoxBody;
-use http::{Request, Response, StatusCode};
-use std::{
- convert::Infallible,
- fmt,
- future::ready,
- marker::PhantomData,
- task::{Context, Poll},
-};
-use tower_service::Service;
-
-/// A [`Service`] that responds with `405 Method not allowed` to all requests.
-///
-/// This is used as the bottom service in a method router. You shouldn't have to
-/// use it manually.
-pub struct MethodNotAllowed {
- _marker: PhantomData E>,
-}
-
-impl MethodNotAllowed {
- pub(crate) fn new() -> Self {
- Self {
- _marker: PhantomData,
- }
- }
-}
-
-impl Clone for MethodNotAllowed {
- fn clone(&self) -> Self {
- Self {
- _marker: PhantomData,
- }
- }
-}
-
-impl fmt::Debug for MethodNotAllowed {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- f.debug_tuple("MethodNotAllowed").finish()
- }
-}
-
-impl Service> for MethodNotAllowed
-where
- B: Send + 'static,
-{
- type Response = Response;
- type Error = E;
- type Future = MethodNotAllowedFuture;
-
- #[inline]
- fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> {
- Poll::Ready(Ok(()))
- }
-
- fn call(&mut self, _req: Request) -> Self::Future {
- let res = Response::builder()
- .status(StatusCode::METHOD_NOT_ALLOWED)
- .body(crate::body::empty())
- .unwrap();
-
- MethodNotAllowedFuture::new(ready(Ok(res)))
- }
-}
-
-opaque_future! {
- /// Response future for [`MethodNotAllowed`](super::MethodNotAllowed).
- pub type MethodNotAllowedFuture =
- std::future::Ready, E>>;
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn traits() {
- use crate::test_helpers::*;
-
- assert_send::>();
- assert_sync::>();
- }
-}
diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs
new file mode 100644
index 00000000..271dfa45
--- /dev/null
+++ b/axum/src/routing/method_routing.rs
@@ -0,0 +1,1085 @@
+use crate::{
+ body::{box_body, Body, BoxBody, Bytes},
+ error_handling::HandleErrorLayer,
+ handler::Handler,
+ http::{Method, Request, Response, StatusCode},
+ routing::{Fallback, MethodFilter, Route},
+ BoxError,
+};
+use http_body::Empty;
+use std::{
+ convert::Infallible,
+ fmt,
+ marker::PhantomData,
+ task::{Context, Poll},
+};
+use tower::{service_fn, ServiceBuilder, ServiceExt};
+use tower_http::map_response_body::MapResponseBodyLayer;
+use tower_layer::Layer;
+use tower_service::Service;
+
+macro_rules! top_level_service_fn {
+ (
+ $name:ident, GET
+ ) => {
+ top_level_service_fn!(
+ /// Route `GET` requests to the given service.
+ ///
+ /// # Example
+ ///
+ /// ```rust
+ /// use axum::{
+ /// http::Request,
+ /// Router,
+ /// routing::get_service,
+ /// };
+ /// use http::Response;
+ /// use std::convert::Infallible;
+ /// use hyper::Body;
+ ///
+ /// let service = tower::service_fn(|request: Request| async {
+ /// Ok::<_, Infallible>(Response::new(Body::empty()))
+ /// });
+ ///
+ /// // Requests to `GET /` will go to `service`.
+ /// let app = Router::new().route("/", get_service(service));
+ /// # async {
+ /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+ /// # };
+ /// ```
+ ///
+ /// Note that `get` routes will also be called for `HEAD` requests but will have
+ /// the response body removed. Make sure to add explicit `HEAD` routes
+ /// afterwards.
+ $name,
+ GET
+ );
+ };
+
+ (
+ $name:ident, $method:ident
+ ) => {
+ top_level_service_fn!(
+ #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
+ ///
+ /// See [`get_service`] for an example.
+ $name,
+ $method
+ );
+ };
+
+ (
+ $(#[$m:meta])+
+ $name:ident, $method:ident
+ ) => {
+ $(#[$m])+
+ pub fn $name(svc: S) -> MethodRouter
+ where
+ S: Service, Response = Response> + Clone + Send + 'static,
+ S::Future: Send + 'static,
+ ResBody: http_body::Body + Send + 'static,
+ ResBody::Error: Into,
+ {
+ on_service(MethodFilter::$method, svc)
+ }
+ };
+}
+
+macro_rules! top_level_handler_fn {
+ (
+ $name:ident, GET
+ ) => {
+ top_level_handler_fn!(
+ /// Route `GET` requests to the given handler.
+ ///
+ /// # Example
+ ///
+ /// ```rust
+ /// use axum::{
+ /// routing::get,
+ /// Router,
+ /// };
+ ///
+ /// async fn handler() {}
+ ///
+ /// // Requests to `GET /` will go to `handler`.
+ /// let app = Router::new().route("/", get(handler));
+ /// # async {
+ /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+ /// # };
+ /// ```
+ ///
+ /// Note that `get` routes will also be called for `HEAD` requests but will have
+ /// the response body removed. Make sure to add explicit `HEAD` routes
+ /// afterwards.
+ $name,
+ GET
+ );
+ };
+
+ (
+ $name:ident, $method:ident
+ ) => {
+ top_level_handler_fn!(
+ #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
+ ///
+ /// See [`get`] for an example.
+ $name,
+ $method
+ );
+ };
+
+ (
+ $(#[$m:meta])+
+ $name:ident, $method:ident
+ ) => {
+ $(#[$m])+
+ pub fn $name(handler: H) -> MethodRouter
+ where
+ H: Handler,
+ B: Send + 'static,
+ T: 'static,
+ {
+ on(MethodFilter::$method, handler)
+ }
+ };
+}
+
+macro_rules! chained_service_fn {
+ (
+ $name:ident, GET
+ ) => {
+ chained_service_fn!(
+ /// Chain an additional service that will only accept `GET` requests.
+ ///
+ /// # Example
+ ///
+ /// ```rust
+ /// use axum::{
+ /// http::Request,
+ /// Router,
+ /// routing::post_service,
+ /// };
+ /// use http::Response;
+ /// use std::convert::Infallible;
+ /// use hyper::Body;
+ ///
+ /// let service = tower::service_fn(|request: Request| async {
+ /// Ok::<_, Infallible>(Response::new(Body::empty()))
+ /// });
+ ///
+ /// let other_service = tower::service_fn(|request: Request| async {
+ /// Ok::<_, Infallible>(Response::new(Body::empty()))
+ /// });
+ ///
+ /// // Requests to `GET /` will go to `service` and `POST /` will go to
+ /// // `other_service`.
+ /// let app = Router::new().route("/", post_service(service).get_service(other_service));
+ /// # async {
+ /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+ /// # };
+ /// ```
+ ///
+ /// Note that `get` routes will also be called for `HEAD` requests but will have
+ /// the response body removed. Make sure to add explicit `HEAD` routes
+ /// afterwards.
+ $name,
+ GET
+ );
+ };
+
+ (
+ $name:ident, $method:ident
+ ) => {
+ chained_service_fn!(
+ #[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")]
+ ///
+ /// See [`MethodRouter::get_service`] for an example.
+ $name,
+ $method
+ );
+ };
+
+ (
+ $(#[$m:meta])+
+ $name:ident, $method:ident
+ ) => {
+ $(#[$m])+
+ pub fn $name(self, svc: S) -> Self
+ where
+ S: Service, Response = Response, Error = E>
+ + Clone
+ + Send
+ + 'static,
+ S::Future: Send + 'static,
+ ResBody: http_body::Body + Send + 'static,
+ ResBody::Error: Into,
+ {
+ self.on_service(MethodFilter::$method, svc)
+ }
+ };
+}
+
+macro_rules! chained_handler_fn {
+ (
+ $name:ident, GET
+ ) => {
+ chained_handler_fn!(
+ /// Chain an additional handler that will only accept `GET` requests.
+ ///
+ /// # Example
+ ///
+ /// ```rust
+ /// use axum::{routing::post, Router};
+ ///
+ /// async fn handler() {}
+ ///
+ /// async fn other_handler() {}
+ ///
+ /// // Requests to `GET /` will go to `handler` and `POST /` will go to
+ /// // `other_handler`.
+ /// let app = Router::new().route("/", post(handler).get(other_handler));
+ /// # async {
+ /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+ /// # };
+ /// ```
+ ///
+ /// Note that `get` routes will also be called for `HEAD` requests but will have
+ /// the response body removed. Make sure to add explicit `HEAD` routes
+ /// afterwards.
+ $name,
+ GET
+ );
+ };
+
+ (
+ $name:ident, $method:ident
+ ) => {
+ chained_handler_fn!(
+ #[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")]
+ ///
+ /// See [`MethodRouter::get`] for an example.
+ $name,
+ $method
+ );
+ };
+
+ (
+ $(#[$m:meta])+
+ $name:ident, $method:ident
+ ) => {
+ $(#[$m])+
+ pub fn $name(self, handler: H) -> Self
+ where
+ H: Handler,
+ T: 'static,
+ {
+ self.on(MethodFilter::$method, handler)
+ }
+ };
+}
+
+top_level_service_fn!(delete_service, DELETE);
+top_level_service_fn!(get_service, GET);
+top_level_service_fn!(head_service, HEAD);
+top_level_service_fn!(options_service, OPTIONS);
+top_level_service_fn!(patch_service, PATCH);
+top_level_service_fn!(post_service, POST);
+top_level_service_fn!(put_service, PUT);
+top_level_service_fn!(trace_service, TRACE);
+
+/// Route requests with the given method to the service.
+///
+/// # Example
+///
+/// ```rust
+/// use axum::{
+/// http::Request,
+/// routing::on,
+/// Router,
+/// routing::{MethodFilter, on_service},
+/// };
+/// use http::Response;
+/// use std::convert::Infallible;
+/// use hyper::Body;
+///
+/// let service = tower::service_fn(|request: Request| async {
+/// Ok::<_, Infallible>(Response::new(Body::empty()))
+/// });
+///
+/// // Requests to `POST /` will go to `service`.
+/// let app = Router::new().route("/", on_service(MethodFilter::POST, service));
+/// # async {
+/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+/// # };
+/// ```
+pub fn on_service(
+ filter: MethodFilter,
+ svc: S,
+) -> MethodRouter
+where
+ S: Service, Response = Response> + Clone + Send + 'static,
+ S::Future: Send + 'static,
+ ResBody: http_body::Body + Send + 'static,
+ ResBody::Error: Into,
+{
+ MethodRouter::new().on_service(filter, svc)
+}
+
+/// Route requests to the given service regardless of its method.
+///
+/// # Example
+///
+/// ```rust
+/// use axum::{
+/// http::Request,
+/// Router,
+/// routing::any_service,
+/// };
+/// use http::Response;
+/// use std::convert::Infallible;
+/// use hyper::Body;
+///
+/// let service = tower::service_fn(|request: Request| async {
+/// Ok::<_, Infallible>(Response::new(Body::empty()))
+/// });
+///
+/// // All requests to `/` will go to `service`.
+/// let app = Router::new().route("/", any_service(service));
+/// # async {
+/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+/// # };
+/// ```
+///
+/// Additional methods can still be chained:
+///
+/// ```rust
+/// use axum::{
+/// http::Request,
+/// Router,
+/// routing::any_service,
+/// };
+/// use http::Response;
+/// use std::convert::Infallible;
+/// use hyper::Body;
+///
+/// let service = tower::service_fn(|request: Request| async {
+/// # Ok::<_, Infallible>(Response::new(Body::empty()))
+/// // ...
+/// });
+///
+/// let other_service = tower::service_fn(|request: Request| async {
+/// # Ok::<_, Infallible>(Response::new(Body::empty()))
+/// // ...
+/// });
+///
+/// // `POST /` goes to `other_service`. All other requests go to `service`
+/// let app = Router::new().route("/", any_service(service).post_service(other_service));
+/// # async {
+/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+/// # };
+/// ```
+pub fn any_service(svc: S) -> MethodRouter
+where
+ S: Service, Response = Response> + Clone + Send + 'static,
+ S::Future: Send + 'static,
+ ResBody: http_body::Body + Send + 'static,
+ ResBody::Error: Into,
+{
+ MethodRouter::new().fallback(svc)
+}
+
+top_level_handler_fn!(delete, DELETE);
+top_level_handler_fn!(get, GET);
+top_level_handler_fn!(head, HEAD);
+top_level_handler_fn!(options, OPTIONS);
+top_level_handler_fn!(patch, PATCH);
+top_level_handler_fn!(post, POST);
+top_level_handler_fn!(put, PUT);
+top_level_handler_fn!(trace, TRACE);
+
+/// Route requests with the given method to the handler.
+///
+/// # Example
+///
+/// ```rust
+/// use axum::{
+/// routing::on,
+/// Router,
+/// routing::MethodFilter,
+/// };
+///
+/// async fn handler() {}
+///
+/// // Requests to `POST /` will go to `handler`.
+/// let app = Router::new().route("/", on(MethodFilter::POST, handler));
+/// # async {
+/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+/// # };
+/// ```
+pub fn on(filter: MethodFilter, handler: H) -> MethodRouter
+where
+ H: Handler,
+ B: Send + 'static,
+ T: 'static,
+{
+ MethodRouter::new().on(filter, handler)
+}
+
+/// Route requests with the given handler regardless of the method.
+///
+/// # Example
+///
+/// ```rust
+/// use axum::{
+/// routing::any,
+/// Router,
+/// };
+///
+/// async fn handler() {}
+///
+/// // All requests to `/` will go to `handler`.
+/// let app = Router::new().route("/", any(handler));
+/// # async {
+/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+/// # };
+/// ```
+///
+/// Additional methods can still be chained:
+///
+/// ```rust
+/// use axum::{
+/// routing::any,
+/// Router,
+/// };
+///
+/// async fn handler() {}
+///
+/// async fn other_handler() {}
+///
+/// // `POST /` goes to `other_handler`. All other requests go to `handler`
+/// let app = Router::new().route("/", any(handler).post(other_handler));
+/// # async {
+/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+/// # };
+/// ```
+pub fn any(handler: H) -> MethodRouter
+where
+ H: Handler,
+ B: Send + 'static,
+ T: 'static,
+{
+ MethodRouter::new().fallback_boxed_response_body(handler.into_service())
+}
+
+/// A [`Service`] that accepts requests based on a [`MethodFilter`] and
+/// allows chaining additional handlers and services.
+pub struct MethodRouter {
+ get: Option>,
+ head: Option>,
+ delete: Option>,
+ options: Option>,
+ patch: Option>,
+ post: Option>,
+ put: Option>,
+ trace: Option>,
+ fallback: Fallback,
+ _request_body: PhantomData (B, E)>,
+}
+
+impl fmt::Debug for MethodRouter {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("MethodRouter")
+ .field("get", &self.get)
+ .field("head", &self.head)
+ .field("delete", &self.delete)
+ .field("options", &self.options)
+ .field("patch", &self.patch)
+ .field("post", &self.post)
+ .field("put", &self.put)
+ .field("trace", &self.trace)
+ .field("fallback", &self.fallback)
+ .finish()
+ }
+}
+
+impl MethodRouter {
+ /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all
+ /// requests.
+ pub fn new() -> Self {
+ let fallback = Route::new(service_fn(|_: Request| async {
+ let mut response = Response::new(box_body(Empty::new()));
+ *response.status_mut() = StatusCode::METHOD_NOT_ALLOWED;
+ Ok(response)
+ }));
+
+ Self {
+ get: None,
+ head: None,
+ delete: None,
+ options: None,
+ patch: None,
+ post: None,
+ put: None,
+ trace: None,
+ fallback: Fallback::Default(fallback),
+ _request_body: PhantomData,
+ }
+ }
+}
+
+impl MethodRouter
+where
+ B: Send + 'static,
+{
+ /// Chain an additional handler that will accept requests matching the given
+ /// `MethodFilter`.
+ ///
+ /// # Example
+ ///
+ /// ```rust
+ /// use axum::{
+ /// routing::get,
+ /// Router,
+ /// routing::MethodFilter
+ /// };
+ ///
+ /// async fn handler() {}
+ ///
+ /// async fn other_handler() {}
+ ///
+ /// // Requests to `GET /` will go to `handler` and `DELETE /` will go to
+ /// // `other_handler`
+ /// let app = Router::new().route("/", get(handler).on(MethodFilter::DELETE, other_handler));
+ /// # async {
+ /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+ /// # };
+ /// ```
+ pub fn on(self, filter: MethodFilter, handler: H) -> Self
+ where
+ H: Handler,
+ T: 'static,
+ {
+ self.on_service_boxed_response_body(filter, handler.into_service())
+ }
+
+ chained_handler_fn!(delete, DELETE);
+ chained_handler_fn!(get, GET);
+ chained_handler_fn!(head, HEAD);
+ chained_handler_fn!(options, OPTIONS);
+ chained_handler_fn!(patch, PATCH);
+ chained_handler_fn!(post, POST);
+ chained_handler_fn!(put, PUT);
+ chained_handler_fn!(trace, TRACE);
+}
+
+impl MethodRouter {
+ /// Chain an additional service that will accept requests matching the given
+ /// `MethodFilter`.
+ ///
+ /// # Example
+ ///
+ /// ```rust
+ /// use axum::{
+ /// http::Request,
+ /// Router,
+ /// routing::{MethodFilter, on_service},
+ /// };
+ /// use http::Response;
+ /// use std::convert::Infallible;
+ /// use hyper::Body;
+ ///
+ /// let service = tower::service_fn(|request: Request| async {
+ /// Ok::<_, Infallible>(Response::new(Body::empty()))
+ /// });
+ ///
+ /// // Requests to `DELETE /` will go to `service`
+ /// let app = Router::new().route("/", on_service(MethodFilter::DELETE, service));
+ /// # async {
+ /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
+ /// # };
+ /// ```
+ pub fn on_service(self, filter: MethodFilter, svc: S) -> Self
+ where
+ S: Service, Response = Response, Error = E>
+ + Clone
+ + Send
+ + 'static,
+ S::Future: Send + 'static,
+ ResBody: http_body::Body + Send + 'static,
+ ResBody::Error: Into,
+ {
+ self.on_service_boxed_response_body(filter, svc.map_response(|res| res.map(box_body)))
+ }
+
+ chained_service_fn!(delete_service, DELETE);
+ chained_service_fn!(get_service, GET);
+ chained_service_fn!(head_service, HEAD);
+ chained_service_fn!(options_service, OPTIONS);
+ chained_service_fn!(patch_service, PATCH);
+ chained_service_fn!(post_service, POST);
+ chained_service_fn!(put_service, PUT);
+ chained_service_fn!(trace_service, TRACE);
+
+ #[doc = include_str!("../docs/method_routing/fallback.md")]
+ pub fn fallback(mut self, svc: S) -> Self
+ where
+ S: Service, Response = Response, Error = E>
+ + Clone
+ + Send
+ + 'static,
+ S::Future: Send + 'static,
+ ResBody: http_body::Body + Send + 'static,
+ ResBody::Error: Into,
+ {
+ self.fallback = Fallback::Custom(Route::new(svc.map_response(|res| res.map(box_body))));
+ self
+ }
+
+ fn fallback_boxed_response_body(mut self, svc: S) -> Self
+ where
+ S: Service, Response = Response, Error = E>
+ + Clone
+ + Send
+ + 'static,
+ S::Future: Send + 'static,
+ {
+ self.fallback = Fallback::Custom(Route::new(svc));
+ self
+ }
+
+ #[doc = include_str!("../docs/method_routing/layer.md")]
+ pub fn layer(
+ self,
+ layer: L,
+ ) -> MethodRouter
+ where
+ L: Layer>,
+ L::Service: Service, Response = Response, Error = NewError>
+ + Clone
+ + Send
+ + 'static,
+ >>::Future: Send + 'static,
+ NewResBody: http_body::Body + Send + 'static,
+ NewResBody::Error: Into,
+ {
+ let layer = ServiceBuilder::new()
+ .layer_fn(Route::new)
+ .layer(MapResponseBodyLayer::new(box_body))
+ .layer(layer)
+ .into_inner();
+ let layer_fn = |s| layer.layer(s);
+
+ MethodRouter {
+ get: self.get.map(layer_fn),
+ head: self.head.map(layer_fn),
+ delete: self.delete.map(layer_fn),
+ options: self.options.map(layer_fn),
+ patch: self.patch.map(layer_fn),
+ post: self.post.map(layer_fn),
+ put: self.put.map(layer_fn),
+ trace: self.trace.map(layer_fn),
+ fallback: self.fallback.map(layer_fn),
+ _request_body: PhantomData,
+ }
+ }
+
+ #[doc = include_str!("../docs/method_routing/route_layer.md")]
+ pub fn route_layer(self, layer: L) -> MethodRouter
+ where
+ L: Layer>,
+ L::Service: Service, Response = Response, Error = E>
+ + Clone
+ + Send
+ + 'static,
+ >>::Future: Send + 'static,
+ NewResBody: http_body::Body + Send + 'static,
+ NewResBody::Error: Into,
+ {
+ let layer = ServiceBuilder::new()
+ .layer_fn(Route::new)
+ .layer(MapResponseBodyLayer::new(box_body))
+ .layer(layer)
+ .into_inner();
+ let layer_fn = |s| layer.layer(s);
+
+ MethodRouter {
+ get: self.get.map(layer_fn),
+ head: self.head.map(layer_fn),
+ delete: self.delete.map(layer_fn),
+ options: self.options.map(layer_fn),
+ patch: self.patch.map(layer_fn),
+ post: self.post.map(layer_fn),
+ put: self.put.map(layer_fn),
+ trace: self.trace.map(layer_fn),
+ fallback: self.fallback,
+ _request_body: PhantomData,
+ }
+ }
+
+ #[doc = include_str!("../docs/method_routing/merge.md")]
+ pub fn merge(self, other: MethodRouter) -> Self {
+ macro_rules! merge {
+ ( $first:ident, $second:ident ) => {
+ match ($first, $second) {
+ (Some(_), Some(_)) => panic!(concat!(
+ "Overlapping method route. Cannot merge two method routes that both define `",
+ stringify!($first),
+ "`"
+ )),
+ (Some(svc), None) => Some(svc),
+ (None, Some(svc)) => Some(svc),
+ (None, None) => None,
+ }
+ };
+ }
+
+ let Self {
+ get,
+ head,
+ delete,
+ options,
+ patch,
+ post,
+ put,
+ trace,
+ fallback,
+ _request_body: _,
+ } = self;
+
+ let Self {
+ get: get_other,
+ head: head_other,
+ delete: delete_other,
+ options: options_other,
+ patch: patch_other,
+ post: post_other,
+ put: put_other,
+ trace: trace_other,
+ fallback: fallback_other,
+ _request_body: _,
+ } = other;
+
+ let get = merge!(get, get_other);
+ let head = merge!(head, head_other);
+ let delete = merge!(delete, delete_other);
+ let options = merge!(options, options_other);
+ let patch = merge!(patch, patch_other);
+ let post = merge!(post, post_other);
+ let put = merge!(put, put_other);
+ let trace = merge!(trace, trace_other);
+
+ let fallback = match (fallback, fallback_other) {
+ (pick @ Fallback::Default(_), Fallback::Default(_)) => pick,
+ (Fallback::Default(_), pick @ Fallback::Custom(_)) => pick,
+ (pick @ Fallback::Custom(_), Fallback::Default(_)) => pick,
+ (Fallback::Custom(_), Fallback::Custom(_)) => {
+ panic!("Cannot merge two `MethodRouter`s that both have a fallback")
+ }
+ };
+
+ Self {
+ get,
+ head,
+ delete,
+ options,
+ patch,
+ post,
+ put,
+ trace,
+ fallback,
+ _request_body: PhantomData,
+ }
+ }
+
+ /// Apply a [`HandleErrorLayer`].
+ ///
+ /// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`.
+ pub fn handle_error(self, f: F) -> MethodRouter
+ where
+ F: FnOnce(E) -> Res + Clone + Send + 'static,
+ Res: crate::response::IntoResponse,
+ ReqBody: Send + 'static,
+ E: 'static,
+ {
+ self.layer(HandleErrorLayer::new(f))
+ }
+
+ fn on_service_boxed_response_body(self, filter: MethodFilter, svc: S) -> Self
+ where
+ S: Service, Response = Response, Error = E>
+ + Clone
+ + Send
+ + 'static,
+ S::Future: Send + 'static,
+ {
+ // written with a pattern match like this to ensure we update all fields
+ let Self {
+ mut get,
+ mut head,
+ mut delete,
+ mut options,
+ mut patch,
+ mut post,
+ mut put,
+ mut trace,
+ fallback,
+ _request_body: _,
+ } = self;
+ let svc = Some(Route::new(svc));
+ if filter.contains(MethodFilter::GET) {
+ get = svc.clone();
+ }
+ if filter.contains(MethodFilter::HEAD) {
+ head = svc.clone();
+ }
+ if filter.contains(MethodFilter::DELETE) {
+ delete = svc.clone();
+ }
+ if filter.contains(MethodFilter::OPTIONS) {
+ options = svc.clone();
+ }
+ if filter.contains(MethodFilter::PATCH) {
+ patch = svc.clone();
+ }
+ if filter.contains(MethodFilter::POST) {
+ post = svc.clone();
+ }
+ if filter.contains(MethodFilter::PUT) {
+ put = svc.clone();
+ }
+ if filter.contains(MethodFilter::TRACE) {
+ trace = svc;
+ }
+ Self {
+ get,
+ head,
+ delete,
+ options,
+ patch,
+ post,
+ put,
+ trace,
+ fallback,
+ _request_body: PhantomData,
+ }
+ }
+}
+
+impl Clone for MethodRouter {
+ fn clone(&self) -> Self {
+ Self {
+ get: self.get.clone(),
+ head: self.head.clone(),
+ delete: self.delete.clone(),
+ options: self.options.clone(),
+ patch: self.patch.clone(),
+ post: self.post.clone(),
+ put: self.put.clone(),
+ trace: self.trace.clone(),
+ fallback: self.fallback.clone(),
+ _request_body: PhantomData,
+ }
+ }
+}
+
+impl Default for MethodRouter
+where
+ B: Send + 'static,
+{
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+use crate::routing::future::RouteFuture;
+
+impl Service> for MethodRouter {
+ type Response = Response;
+ type Error = E;
+ type Future = RouteFuture;
+
+ #[inline]
+ fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn call(&mut self, req: Request) -> Self::Future {
+ macro_rules! call {
+ (
+ $req:expr,
+ $method:expr,
+ $method_variant:ident,
+ $svc:expr
+ ) => {
+ if $method == Method::$method_variant {
+ if let Some(svc) = $svc {
+ return RouteFuture::new(svc.0.clone().oneshot($req))
+ .strip_body($method == Method::HEAD);
+ }
+ }
+ };
+ }
+
+ let method = req.method().clone();
+
+ // written with a pattern match like this to ensure we call all routes
+ let Self {
+ get,
+ head,
+ delete,
+ options,
+ patch,
+ post,
+ put,
+ trace,
+ fallback,
+ _request_body: _,
+ } = self;
+
+ call!(req, method, HEAD, head);
+ call!(req, method, HEAD, get);
+ call!(req, method, GET, get);
+ call!(req, method, POST, post);
+ call!(req, method, OPTIONS, options);
+ call!(req, method, PATCH, patch);
+ call!(req, method, PUT, put);
+ call!(req, method, DELETE, delete);
+ call!(req, method, TRACE, trace);
+
+ match fallback {
+ Fallback::Default(fallback) => {
+ RouteFuture::new(fallback.0.clone().oneshot(req)).strip_body(method == Method::HEAD)
+ }
+ Fallback::Custom(fallback) => {
+ RouteFuture::new(fallback.0.clone().oneshot(req)).strip_body(method == Method::HEAD)
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{body::Body, error_handling::HandleErrorLayer};
+ use std::time::Duration;
+ use tower::{timeout::TimeoutLayer, Service, ServiceExt};
+ use tower_http::{auth::RequireAuthorizationLayer, services::fs::ServeDir};
+
+ #[tokio::test]
+ async fn method_not_allowed_by_default() {
+ let mut svc = MethodRouter::new();
+ let (status, body) = call(Method::GET, &mut svc).await;
+ assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
+ assert!(body.is_empty());
+ }
+
+ #[tokio::test]
+ async fn get_handler() {
+ let mut svc = MethodRouter::new().get(ok);
+ let (status, body) = call(Method::GET, &mut svc).await;
+ assert_eq!(status, StatusCode::OK);
+ assert_eq!(body, "ok");
+ }
+
+ #[tokio::test]
+ async fn get_accepts_head() {
+ let mut svc = MethodRouter::new().get(ok);
+ let (status, body) = call(Method::HEAD, &mut svc).await;
+ assert_eq!(status, StatusCode::OK);
+ assert!(body.is_empty());
+ }
+
+ #[tokio::test]
+ async fn head_takes_precedence_over_get() {
+ let mut svc = MethodRouter::new().head(created).get(ok);
+ let (status, body) = call(Method::HEAD, &mut svc).await;
+ assert_eq!(status, StatusCode::CREATED);
+ assert!(body.is_empty());
+ }
+
+ #[tokio::test]
+ async fn merge() {
+ let mut svc = get(ok).merge(post(ok));
+
+ let (status, _) = call(Method::GET, &mut svc).await;
+ assert_eq!(status, StatusCode::OK);
+
+ let (status, _) = call(Method::POST, &mut svc).await;
+ assert_eq!(status, StatusCode::OK);
+ }
+
+ #[tokio::test]
+ async fn layer() {
+ let mut svc = MethodRouter::new()
+ .get(|| async { std::future::pending::<()>().await })
+ .layer(RequireAuthorizationLayer::bearer("password"));
+
+ // method with route
+ let (status, _) = call(Method::GET, &mut svc).await;
+ assert_eq!(status, StatusCode::UNAUTHORIZED);
+
+ // method without route
+ let (status, _) = call(Method::DELETE, &mut svc).await;
+ assert_eq!(status, StatusCode::UNAUTHORIZED);
+ }
+
+ #[tokio::test]
+ async fn route_layer() {
+ let mut svc = MethodRouter::new()
+ .get(|| async { std::future::pending::<()>().await })
+ .route_layer(RequireAuthorizationLayer::bearer("password"));
+
+ // method with route
+ let (status, _) = call(Method::GET, &mut svc).await;
+ assert_eq!(status, StatusCode::UNAUTHORIZED);
+
+ // method without route
+ let (status, _) = call(Method::DELETE, &mut svc).await;
+ assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
+ }
+
+ #[allow(dead_code)]
+ fn buiding_complex_router() {
+ let app = crate::Router::new().route(
+ "/",
+ // use the all the things :bomb:
+ get(ok)
+ .post(ok)
+ .route_layer(RequireAuthorizationLayer::bearer("password"))
+ .merge(delete_service(ServeDir::new(".")).handle_error(|_| StatusCode::NOT_FOUND))
+ .fallback((|| async { StatusCode::NOT_FOUND }).into_service())
+ .put(ok)
+ .layer(
+ ServiceBuilder::new()
+ .layer(HandleErrorLayer::new(|_| StatusCode::REQUEST_TIMEOUT))
+ .layer(TimeoutLayer::new(Duration::from_secs(10))),
+ ),
+ );
+
+ crate::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(app.into_make_service());
+ }
+
+ async fn call(method: Method, svc: &mut S) -> (StatusCode, String)
+ where
+ S: Service, Response = Response, Error = Infallible>,
+ {
+ let request = Request::builder()
+ .uri("/")
+ .method(method)
+ .body(Body::empty())
+ .unwrap();
+ let response = svc.ready().await.unwrap().call(request).await.unwrap();
+ let (parts, body) = response.into_parts();
+ let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap();
+ (parts.status, body)
+ }
+
+ async fn ok() -> (StatusCode, &'static str) {
+ (StatusCode::OK, "ok")
+ }
+
+ async fn created() -> (StatusCode, &'static str) {
+ (StatusCode::CREATED, "created")
+ }
+}
diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs
index 54e70952..57689760 100644
--- a/axum/src/routing/mod.rs
+++ b/axum/src/routing/mod.rs
@@ -7,6 +7,7 @@ use crate::{
connect_info::{Connected, IntoMakeServiceWithConnectInfo},
MatchedPath, OriginalUri,
},
+ routing::strip_prefix::StripPrefix,
util::{ByteStr, PercentDecodedByteStr},
BoxError,
};
@@ -20,18 +21,16 @@ use std::{
sync::Arc,
task::{Context, Poll},
};
-use tower::{util::ServiceExt, ServiceBuilder};
+use tower::{layer::layer_fn, ServiceBuilder};
use tower_http::map_response_body::MapResponseBodyLayer;
use tower_layer::Layer;
use tower_service::Service;
pub mod future;
-pub mod handler_method_routing;
-pub mod service_method_routing;
mod into_make_service;
mod method_filter;
-mod method_not_allowed;
+mod method_routing;
mod not_found;
mod route;
mod strip_prefix;
@@ -39,14 +38,12 @@ mod strip_prefix;
#[cfg(test)]
mod tests;
-pub use self::{
- into_make_service::IntoMakeService, method_filter::MethodFilter,
- method_not_allowed::MethodNotAllowed, route::Route,
-};
+pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route};
-#[doc(no_inline)]
-pub use self::handler_method_routing::{
- any, delete, get, head, on, options, patch, post, put, trace, MethodRouter,
+pub use self::method_routing::{
+ any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service,
+ options, options_service, patch, patch_service, post, post_service, put, put_service, trace,
+ trace_service, MethodRouter,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
@@ -63,7 +60,7 @@ impl RouteId {
/// The router type for composing handlers and services.
#[derive(Debug)]
pub struct Router {
- routes: HashMap>,
+ routes: HashMap>,
node: Node,
fallback: Fallback,
nested_at_root: bool,
@@ -131,11 +128,32 @@ where
let id = RouteId::next();
+ let service = match try_downcast::, _>(service) {
+ Ok(method_router) => {
+ if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
+ .node
+ .path_to_route_id
+ .get(path)
+ .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc)))
+ {
+ // if we're adding a new `MethodRouter` to a route that already has one just
+ // merge them. This makes `.route("/", get(_)).route("/", post(_))` work
+ let service =
+ Endpoint::MethodRouter(prev_method_router.clone().merge(method_router));
+ self.routes.insert(route_id, service);
+ return self;
+ } else {
+ Endpoint::MethodRouter(method_router)
+ }
+ }
+ Err(service) => Endpoint::Route(Route::new(service)),
+ };
+
if let Err(err) = self.node.insert(path, id) {
self.panic_on_matchit_error(err);
}
- self.routes.insert(id, Route::new(service));
+ self.routes.insert(id, service);
self
}
@@ -179,14 +197,22 @@ where
nested_at_root: _,
} = router;
- for (id, nested_path) in node.paths {
+ for (id, nested_path) in node.route_id_to_path {
let route = routes.remove(&id).unwrap();
let full_path = if &*nested_path == "/" {
path.to_string()
} else {
format!("{}{}", path, nested_path)
};
- self = self.route(&full_path, strip_prefix::StripPrefix::new(route, prefix));
+ self = match route {
+ Endpoint::MethodRouter(method_router) => self.route(
+ &full_path,
+ method_router.layer(layer_fn(|s| StripPrefix::new(s, prefix))),
+ ),
+ Endpoint::Route(route) => {
+ self.route(&full_path, StripPrefix::new(route, prefix))
+ }
+ };
}
debug_assert!(routes.is_empty());
@@ -248,20 +274,25 @@ where
NewResBody::Error: Into,
{
let layer = ServiceBuilder::new()
- .layer_fn(Route::new)
.layer(MapResponseBodyLayer::new(box_body))
- .layer(layer);
+ .layer(layer)
+ .into_inner();
let routes = self
.routes
.into_iter()
.map(|(id, route)| {
- let route = Layer::layer(&layer, route);
+ let route = match route {
+ Endpoint::MethodRouter(method_router) => {
+ Endpoint::MethodRouter(method_router.layer(&layer))
+ }
+ Endpoint::Route(route) => Endpoint::Route(Route::new(layer.layer(route))),
+ };
(id, route)
})
.collect();
- let fallback = self.fallback.map(|svc| Layer::layer(&layer, svc));
+ let fallback = self.fallback.map(|svc| Route::new(layer.layer(svc)));
Router {
routes,
@@ -284,15 +315,20 @@ where
NewResBody::Error: Into,
{
let layer = ServiceBuilder::new()
- .layer_fn(Route::new)
.layer(MapResponseBodyLayer::new(box_body))
- .layer(layer);
+ .layer(layer)
+ .into_inner();
let routes = self
.routes
.into_iter()
.map(|(id, route)| {
- let route = Layer::layer(&layer, route);
+ let route = match route {
+ Endpoint::MethodRouter(method_router) => {
+ Endpoint::MethodRouter(method_router.layer(&layer))
+ }
+ Endpoint::Route(route) => Endpoint::Route(Route::new(layer.layer(route))),
+ };
(id, route)
})
.collect();
@@ -360,7 +396,7 @@ where
let id = *match_.value;
req.extensions_mut().insert(id);
- if let Some(matched_path) = self.node.paths.get(&id) {
+ if let Some(matched_path) = self.node.route_id_to_path.get(&id) {
let matched_path = if let Some(previous) = req.extensions_mut().get::() {
// a previous `MatchedPath` might exist if we're inside a nested Router
let previous = if let Some(previous) =
@@ -388,13 +424,17 @@ where
insert_url_params(&mut req, params);
- let route = self
+ let mut route = self
.routes
.get(&id)
.expect("no route for id. This is a bug in axum. Please file an issue")
.clone();
- RouterFuture::from_oneshot(route.oneshot(req))
+ let future = match &mut route {
+ Endpoint::MethodRouter(inner) => inner.call(req),
+ Endpoint::Route(inner) => inner.call(req),
+ };
+ RouterFuture::from_future(future)
}
fn panic_on_matchit_error(&self, err: matchit::InsertError) {
@@ -449,10 +489,10 @@ where
} else {
match &self.fallback {
Fallback::Default(inner) => {
- RouterFuture::from_oneshot(inner.clone().oneshot(req))
+ RouterFuture::from_future(inner.clone().call(req))
}
Fallback::Custom(inner) => {
- RouterFuture::from_oneshot(inner.clone().oneshot(req))
+ RouterFuture::from_future(inner.clone().call(req))
}
}
}
@@ -537,7 +577,8 @@ pub(crate) struct InvalidUtf8InPathParam {
#[derive(Clone, Default)]
struct Node {
inner: matchit::Node,
- paths: HashMap>,
+ route_id_to_path: HashMap>,
+ path_to_route_id: HashMap, RouteId>,
}
impl Node {
@@ -547,13 +588,18 @@ impl Node {
val: RouteId,
) -> Result<(), matchit::InsertError> {
let path = path.into();
+
self.inner.insert(&path, val)?;
- self.paths.insert(val, path.into());
+
+ let shared_path: Arc = path.into();
+ self.route_id_to_path.insert(val, shared_path.clone());
+ self.path_to_route_id.insert(shared_path, val);
+
Ok(())
}
fn merge(&mut self, other: Node) -> Result<(), matchit::InsertError> {
- for (id, path) in other.paths {
+ for (id, path) in other.route_id_to_path {
self.insert(&*path, id)?;
}
Ok(())
@@ -569,16 +615,18 @@ impl Node {
impl fmt::Debug for Node {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- f.debug_struct("Node").field("paths", &self.paths).finish()
+ f.debug_struct("Node")
+ .field("paths", &self.route_id_to_path)
+ .finish()
}
}
-enum Fallback {
- Default(Route),
- Custom(Route),
+enum Fallback {
+ Default(Route),
+ Custom(Route),
}
-impl Clone for Fallback {
+impl Clone for Fallback {
fn clone(&self) -> Self {
match self {
Fallback::Default(inner) => Fallback::Default(inner.clone()),
@@ -587,7 +635,7 @@ impl Clone for Fallback {
}
}
-impl fmt::Debug for Fallback {
+impl fmt::Debug for Fallback {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
@@ -596,10 +644,10 @@ impl fmt::Debug for Fallback {
}
}
-impl Fallback {
- fn map(self, f: F) -> Fallback
+impl Fallback {
+ fn map(self, f: F) -> Fallback
where
- F: FnOnce(Route) -> Route,
+ F: FnOnce(Route) -> Route,
{
match self {
Fallback::Default(inner) => Fallback::Default(f(inner)),
@@ -622,6 +670,29 @@ where
}
}
+enum Endpoint {
+ MethodRouter(MethodRouter),
+ Route(Route),
+}
+
+impl Clone for Endpoint {
+ fn clone(&self) -> Self {
+ match self {
+ Endpoint::MethodRouter(inner) => Endpoint::MethodRouter(inner.clone()),
+ Endpoint::Route(inner) => Endpoint::Route(inner.clone()),
+ }
+ }
+}
+
+impl fmt::Debug for Endpoint {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::MethodRouter(inner) => inner.fmt(f),
+ Self::Route(inner) => inner.fmt(f),
+ }
+ }
+}
+
#[test]
fn traits() {
use crate::test_helpers::*;
diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs
index 26d4975a..69ee0b91 100644
--- a/axum/src/routing/route.rs
+++ b/axum/src/routing/route.rs
@@ -1,8 +1,9 @@
use crate::{
- body::{Body, BoxBody},
+ body::{box_body, Body, BoxBody},
clone_box_service::CloneBoxService,
};
use http::{Request, Response};
+use http_body::Empty;
use pin_project_lite::pin_project;
use std::{
convert::Infallible,
@@ -18,37 +19,36 @@ use tower_service::Service;
///
/// You normally shouldn't need to care about this type. It's used in
/// [`Router::layer`](super::Router::layer).
-pub struct Route(CloneBoxService, Response, Infallible>);
+pub struct Route(
+ pub(crate) CloneBoxService, Response, E>,
+);
-impl Route {
+impl Route {
pub(super) fn new(svc: T) -> Self
where
- T: Service, Response = Response, Error = Infallible>
- + Clone
- + Send
- + 'static,
+ T: Service, Response = Response, Error = E> + Clone + Send + 'static,
T::Future: Send + 'static,
{
Self(CloneBoxService::new(svc))
}
}
-impl Clone for Route {
+impl Clone for Route {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
-impl fmt::Debug for Route {
+impl fmt::Debug for Route {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Route").finish()
}
}
-impl Service> for Route {
+impl Service> for Route {
type Response = Response;
- type Error = Infallible;
- type Future = RouteFuture;
+ type Error = E;
+ type Future = RouteFuture;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> {
@@ -63,29 +63,50 @@ impl Service> for Route {
pin_project! {
/// Response future for [`Route`].
- pub struct RouteFuture {
+ pub struct RouteFuture {
#[pin]
future: Oneshot<
- CloneBoxService, Response, Infallible>,
+ CloneBoxService, Response, E>,
Request,
- >
+ >,
+ strip_body: bool,
}
}
-impl RouteFuture {
+impl RouteFuture {
pub(crate) fn new(
- future: Oneshot, Response, Infallible>, Request>,
+ future: Oneshot, Response, E>, Request>,
) -> Self {
- RouteFuture { future }
+ RouteFuture {
+ future,
+ strip_body: false,
+ }
+ }
+
+ pub(crate) fn strip_body(mut self, strip_body: bool) -> Self {
+ self.strip_body = strip_body;
+ self
}
}
-impl Future for RouteFuture {
- type Output = Result, Infallible>;
+impl Future for RouteFuture {
+ type Output = Result, E>;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll {
- self.project().future.poll(cx)
+ let strip_body = self.strip_body;
+
+ match self.project().future.poll(cx) {
+ Poll::Ready(Ok(res)) => {
+ if strip_body {
+ Poll::Ready(Ok(res.map(|_| box_body(Empty::new()))))
+ } else {
+ Poll::Ready(Ok(res))
+ }
+ }
+ Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
+ Poll::Pending => Poll::Pending,
+ }
}
}
diff --git a/axum/src/routing/service_method_routing.rs b/axum/src/routing/service_method_routing.rs
deleted file mode 100644
index 7588d560..00000000
--- a/axum/src/routing/service_method_routing.rs
+++ /dev/null
@@ -1,559 +0,0 @@
-//! Routing for [`Service`'s] based on HTTP methods.
-//!
-//! Most of the time applications will be written by composing
-//! [handlers](crate::handler), however sometimes you might have some general
-//! [`Service`] that you want to route requests to. That is enabled by the
-//! functions in this module.
-//!
-//! # Example
-//!
-//! Using [`Redirect`] to redirect requests can be done like so:
-//!
-//! ```
-//! use tower_http::services::Redirect;
-//! use axum::{
-//! body::Body,
-//! routing::{get, service_method_routing as service},
-//! http::Request,
-//! Router,
-//! };
-//!
-//! async fn handler(request: Request) { /* ... */ }
-//!
-//! let redirect_service = Redirect::::permanent("/new".parse().unwrap());
-//!
-//! let app = Router::new()
-//! .route("/old", service::get(redirect_service))
-//! .route("/new", get(handler));
-//! # async {
-//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
-//! # };
-//! ```
-//!
-//! # Regarding backpressure and `Service::poll_ready`
-//!
-//! Generally routing to one of multiple services and backpressure doesn't mix
-//! well. Ideally you would want ensure a service is ready to receive a request
-//! before calling it. However, in order to know which service to call, you need
-//! the request...
-//!
-//! One approach is to not consider the router service itself ready until all
-//! destination services are ready. That is the approach used by
-//! [`tower::steer::Steer`].
-//!
-//! Another approach is to always consider all services ready (always return
-//! `Poll::Ready(Ok(()))`) from `Service::poll_ready` and then actually drive
-//! readiness inside the response future returned by `Service::call`. This works
-//! well when your services don't care about backpressure and are always ready
-//! anyway.
-//!
-//! axum expects that all services used in your app wont care about
-//! backpressure and so it uses the latter strategy. However that means you
-//! should avoid routing to a service (or using a middleware) that _does_ care
-//! about backpressure. At the very least you should [load shed] so requests are
-//! dropped quickly and don't keep piling up.
-//!
-//! It also means that if `poll_ready` returns an error then that error will be
-//! returned in the response future from `call` and _not_ from `poll_ready`. In
-//! that case, the underlying service will _not_ be discarded and will continue
-//! to be used for future requests. Services that expect to be discarded if
-//! `poll_ready` fails should _not_ be used with axum.
-//!
-//! One possible approach is to only apply backpressure sensitive middleware
-//! around your entire app. This is possible because axum applications are
-//! themselves services:
-//!
-//! ```rust
-//! use axum::{
-//! routing::get,
-//! Router,
-//! };
-//! use tower::ServiceBuilder;
-//! # let some_backpressure_sensitive_middleware =
-//! # tower::layer::util::Identity::new();
-//!
-//! async fn handler() { /* ... */ }
-//!
-//! let app = Router::new().route("/", get(handler));
-//!
-//! let app = ServiceBuilder::new()
-//! .layer(some_backpressure_sensitive_middleware)
-//! .service(app);
-//! # async {
-//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
-//! # };
-//! ```
-//!
-//! However when applying middleware around your whole application in this way
-//! you have to take care that errors are still being handled with
-//! appropriately.
-//!
-//! Also note that handlers created from async functions don't care about
-//! backpressure and are always ready. So if you're not using any Tower
-//! middleware you don't have to worry about any of this.
-//!
-//! [`Redirect`]: tower_http::services::Redirect
-//! [load shed]: tower::load_shed
-//! [`Service`'s]: tower::Service
-
-use crate::{
- body::{box_body, BoxBody},
- routing::{MethodFilter, MethodNotAllowed},
- util::{Either, EitherProj},
- BoxError,
-};
-use bytes::Bytes;
-use futures_util::ready;
-use http::{Method, Request, Response};
-use http_body::Empty;
-use pin_project_lite::pin_project;
-use std::{
- fmt,
- future::Future,
- marker::PhantomData,
- pin::Pin,
- task::{Context, Poll},
-};
-use tower::{util::Oneshot, ServiceExt as _};
-use tower_service::Service;
-
-/// Route requests with any standard HTTP method to the given service.
-///
-/// See [`get`] for an example.
-///
-/// Note that this only accepts the standard HTTP methods. If you need to
-/// support non-standard methods you can route directly to a [`Service`].
-pub fn any