diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 546553a9..0feea34c 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,7 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **breaking:** New `MethodRouter` that works similarly to `Router`: + - Route to handlers and services with the same type + - Add middleware to some routes more easily with `MethodRouter::layer` and + `MethodRouter::route_layer`. + - Merge method routers with `MethodRouter::merge` + - Customize response for unsupported methods with `MethodRouter::fallback` +- **fixed:** Adding the same route with different methods now works ie + `.route("/", get(_)).route("/", post(_))`. +- **breaking:** `routing::handler_method_router` and + `routing::service_method_router` has been removed in favor of + `routing::{get, get_service, ..., MethodRouter}`. +- **breaking:** `HandleErrorExt` has been removed in favor of + `MethodRouter::handle_error`. # 0.3.3 (13. November, 2021) diff --git a/axum/src/docs/error_handling.md b/axum/src/docs/error_handling.md index 7c9c3572..508302ab 100644 --- a/axum/src/docs/error_handling.md +++ b/axum/src/docs/error_handling.md @@ -43,14 +43,12 @@ functions as handlers. However if you're embedding general `Service`s or applying middleware, which might produce errors you have to tell axum how to convert those errors into responses. -You can handle errors from services using [`HandleErrorExt::handle_error`]: - ```rust use axum::{ Router, body::Body, http::{Request, Response, StatusCode}, - error_handling::HandleErrorExt, // for `.handle_error()` + error_handling::HandleError, }; async fn thing_that_might_fail() -> Result<(), anyhow::Error> { @@ -69,7 +67,7 @@ let app = Router::new().route( // we cannot route to `some_fallible_service` directly since it might fail. // we have to use `handle_error` which converts its errors into responses // and changes its error type from `anyhow::Error` to `Infallible`. - some_fallible_service.handle_error(handle_anyhow_error), + HandleError::new(some_fallible_service, handle_anyhow_error), ); // handle errors by converting them into something that implements diff --git a/axum/src/docs/method_routing/fallback.md b/axum/src/docs/method_routing/fallback.md new file mode 100644 index 00000000..c027578c --- /dev/null +++ b/axum/src/docs/method_routing/fallback.md @@ -0,0 +1,53 @@ +Add a fallback service to the router. + +This service will be called if no routes matches the incoming request. + +```rust +use axum::{ + Router, + routing::get, + handler::Handler, + response::IntoResponse, + http::{StatusCode, Method, Uri}, +}; + +let handler = get(|| async {}).fallback(fallback.into_service()); + +let app = Router::new().route("/", handler); + +async fn fallback(method: Method, uri: Uri) -> impl IntoResponse { + (StatusCode::NOT_FOUND, format!("`{}` not allowed for {}", method, uri)) +} +# async { +# hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +# }; +``` + +## When used with `MethodRouter::merge` + +Two routers that both have a fallback cannot be merged. Doing so results in a +panic: + +```rust,should_panic +use axum::{ + routing::{get, post}, + handler::Handler, + response::IntoResponse, + http::{StatusCode, Uri}, +}; + +let one = get(|| async {}) + .fallback(fallback_one.into_service()); + +let two = post(|| async {}) + .fallback(fallback_two.into_service()); + +let method_route = one.merge(two); + +async fn fallback_one() -> impl IntoResponse {} +async fn fallback_two() -> impl IntoResponse {} +# let app = axum::Router::new().route("/", method_route); +# async { +# hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +# }; +``` diff --git a/axum/src/docs/method_routing/layer.md b/axum/src/docs/method_routing/layer.md new file mode 100644 index 00000000..10bde2ab --- /dev/null +++ b/axum/src/docs/method_routing/layer.md @@ -0,0 +1,28 @@ +Apply a [`tower::Layer`] to the router. + +All requests to the router will be processed by the layer's +corresponding middleware. + +This can be used to add additional processing to a request for a group +of routes. + +Works similarly to [`Router::layer`](super::Router::layer). See that method for +more details. + +# Example + +```rust +use axum::{routing::get, Router}; +use tower::limit::ConcurrencyLimitLayer; + +async fn hander() {} + +let app = Router::new().route( + "/", + // All requests to `GET /` will be sent through `ConcurrencyLimitLayer` + get(hander).layer(ConcurrencyLimitLayer::new(64)), +); +# async { +# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +# }; +``` diff --git a/axum/src/docs/method_routing/merge.md b/axum/src/docs/method_routing/merge.md new file mode 100644 index 00000000..39d74d04 --- /dev/null +++ b/axum/src/docs/method_routing/merge.md @@ -0,0 +1,25 @@ +Merge two routers into one. + +This is useful for breaking routers into smaller pieces and combining them +into one. + +```rust +use axum::{ + routing::{get, post}, + Router, +}; + +let get = get(|| async {}); +let post = post(|| async {}); + +let merged = get.merge(post); + +let app = Router::new().route("/", merged); + +// Our app now accepts +// - GET / +// - POST / +# async { +# hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +# }; +``` diff --git a/axum/src/docs/method_routing/route_layer.md b/axum/src/docs/method_routing/route_layer.md new file mode 100644 index 00000000..c2e2c061 --- /dev/null +++ b/axum/src/docs/method_routing/route_layer.md @@ -0,0 +1,30 @@ +Apply a [`tower::Layer`] to the router that will only run if the request matches +a route. + +This works similarly to [`MethodRouter::layer`] except the middleware will only run if +the request matches a route. This is useful for middleware that return early +(such as authorization) which might otherwise convert a `405 Method Not Allowed` into a +`401 Unauthorized`. + +# Example + +```rust +use axum::{ + routing::get, + Router, +}; +use tower_http::auth::RequireAuthorizationLayer; + +let app = Router::new().route( + "/foo", + get(|| async {}) + .route_layer(RequireAuthorizationLayer::bearer("password")) +); + +// `GET /foo` with a valid token will receive `200 OK` +// `GET /foo` with a invalid token will receive `401 Unauthorized` +// `POST /FOO` with a invalid token will receive `405 Method Not Allowed` +# async { +# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +# }; +``` diff --git a/axum/src/docs/routing/nest.md b/axum/src/docs/routing/nest.md index 468968dd..be913968 100644 --- a/axum/src/docs/routing/nest.md +++ b/axum/src/docs/routing/nest.md @@ -72,15 +72,14 @@ let app = Router::new().nest("/:version/api", users_api); ```rust use axum::{ Router, - routing::service_method_routing::get, - error_handling::HandleErrorExt, + routing::get_service, http::StatusCode, }; use std::{io, convert::Infallible}; use tower_http::services::ServeDir; // Serves files inside the `public` directory at `GET /public/*` -let serve_dir_service = ServeDir::new("public") +let serve_dir_service = get_service(ServeDir::new("public")) .handle_error(|error: io::Error| { ( StatusCode::INTERNAL_SERVER_ERROR, @@ -88,7 +87,7 @@ let serve_dir_service = ServeDir::new("public") ) }); -let app = Router::new().nest("/public", get(serve_dir_service)); +let app = Router::new().nest("/public", serve_dir_service); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; diff --git a/axum/src/docs/routing/route.md b/axum/src/docs/routing/route.md index 7d7cb245..4b4ad4d1 100644 --- a/axum/src/docs/routing/route.md +++ b/axum/src/docs/routing/route.md @@ -111,8 +111,7 @@ axum also supports routing to general [`Service`]s: use axum::{ Router, body::Body, - routing::service_method_routing as service, - error_handling::HandleErrorExt, + routing::{any_service, get_service}, http::{Request, StatusCode}, }; use tower_http::services::ServeFile; @@ -125,9 +124,9 @@ let app = Router::new() // Any request to `/` goes to a service "/", // Services whose response body is not `axum::body::BoxBody` - // can be wrapped in `axum::service::any` (or one of the other routing filters) + // can be wrapped in `axum::routing::any_service` (or one of the other routing filters) // to have the response body mapped - service::any(service_fn(|_: Request| async { + any_service(service_fn(|_: Request| async { let res = Response::new(Body::from("Hi from `GET /`")); Ok::<_, Infallible>(res) })) @@ -146,7 +145,7 @@ let app = Router::new() .route( // GET `/static/Cargo.toml` goes to a service from tower-http "/static/Cargo.toml", - service::get(ServeFile::new("Cargo.toml")) + get_service(ServeFile::new("Cargo.toml")) // though we must handle any potential errors .handle_error(|error: io::Error| { ( @@ -161,8 +160,10 @@ let app = Router::new() ``` Routing to arbitrary services in this way has complications for backpressure -([`Service::poll_ready`]). See the [`service_method_routing`] module for more -details. +([`Service::poll_ready`]). See the [Routing to services and backpressure] module +for more details. + +[Routing to services and backpressure]: /#routing-to-services-and-backpressure # Panics diff --git a/axum/src/error_handling/mod.rs b/axum/src/error_handling/mod.rs index 8eb20824..b1559549 100644 --- a/axum/src/error_handling/mod.rs +++ b/axum/src/error_handling/mod.rs @@ -134,19 +134,6 @@ where } } -/// Extension trait to [`Service`] for handling errors by mapping them to -/// responses. -/// -/// See [module docs](self) for more details on axum's error handling model. -pub trait HandleErrorExt: Service> + Sized { - /// Apply a [`HandleError`] middleware. - fn handle_error(self, f: F) -> HandleError { - HandleError::new(self, f) - } -} - -impl HandleErrorExt for S where S: Service> {} - pub mod future { //! Future types. diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index f0799f07..68e80045 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -404,11 +404,4 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "you said: hi there!"); } - - #[test] - fn traits() { - use crate::{routing::MethodRouter, test_helpers::*}; - assert_send::>(); - assert_sync::>(); - } } diff --git a/axum/src/lib.rs b/axum/src/lib.rs index 49f33497..d2060072 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -11,6 +11,7 @@ //! - [Responses](#responses) //! - [Error handling](#error-handling) //! - [Middleware](#middleware) +//! - [Routing to services and backpressure](#routing-to-services-and-backpressure) //! - [Sharing state with handlers](#sharing-state-with-handlers) //! - [Required dependencies](#required-dependencies) //! - [Examples](#examples) @@ -160,6 +161,68 @@ //! #![doc = include_str!("docs/middleware.md")] //! +//! # Routing to services and backpressure +//! +//! Generally routing to one of multiple services and backpressure doesn't mix +//! well. Ideally you would want ensure a service is ready to receive a request +//! before calling it. However, in order to know which service to call, you need +//! the request... +//! +//! One approach is to not consider the router service itself ready until all +//! destination services are ready. That is the approach used by +//! [`tower::steer::Steer`]. +//! +//! Another approach is to always consider all services ready (always return +//! `Poll::Ready(Ok(()))`) from `Service::poll_ready` and then actually drive +//! readiness inside the response future returned by `Service::call`. This works +//! well when your services don't care about backpressure and are always ready +//! anyway. +//! +//! axum expects that all services used in your app wont care about +//! backpressure and so it uses the latter strategy. However that means you +//! should avoid routing to a service (or using a middleware) that _does_ care +//! about backpressure. At the very least you should [load shed] so requests are +//! dropped quickly and don't keep piling up. +//! +//! It also means that if `poll_ready` returns an error then that error will be +//! returned in the response future from `call` and _not_ from `poll_ready`. In +//! that case, the underlying service will _not_ be discarded and will continue +//! to be used for future requests. Services that expect to be discarded if +//! `poll_ready` fails should _not_ be used with axum. +//! +//! One possible approach is to only apply backpressure sensitive middleware +//! around your entire app. This is possible because axum applications are +//! themselves services: +//! +//! ```rust +//! use axum::{ +//! routing::get, +//! Router, +//! }; +//! use tower::ServiceBuilder; +//! # let some_backpressure_sensitive_middleware = +//! # tower::layer::util::Identity::new(); +//! +//! async fn handler() { /* ... */ } +//! +//! let app = Router::new().route("/", get(handler)); +//! +//! let app = ServiceBuilder::new() +//! .layer(some_backpressure_sensitive_middleware) +//! .service(app); +//! # async { +//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; +//! ``` +//! +//! However when applying middleware around your whole application in this way +//! you have to take care that errors are still being handled with +//! appropriately. +//! +//! Also note that handlers created from async functions don't care about +//! backpressure and are always ready. So if you're not using any Tower +//! middleware you don't have to worry about any of this. +//! //! # Sharing state with handlers //! //! It is common to share some state between handlers for example to share a @@ -255,8 +318,8 @@ //! [`OriginalUri`]: crate::extract::OriginalUri //! [`Service`]: tower::Service //! [`Service::poll_ready`]: tower::Service::poll_ready +//! [`Service`'s]: tower::Service //! [`tower::Service`]: tower::Service -//! [`handle_error`]: error_handling::HandleErrorExt::handle_error //! [tower-guides]: https://github.com/tower-rs/tower/tree/master/guides //! [`Uuid`]: https://docs.rs/uuid/latest/uuid/ //! [`FromRequest`]: crate::extract::FromRequest @@ -267,6 +330,7 @@ //! [`debug_handler`]: https://docs.rs/axum-debug/latest/axum_debug/attr.debug_handler.html //! [`Handler`]: crate::handler::Handler //! [`Infallible`]: std::convert::Infallible +//! [load shed]: tower::load_shed #![warn( clippy::all, diff --git a/axum/src/routing/future.rs b/axum/src/routing/future.rs index e09be0ae..1cd974ff 100644 --- a/axum/src/routing/future.rs +++ b/axum/src/routing/future.rs @@ -2,26 +2,22 @@ use crate::body::BoxBody; use futures_util::future::Either; -use http::{Request, Response}; +use http::Response; use std::{convert::Infallible, future::ready}; -use tower::util::Oneshot; -pub use super::{ - into_make_service::IntoMakeServiceFuture, method_not_allowed::MethodNotAllowedFuture, - route::RouteFuture, -}; +pub use super::{into_make_service::IntoMakeServiceFuture, route::RouteFuture}; opaque_future! { /// Response future for [`Router`](super::Router). pub type RouterFuture = futures_util::future::Either< - Oneshot, Request>, + RouteFuture, std::future::Ready, Infallible>>, >; } impl RouterFuture { - pub(super) fn from_oneshot(future: Oneshot, Request>) -> Self { + pub(super) fn from_future(future: RouteFuture) -> Self { Self::new(Either::Left(future)) } diff --git a/axum/src/routing/method_filter.rs b/axum/src/routing/method_filter.rs index 251be637..82e823a5 100644 --- a/axum/src/routing/method_filter.rs +++ b/axum/src/routing/method_filter.rs @@ -1,5 +1,4 @@ use bitflags::bitflags; -use http::Method; bitflags! { /// A filter that matches one or more HTTP methods. @@ -22,21 +21,3 @@ bitflags! { const TRACE = 0b100000000; } } - -impl MethodFilter { - #[allow(clippy::match_like_matches_macro)] - pub(crate) fn matches(self, method: &Method) -> bool { - let method = match *method { - Method::DELETE => Self::DELETE, - Method::GET => Self::GET, - Method::HEAD => Self::HEAD, - Method::OPTIONS => Self::OPTIONS, - Method::PATCH => Self::PATCH, - Method::POST => Self::POST, - Method::PUT => Self::PUT, - Method::TRACE => Self::TRACE, - _ => return false, - }; - self.contains(method) - } -} diff --git a/axum/src/routing/method_not_allowed.rs b/axum/src/routing/method_not_allowed.rs deleted file mode 100644 index 6f812b11..00000000 --- a/axum/src/routing/method_not_allowed.rs +++ /dev/null @@ -1,82 +0,0 @@ -use crate::body::BoxBody; -use http::{Request, Response, StatusCode}; -use std::{ - convert::Infallible, - fmt, - future::ready, - marker::PhantomData, - task::{Context, Poll}, -}; -use tower_service::Service; - -/// A [`Service`] that responds with `405 Method not allowed` to all requests. -/// -/// This is used as the bottom service in a method router. You shouldn't have to -/// use it manually. -pub struct MethodNotAllowed { - _marker: PhantomData E>, -} - -impl MethodNotAllowed { - pub(crate) fn new() -> Self { - Self { - _marker: PhantomData, - } - } -} - -impl Clone for MethodNotAllowed { - fn clone(&self) -> Self { - Self { - _marker: PhantomData, - } - } -} - -impl fmt::Debug for MethodNotAllowed { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("MethodNotAllowed").finish() - } -} - -impl Service> for MethodNotAllowed -where - B: Send + 'static, -{ - type Response = Response; - type Error = E; - type Future = MethodNotAllowedFuture; - - #[inline] - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, _req: Request) -> Self::Future { - let res = Response::builder() - .status(StatusCode::METHOD_NOT_ALLOWED) - .body(crate::body::empty()) - .unwrap(); - - MethodNotAllowedFuture::new(ready(Ok(res))) - } -} - -opaque_future! { - /// Response future for [`MethodNotAllowed`](super::MethodNotAllowed). - pub type MethodNotAllowedFuture = - std::future::Ready, E>>; -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn traits() { - use crate::test_helpers::*; - - assert_send::>(); - assert_sync::>(); - } -} diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs new file mode 100644 index 00000000..271dfa45 --- /dev/null +++ b/axum/src/routing/method_routing.rs @@ -0,0 +1,1085 @@ +use crate::{ + body::{box_body, Body, BoxBody, Bytes}, + error_handling::HandleErrorLayer, + handler::Handler, + http::{Method, Request, Response, StatusCode}, + routing::{Fallback, MethodFilter, Route}, + BoxError, +}; +use http_body::Empty; +use std::{ + convert::Infallible, + fmt, + marker::PhantomData, + task::{Context, Poll}, +}; +use tower::{service_fn, ServiceBuilder, ServiceExt}; +use tower_http::map_response_body::MapResponseBodyLayer; +use tower_layer::Layer; +use tower_service::Service; + +macro_rules! top_level_service_fn { + ( + $name:ident, GET + ) => { + top_level_service_fn!( + /// Route `GET` requests to the given service. + /// + /// # Example + /// + /// ```rust + /// use axum::{ + /// http::Request, + /// Router, + /// routing::get_service, + /// }; + /// use http::Response; + /// use std::convert::Infallible; + /// use hyper::Body; + /// + /// let service = tower::service_fn(|request: Request| async { + /// Ok::<_, Infallible>(Response::new(Body::empty())) + /// }); + /// + /// // Requests to `GET /` will go to `service`. + /// let app = Router::new().route("/", get_service(service)); + /// # async { + /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; + /// ``` + /// + /// Note that `get` routes will also be called for `HEAD` requests but will have + /// the response body removed. Make sure to add explicit `HEAD` routes + /// afterwards. + $name, + GET + ); + }; + + ( + $name:ident, $method:ident + ) => { + top_level_service_fn!( + #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")] + /// + /// See [`get_service`] for an example. + $name, + $method + ); + }; + + ( + $(#[$m:meta])+ + $name:ident, $method:ident + ) => { + $(#[$m])+ + pub fn $name(svc: S) -> MethodRouter + where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Error: Into, + { + on_service(MethodFilter::$method, svc) + } + }; +} + +macro_rules! top_level_handler_fn { + ( + $name:ident, GET + ) => { + top_level_handler_fn!( + /// Route `GET` requests to the given handler. + /// + /// # Example + /// + /// ```rust + /// use axum::{ + /// routing::get, + /// Router, + /// }; + /// + /// async fn handler() {} + /// + /// // Requests to `GET /` will go to `handler`. + /// let app = Router::new().route("/", get(handler)); + /// # async { + /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; + /// ``` + /// + /// Note that `get` routes will also be called for `HEAD` requests but will have + /// the response body removed. Make sure to add explicit `HEAD` routes + /// afterwards. + $name, + GET + ); + }; + + ( + $name:ident, $method:ident + ) => { + top_level_handler_fn!( + #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")] + /// + /// See [`get`] for an example. + $name, + $method + ); + }; + + ( + $(#[$m:meta])+ + $name:ident, $method:ident + ) => { + $(#[$m])+ + pub fn $name(handler: H) -> MethodRouter + where + H: Handler, + B: Send + 'static, + T: 'static, + { + on(MethodFilter::$method, handler) + } + }; +} + +macro_rules! chained_service_fn { + ( + $name:ident, GET + ) => { + chained_service_fn!( + /// Chain an additional service that will only accept `GET` requests. + /// + /// # Example + /// + /// ```rust + /// use axum::{ + /// http::Request, + /// Router, + /// routing::post_service, + /// }; + /// use http::Response; + /// use std::convert::Infallible; + /// use hyper::Body; + /// + /// let service = tower::service_fn(|request: Request| async { + /// Ok::<_, Infallible>(Response::new(Body::empty())) + /// }); + /// + /// let other_service = tower::service_fn(|request: Request| async { + /// Ok::<_, Infallible>(Response::new(Body::empty())) + /// }); + /// + /// // Requests to `GET /` will go to `service` and `POST /` will go to + /// // `other_service`. + /// let app = Router::new().route("/", post_service(service).get_service(other_service)); + /// # async { + /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; + /// ``` + /// + /// Note that `get` routes will also be called for `HEAD` requests but will have + /// the response body removed. Make sure to add explicit `HEAD` routes + /// afterwards. + $name, + GET + ); + }; + + ( + $name:ident, $method:ident + ) => { + chained_service_fn!( + #[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")] + /// + /// See [`MethodRouter::get_service`] for an example. + $name, + $method + ); + }; + + ( + $(#[$m:meta])+ + $name:ident, $method:ident + ) => { + $(#[$m])+ + pub fn $name(self, svc: S) -> Self + where + S: Service, Response = Response, Error = E> + + Clone + + Send + + 'static, + S::Future: Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Error: Into, + { + self.on_service(MethodFilter::$method, svc) + } + }; +} + +macro_rules! chained_handler_fn { + ( + $name:ident, GET + ) => { + chained_handler_fn!( + /// Chain an additional handler that will only accept `GET` requests. + /// + /// # Example + /// + /// ```rust + /// use axum::{routing::post, Router}; + /// + /// async fn handler() {} + /// + /// async fn other_handler() {} + /// + /// // Requests to `GET /` will go to `handler` and `POST /` will go to + /// // `other_handler`. + /// let app = Router::new().route("/", post(handler).get(other_handler)); + /// # async { + /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; + /// ``` + /// + /// Note that `get` routes will also be called for `HEAD` requests but will have + /// the response body removed. Make sure to add explicit `HEAD` routes + /// afterwards. + $name, + GET + ); + }; + + ( + $name:ident, $method:ident + ) => { + chained_handler_fn!( + #[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")] + /// + /// See [`MethodRouter::get`] for an example. + $name, + $method + ); + }; + + ( + $(#[$m:meta])+ + $name:ident, $method:ident + ) => { + $(#[$m])+ + pub fn $name(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + self.on(MethodFilter::$method, handler) + } + }; +} + +top_level_service_fn!(delete_service, DELETE); +top_level_service_fn!(get_service, GET); +top_level_service_fn!(head_service, HEAD); +top_level_service_fn!(options_service, OPTIONS); +top_level_service_fn!(patch_service, PATCH); +top_level_service_fn!(post_service, POST); +top_level_service_fn!(put_service, PUT); +top_level_service_fn!(trace_service, TRACE); + +/// Route requests with the given method to the service. +/// +/// # Example +/// +/// ```rust +/// use axum::{ +/// http::Request, +/// routing::on, +/// Router, +/// routing::{MethodFilter, on_service}, +/// }; +/// use http::Response; +/// use std::convert::Infallible; +/// use hyper::Body; +/// +/// let service = tower::service_fn(|request: Request| async { +/// Ok::<_, Infallible>(Response::new(Body::empty())) +/// }); +/// +/// // Requests to `POST /` will go to `service`. +/// let app = Router::new().route("/", on_service(MethodFilter::POST, service)); +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +pub fn on_service( + filter: MethodFilter, + svc: S, +) -> MethodRouter +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Error: Into, +{ + MethodRouter::new().on_service(filter, svc) +} + +/// Route requests to the given service regardless of its method. +/// +/// # Example +/// +/// ```rust +/// use axum::{ +/// http::Request, +/// Router, +/// routing::any_service, +/// }; +/// use http::Response; +/// use std::convert::Infallible; +/// use hyper::Body; +/// +/// let service = tower::service_fn(|request: Request| async { +/// Ok::<_, Infallible>(Response::new(Body::empty())) +/// }); +/// +/// // All requests to `/` will go to `service`. +/// let app = Router::new().route("/", any_service(service)); +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// Additional methods can still be chained: +/// +/// ```rust +/// use axum::{ +/// http::Request, +/// Router, +/// routing::any_service, +/// }; +/// use http::Response; +/// use std::convert::Infallible; +/// use hyper::Body; +/// +/// let service = tower::service_fn(|request: Request| async { +/// # Ok::<_, Infallible>(Response::new(Body::empty())) +/// // ... +/// }); +/// +/// let other_service = tower::service_fn(|request: Request| async { +/// # Ok::<_, Infallible>(Response::new(Body::empty())) +/// // ... +/// }); +/// +/// // `POST /` goes to `other_service`. All other requests go to `service` +/// let app = Router::new().route("/", any_service(service).post_service(other_service)); +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +pub fn any_service(svc: S) -> MethodRouter +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Error: Into, +{ + MethodRouter::new().fallback(svc) +} + +top_level_handler_fn!(delete, DELETE); +top_level_handler_fn!(get, GET); +top_level_handler_fn!(head, HEAD); +top_level_handler_fn!(options, OPTIONS); +top_level_handler_fn!(patch, PATCH); +top_level_handler_fn!(post, POST); +top_level_handler_fn!(put, PUT); +top_level_handler_fn!(trace, TRACE); + +/// Route requests with the given method to the handler. +/// +/// # Example +/// +/// ```rust +/// use axum::{ +/// routing::on, +/// Router, +/// routing::MethodFilter, +/// }; +/// +/// async fn handler() {} +/// +/// // Requests to `POST /` will go to `handler`. +/// let app = Router::new().route("/", on(MethodFilter::POST, handler)); +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +pub fn on(filter: MethodFilter, handler: H) -> MethodRouter +where + H: Handler, + B: Send + 'static, + T: 'static, +{ + MethodRouter::new().on(filter, handler) +} + +/// Route requests with the given handler regardless of the method. +/// +/// # Example +/// +/// ```rust +/// use axum::{ +/// routing::any, +/// Router, +/// }; +/// +/// async fn handler() {} +/// +/// // All requests to `/` will go to `handler`. +/// let app = Router::new().route("/", any(handler)); +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// Additional methods can still be chained: +/// +/// ```rust +/// use axum::{ +/// routing::any, +/// Router, +/// }; +/// +/// async fn handler() {} +/// +/// async fn other_handler() {} +/// +/// // `POST /` goes to `other_handler`. All other requests go to `handler` +/// let app = Router::new().route("/", any(handler).post(other_handler)); +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +pub fn any(handler: H) -> MethodRouter +where + H: Handler, + B: Send + 'static, + T: 'static, +{ + MethodRouter::new().fallback_boxed_response_body(handler.into_service()) +} + +/// A [`Service`] that accepts requests based on a [`MethodFilter`] and +/// allows chaining additional handlers and services. +pub struct MethodRouter { + get: Option>, + head: Option>, + delete: Option>, + options: Option>, + patch: Option>, + post: Option>, + put: Option>, + trace: Option>, + fallback: Fallback, + _request_body: PhantomData (B, E)>, +} + +impl fmt::Debug for MethodRouter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MethodRouter") + .field("get", &self.get) + .field("head", &self.head) + .field("delete", &self.delete) + .field("options", &self.options) + .field("patch", &self.patch) + .field("post", &self.post) + .field("put", &self.put) + .field("trace", &self.trace) + .field("fallback", &self.fallback) + .finish() + } +} + +impl MethodRouter { + /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all + /// requests. + pub fn new() -> Self { + let fallback = Route::new(service_fn(|_: Request| async { + let mut response = Response::new(box_body(Empty::new())); + *response.status_mut() = StatusCode::METHOD_NOT_ALLOWED; + Ok(response) + })); + + Self { + get: None, + head: None, + delete: None, + options: None, + patch: None, + post: None, + put: None, + trace: None, + fallback: Fallback::Default(fallback), + _request_body: PhantomData, + } + } +} + +impl MethodRouter +where + B: Send + 'static, +{ + /// Chain an additional handler that will accept requests matching the given + /// `MethodFilter`. + /// + /// # Example + /// + /// ```rust + /// use axum::{ + /// routing::get, + /// Router, + /// routing::MethodFilter + /// }; + /// + /// async fn handler() {} + /// + /// async fn other_handler() {} + /// + /// // Requests to `GET /` will go to `handler` and `DELETE /` will go to + /// // `other_handler` + /// let app = Router::new().route("/", get(handler).on(MethodFilter::DELETE, other_handler)); + /// # async { + /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; + /// ``` + pub fn on(self, filter: MethodFilter, handler: H) -> Self + where + H: Handler, + T: 'static, + { + self.on_service_boxed_response_body(filter, handler.into_service()) + } + + chained_handler_fn!(delete, DELETE); + chained_handler_fn!(get, GET); + chained_handler_fn!(head, HEAD); + chained_handler_fn!(options, OPTIONS); + chained_handler_fn!(patch, PATCH); + chained_handler_fn!(post, POST); + chained_handler_fn!(put, PUT); + chained_handler_fn!(trace, TRACE); +} + +impl MethodRouter { + /// Chain an additional service that will accept requests matching the given + /// `MethodFilter`. + /// + /// # Example + /// + /// ```rust + /// use axum::{ + /// http::Request, + /// Router, + /// routing::{MethodFilter, on_service}, + /// }; + /// use http::Response; + /// use std::convert::Infallible; + /// use hyper::Body; + /// + /// let service = tower::service_fn(|request: Request| async { + /// Ok::<_, Infallible>(Response::new(Body::empty())) + /// }); + /// + /// // Requests to `DELETE /` will go to `service` + /// let app = Router::new().route("/", on_service(MethodFilter::DELETE, service)); + /// # async { + /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; + /// ``` + pub fn on_service(self, filter: MethodFilter, svc: S) -> Self + where + S: Service, Response = Response, Error = E> + + Clone + + Send + + 'static, + S::Future: Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Error: Into, + { + self.on_service_boxed_response_body(filter, svc.map_response(|res| res.map(box_body))) + } + + chained_service_fn!(delete_service, DELETE); + chained_service_fn!(get_service, GET); + chained_service_fn!(head_service, HEAD); + chained_service_fn!(options_service, OPTIONS); + chained_service_fn!(patch_service, PATCH); + chained_service_fn!(post_service, POST); + chained_service_fn!(put_service, PUT); + chained_service_fn!(trace_service, TRACE); + + #[doc = include_str!("../docs/method_routing/fallback.md")] + pub fn fallback(mut self, svc: S) -> Self + where + S: Service, Response = Response, Error = E> + + Clone + + Send + + 'static, + S::Future: Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Error: Into, + { + self.fallback = Fallback::Custom(Route::new(svc.map_response(|res| res.map(box_body)))); + self + } + + fn fallback_boxed_response_body(mut self, svc: S) -> Self + where + S: Service, Response = Response, Error = E> + + Clone + + Send + + 'static, + S::Future: Send + 'static, + { + self.fallback = Fallback::Custom(Route::new(svc)); + self + } + + #[doc = include_str!("../docs/method_routing/layer.md")] + pub fn layer( + self, + layer: L, + ) -> MethodRouter + where + L: Layer>, + L::Service: Service, Response = Response, Error = NewError> + + Clone + + Send + + 'static, + >>::Future: Send + 'static, + NewResBody: http_body::Body + Send + 'static, + NewResBody::Error: Into, + { + let layer = ServiceBuilder::new() + .layer_fn(Route::new) + .layer(MapResponseBodyLayer::new(box_body)) + .layer(layer) + .into_inner(); + let layer_fn = |s| layer.layer(s); + + MethodRouter { + get: self.get.map(layer_fn), + head: self.head.map(layer_fn), + delete: self.delete.map(layer_fn), + options: self.options.map(layer_fn), + patch: self.patch.map(layer_fn), + post: self.post.map(layer_fn), + put: self.put.map(layer_fn), + trace: self.trace.map(layer_fn), + fallback: self.fallback.map(layer_fn), + _request_body: PhantomData, + } + } + + #[doc = include_str!("../docs/method_routing/route_layer.md")] + pub fn route_layer(self, layer: L) -> MethodRouter + where + L: Layer>, + L::Service: Service, Response = Response, Error = E> + + Clone + + Send + + 'static, + >>::Future: Send + 'static, + NewResBody: http_body::Body + Send + 'static, + NewResBody::Error: Into, + { + let layer = ServiceBuilder::new() + .layer_fn(Route::new) + .layer(MapResponseBodyLayer::new(box_body)) + .layer(layer) + .into_inner(); + let layer_fn = |s| layer.layer(s); + + MethodRouter { + get: self.get.map(layer_fn), + head: self.head.map(layer_fn), + delete: self.delete.map(layer_fn), + options: self.options.map(layer_fn), + patch: self.patch.map(layer_fn), + post: self.post.map(layer_fn), + put: self.put.map(layer_fn), + trace: self.trace.map(layer_fn), + fallback: self.fallback, + _request_body: PhantomData, + } + } + + #[doc = include_str!("../docs/method_routing/merge.md")] + pub fn merge(self, other: MethodRouter) -> Self { + macro_rules! merge { + ( $first:ident, $second:ident ) => { + match ($first, $second) { + (Some(_), Some(_)) => panic!(concat!( + "Overlapping method route. Cannot merge two method routes that both define `", + stringify!($first), + "`" + )), + (Some(svc), None) => Some(svc), + (None, Some(svc)) => Some(svc), + (None, None) => None, + } + }; + } + + let Self { + get, + head, + delete, + options, + patch, + post, + put, + trace, + fallback, + _request_body: _, + } = self; + + let Self { + get: get_other, + head: head_other, + delete: delete_other, + options: options_other, + patch: patch_other, + post: post_other, + put: put_other, + trace: trace_other, + fallback: fallback_other, + _request_body: _, + } = other; + + let get = merge!(get, get_other); + let head = merge!(head, head_other); + let delete = merge!(delete, delete_other); + let options = merge!(options, options_other); + let patch = merge!(patch, patch_other); + let post = merge!(post, post_other); + let put = merge!(put, put_other); + let trace = merge!(trace, trace_other); + + let fallback = match (fallback, fallback_other) { + (pick @ Fallback::Default(_), Fallback::Default(_)) => pick, + (Fallback::Default(_), pick @ Fallback::Custom(_)) => pick, + (pick @ Fallback::Custom(_), Fallback::Default(_)) => pick, + (Fallback::Custom(_), Fallback::Custom(_)) => { + panic!("Cannot merge two `MethodRouter`s that both have a fallback") + } + }; + + Self { + get, + head, + delete, + options, + patch, + post, + put, + trace, + fallback, + _request_body: PhantomData, + } + } + + /// Apply a [`HandleErrorLayer`]. + /// + /// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`. + pub fn handle_error(self, f: F) -> MethodRouter + where + F: FnOnce(E) -> Res + Clone + Send + 'static, + Res: crate::response::IntoResponse, + ReqBody: Send + 'static, + E: 'static, + { + self.layer(HandleErrorLayer::new(f)) + } + + fn on_service_boxed_response_body(self, filter: MethodFilter, svc: S) -> Self + where + S: Service, Response = Response, Error = E> + + Clone + + Send + + 'static, + S::Future: Send + 'static, + { + // written with a pattern match like this to ensure we update all fields + let Self { + mut get, + mut head, + mut delete, + mut options, + mut patch, + mut post, + mut put, + mut trace, + fallback, + _request_body: _, + } = self; + let svc = Some(Route::new(svc)); + if filter.contains(MethodFilter::GET) { + get = svc.clone(); + } + if filter.contains(MethodFilter::HEAD) { + head = svc.clone(); + } + if filter.contains(MethodFilter::DELETE) { + delete = svc.clone(); + } + if filter.contains(MethodFilter::OPTIONS) { + options = svc.clone(); + } + if filter.contains(MethodFilter::PATCH) { + patch = svc.clone(); + } + if filter.contains(MethodFilter::POST) { + post = svc.clone(); + } + if filter.contains(MethodFilter::PUT) { + put = svc.clone(); + } + if filter.contains(MethodFilter::TRACE) { + trace = svc; + } + Self { + get, + head, + delete, + options, + patch, + post, + put, + trace, + fallback, + _request_body: PhantomData, + } + } +} + +impl Clone for MethodRouter { + fn clone(&self) -> Self { + Self { + get: self.get.clone(), + head: self.head.clone(), + delete: self.delete.clone(), + options: self.options.clone(), + patch: self.patch.clone(), + post: self.post.clone(), + put: self.put.clone(), + trace: self.trace.clone(), + fallback: self.fallback.clone(), + _request_body: PhantomData, + } + } +} + +impl Default for MethodRouter +where + B: Send + 'static, +{ + fn default() -> Self { + Self::new() + } +} + +use crate::routing::future::RouteFuture; + +impl Service> for MethodRouter { + type Response = Response; + type Error = E; + type Future = RouteFuture; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + macro_rules! call { + ( + $req:expr, + $method:expr, + $method_variant:ident, + $svc:expr + ) => { + if $method == Method::$method_variant { + if let Some(svc) = $svc { + return RouteFuture::new(svc.0.clone().oneshot($req)) + .strip_body($method == Method::HEAD); + } + } + }; + } + + let method = req.method().clone(); + + // written with a pattern match like this to ensure we call all routes + let Self { + get, + head, + delete, + options, + patch, + post, + put, + trace, + fallback, + _request_body: _, + } = self; + + call!(req, method, HEAD, head); + call!(req, method, HEAD, get); + call!(req, method, GET, get); + call!(req, method, POST, post); + call!(req, method, OPTIONS, options); + call!(req, method, PATCH, patch); + call!(req, method, PUT, put); + call!(req, method, DELETE, delete); + call!(req, method, TRACE, trace); + + match fallback { + Fallback::Default(fallback) => { + RouteFuture::new(fallback.0.clone().oneshot(req)).strip_body(method == Method::HEAD) + } + Fallback::Custom(fallback) => { + RouteFuture::new(fallback.0.clone().oneshot(req)).strip_body(method == Method::HEAD) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{body::Body, error_handling::HandleErrorLayer}; + use std::time::Duration; + use tower::{timeout::TimeoutLayer, Service, ServiceExt}; + use tower_http::{auth::RequireAuthorizationLayer, services::fs::ServeDir}; + + #[tokio::test] + async fn method_not_allowed_by_default() { + let mut svc = MethodRouter::new(); + let (status, body) = call(Method::GET, &mut svc).await; + assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); + assert!(body.is_empty()); + } + + #[tokio::test] + async fn get_handler() { + let mut svc = MethodRouter::new().get(ok); + let (status, body) = call(Method::GET, &mut svc).await; + assert_eq!(status, StatusCode::OK); + assert_eq!(body, "ok"); + } + + #[tokio::test] + async fn get_accepts_head() { + let mut svc = MethodRouter::new().get(ok); + let (status, body) = call(Method::HEAD, &mut svc).await; + assert_eq!(status, StatusCode::OK); + assert!(body.is_empty()); + } + + #[tokio::test] + async fn head_takes_precedence_over_get() { + let mut svc = MethodRouter::new().head(created).get(ok); + let (status, body) = call(Method::HEAD, &mut svc).await; + assert_eq!(status, StatusCode::CREATED); + assert!(body.is_empty()); + } + + #[tokio::test] + async fn merge() { + let mut svc = get(ok).merge(post(ok)); + + let (status, _) = call(Method::GET, &mut svc).await; + assert_eq!(status, StatusCode::OK); + + let (status, _) = call(Method::POST, &mut svc).await; + assert_eq!(status, StatusCode::OK); + } + + #[tokio::test] + async fn layer() { + let mut svc = MethodRouter::new() + .get(|| async { std::future::pending::<()>().await }) + .layer(RequireAuthorizationLayer::bearer("password")); + + // method with route + let (status, _) = call(Method::GET, &mut svc).await; + assert_eq!(status, StatusCode::UNAUTHORIZED); + + // method without route + let (status, _) = call(Method::DELETE, &mut svc).await; + assert_eq!(status, StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn route_layer() { + let mut svc = MethodRouter::new() + .get(|| async { std::future::pending::<()>().await }) + .route_layer(RequireAuthorizationLayer::bearer("password")); + + // method with route + let (status, _) = call(Method::GET, &mut svc).await; + assert_eq!(status, StatusCode::UNAUTHORIZED); + + // method without route + let (status, _) = call(Method::DELETE, &mut svc).await; + assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); + } + + #[allow(dead_code)] + fn buiding_complex_router() { + let app = crate::Router::new().route( + "/", + // use the all the things :bomb: + get(ok) + .post(ok) + .route_layer(RequireAuthorizationLayer::bearer("password")) + .merge(delete_service(ServeDir::new(".")).handle_error(|_| StatusCode::NOT_FOUND)) + .fallback((|| async { StatusCode::NOT_FOUND }).into_service()) + .put(ok) + .layer( + ServiceBuilder::new() + .layer(HandleErrorLayer::new(|_| StatusCode::REQUEST_TIMEOUT)) + .layer(TimeoutLayer::new(Duration::from_secs(10))), + ), + ); + + crate::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(app.into_make_service()); + } + + async fn call(method: Method, svc: &mut S) -> (StatusCode, String) + where + S: Service, Response = Response, Error = Infallible>, + { + let request = Request::builder() + .uri("/") + .method(method) + .body(Body::empty()) + .unwrap(); + let response = svc.ready().await.unwrap().call(request).await.unwrap(); + let (parts, body) = response.into_parts(); + let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap(); + (parts.status, body) + } + + async fn ok() -> (StatusCode, &'static str) { + (StatusCode::OK, "ok") + } + + async fn created() -> (StatusCode, &'static str) { + (StatusCode::CREATED, "created") + } +} diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 54e70952..57689760 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -7,6 +7,7 @@ use crate::{ connect_info::{Connected, IntoMakeServiceWithConnectInfo}, MatchedPath, OriginalUri, }, + routing::strip_prefix::StripPrefix, util::{ByteStr, PercentDecodedByteStr}, BoxError, }; @@ -20,18 +21,16 @@ use std::{ sync::Arc, task::{Context, Poll}, }; -use tower::{util::ServiceExt, ServiceBuilder}; +use tower::{layer::layer_fn, ServiceBuilder}; use tower_http::map_response_body::MapResponseBodyLayer; use tower_layer::Layer; use tower_service::Service; pub mod future; -pub mod handler_method_routing; -pub mod service_method_routing; mod into_make_service; mod method_filter; -mod method_not_allowed; +mod method_routing; mod not_found; mod route; mod strip_prefix; @@ -39,14 +38,12 @@ mod strip_prefix; #[cfg(test)] mod tests; -pub use self::{ - into_make_service::IntoMakeService, method_filter::MethodFilter, - method_not_allowed::MethodNotAllowed, route::Route, -}; +pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route}; -#[doc(no_inline)] -pub use self::handler_method_routing::{ - any, delete, get, head, on, options, patch, post, put, trace, MethodRouter, +pub use self::method_routing::{ + any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service, + options, options_service, patch, patch_service, post, post_service, put, put_service, trace, + trace_service, MethodRouter, }; #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -63,7 +60,7 @@ impl RouteId { /// The router type for composing handlers and services. #[derive(Debug)] pub struct Router { - routes: HashMap>, + routes: HashMap>, node: Node, fallback: Fallback, nested_at_root: bool, @@ -131,11 +128,32 @@ where let id = RouteId::next(); + let service = match try_downcast::, _>(service) { + Ok(method_router) => { + if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self + .node + .path_to_route_id + .get(path) + .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc))) + { + // if we're adding a new `MethodRouter` to a route that already has one just + // merge them. This makes `.route("/", get(_)).route("/", post(_))` work + let service = + Endpoint::MethodRouter(prev_method_router.clone().merge(method_router)); + self.routes.insert(route_id, service); + return self; + } else { + Endpoint::MethodRouter(method_router) + } + } + Err(service) => Endpoint::Route(Route::new(service)), + }; + if let Err(err) = self.node.insert(path, id) { self.panic_on_matchit_error(err); } - self.routes.insert(id, Route::new(service)); + self.routes.insert(id, service); self } @@ -179,14 +197,22 @@ where nested_at_root: _, } = router; - for (id, nested_path) in node.paths { + for (id, nested_path) in node.route_id_to_path { let route = routes.remove(&id).unwrap(); let full_path = if &*nested_path == "/" { path.to_string() } else { format!("{}{}", path, nested_path) }; - self = self.route(&full_path, strip_prefix::StripPrefix::new(route, prefix)); + self = match route { + Endpoint::MethodRouter(method_router) => self.route( + &full_path, + method_router.layer(layer_fn(|s| StripPrefix::new(s, prefix))), + ), + Endpoint::Route(route) => { + self.route(&full_path, StripPrefix::new(route, prefix)) + } + }; } debug_assert!(routes.is_empty()); @@ -248,20 +274,25 @@ where NewResBody::Error: Into, { let layer = ServiceBuilder::new() - .layer_fn(Route::new) .layer(MapResponseBodyLayer::new(box_body)) - .layer(layer); + .layer(layer) + .into_inner(); let routes = self .routes .into_iter() .map(|(id, route)| { - let route = Layer::layer(&layer, route); + let route = match route { + Endpoint::MethodRouter(method_router) => { + Endpoint::MethodRouter(method_router.layer(&layer)) + } + Endpoint::Route(route) => Endpoint::Route(Route::new(layer.layer(route))), + }; (id, route) }) .collect(); - let fallback = self.fallback.map(|svc| Layer::layer(&layer, svc)); + let fallback = self.fallback.map(|svc| Route::new(layer.layer(svc))); Router { routes, @@ -284,15 +315,20 @@ where NewResBody::Error: Into, { let layer = ServiceBuilder::new() - .layer_fn(Route::new) .layer(MapResponseBodyLayer::new(box_body)) - .layer(layer); + .layer(layer) + .into_inner(); let routes = self .routes .into_iter() .map(|(id, route)| { - let route = Layer::layer(&layer, route); + let route = match route { + Endpoint::MethodRouter(method_router) => { + Endpoint::MethodRouter(method_router.layer(&layer)) + } + Endpoint::Route(route) => Endpoint::Route(Route::new(layer.layer(route))), + }; (id, route) }) .collect(); @@ -360,7 +396,7 @@ where let id = *match_.value; req.extensions_mut().insert(id); - if let Some(matched_path) = self.node.paths.get(&id) { + if let Some(matched_path) = self.node.route_id_to_path.get(&id) { let matched_path = if let Some(previous) = req.extensions_mut().get::() { // a previous `MatchedPath` might exist if we're inside a nested Router let previous = if let Some(previous) = @@ -388,13 +424,17 @@ where insert_url_params(&mut req, params); - let route = self + let mut route = self .routes .get(&id) .expect("no route for id. This is a bug in axum. Please file an issue") .clone(); - RouterFuture::from_oneshot(route.oneshot(req)) + let future = match &mut route { + Endpoint::MethodRouter(inner) => inner.call(req), + Endpoint::Route(inner) => inner.call(req), + }; + RouterFuture::from_future(future) } fn panic_on_matchit_error(&self, err: matchit::InsertError) { @@ -449,10 +489,10 @@ where } else { match &self.fallback { Fallback::Default(inner) => { - RouterFuture::from_oneshot(inner.clone().oneshot(req)) + RouterFuture::from_future(inner.clone().call(req)) } Fallback::Custom(inner) => { - RouterFuture::from_oneshot(inner.clone().oneshot(req)) + RouterFuture::from_future(inner.clone().call(req)) } } } @@ -537,7 +577,8 @@ pub(crate) struct InvalidUtf8InPathParam { #[derive(Clone, Default)] struct Node { inner: matchit::Node, - paths: HashMap>, + route_id_to_path: HashMap>, + path_to_route_id: HashMap, RouteId>, } impl Node { @@ -547,13 +588,18 @@ impl Node { val: RouteId, ) -> Result<(), matchit::InsertError> { let path = path.into(); + self.inner.insert(&path, val)?; - self.paths.insert(val, path.into()); + + let shared_path: Arc = path.into(); + self.route_id_to_path.insert(val, shared_path.clone()); + self.path_to_route_id.insert(shared_path, val); + Ok(()) } fn merge(&mut self, other: Node) -> Result<(), matchit::InsertError> { - for (id, path) in other.paths { + for (id, path) in other.route_id_to_path { self.insert(&*path, id)?; } Ok(()) @@ -569,16 +615,18 @@ impl Node { impl fmt::Debug for Node { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Node").field("paths", &self.paths).finish() + f.debug_struct("Node") + .field("paths", &self.route_id_to_path) + .finish() } } -enum Fallback { - Default(Route), - Custom(Route), +enum Fallback { + Default(Route), + Custom(Route), } -impl Clone for Fallback { +impl Clone for Fallback { fn clone(&self) -> Self { match self { Fallback::Default(inner) => Fallback::Default(inner.clone()), @@ -587,7 +635,7 @@ impl Clone for Fallback { } } -impl fmt::Debug for Fallback { +impl fmt::Debug for Fallback { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(), @@ -596,10 +644,10 @@ impl fmt::Debug for Fallback { } } -impl Fallback { - fn map(self, f: F) -> Fallback +impl Fallback { + fn map(self, f: F) -> Fallback where - F: FnOnce(Route) -> Route, + F: FnOnce(Route) -> Route, { match self { Fallback::Default(inner) => Fallback::Default(f(inner)), @@ -622,6 +670,29 @@ where } } +enum Endpoint { + MethodRouter(MethodRouter), + Route(Route), +} + +impl Clone for Endpoint { + fn clone(&self) -> Self { + match self { + Endpoint::MethodRouter(inner) => Endpoint::MethodRouter(inner.clone()), + Endpoint::Route(inner) => Endpoint::Route(inner.clone()), + } + } +} + +impl fmt::Debug for Endpoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::MethodRouter(inner) => inner.fmt(f), + Self::Route(inner) => inner.fmt(f), + } + } +} + #[test] fn traits() { use crate::test_helpers::*; diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 26d4975a..69ee0b91 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -1,8 +1,9 @@ use crate::{ - body::{Body, BoxBody}, + body::{box_body, Body, BoxBody}, clone_box_service::CloneBoxService, }; use http::{Request, Response}; +use http_body::Empty; use pin_project_lite::pin_project; use std::{ convert::Infallible, @@ -18,37 +19,36 @@ use tower_service::Service; /// /// You normally shouldn't need to care about this type. It's used in /// [`Router::layer`](super::Router::layer). -pub struct Route(CloneBoxService, Response, Infallible>); +pub struct Route( + pub(crate) CloneBoxService, Response, E>, +); -impl Route { +impl Route { pub(super) fn new(svc: T) -> Self where - T: Service, Response = Response, Error = Infallible> - + Clone - + Send - + 'static, + T: Service, Response = Response, Error = E> + Clone + Send + 'static, T::Future: Send + 'static, { Self(CloneBoxService::new(svc)) } } -impl Clone for Route { +impl Clone for Route { fn clone(&self) -> Self { Self(self.0.clone()) } } -impl fmt::Debug for Route { +impl fmt::Debug for Route { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Route").finish() } } -impl Service> for Route { +impl Service> for Route { type Response = Response; - type Error = Infallible; - type Future = RouteFuture; + type Error = E; + type Future = RouteFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { @@ -63,29 +63,50 @@ impl Service> for Route { pin_project! { /// Response future for [`Route`]. - pub struct RouteFuture { + pub struct RouteFuture { #[pin] future: Oneshot< - CloneBoxService, Response, Infallible>, + CloneBoxService, Response, E>, Request, - > + >, + strip_body: bool, } } -impl RouteFuture { +impl RouteFuture { pub(crate) fn new( - future: Oneshot, Response, Infallible>, Request>, + future: Oneshot, Response, E>, Request>, ) -> Self { - RouteFuture { future } + RouteFuture { + future, + strip_body: false, + } + } + + pub(crate) fn strip_body(mut self, strip_body: bool) -> Self { + self.strip_body = strip_body; + self } } -impl Future for RouteFuture { - type Output = Result, Infallible>; +impl Future for RouteFuture { + type Output = Result, E>; #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project().future.poll(cx) + let strip_body = self.strip_body; + + match self.project().future.poll(cx) { + Poll::Ready(Ok(res)) => { + if strip_body { + Poll::Ready(Ok(res.map(|_| box_body(Empty::new())))) + } else { + Poll::Ready(Ok(res)) + } + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => Poll::Pending, + } } } diff --git a/axum/src/routing/service_method_routing.rs b/axum/src/routing/service_method_routing.rs deleted file mode 100644 index 7588d560..00000000 --- a/axum/src/routing/service_method_routing.rs +++ /dev/null @@ -1,559 +0,0 @@ -//! Routing for [`Service`'s] based on HTTP methods. -//! -//! Most of the time applications will be written by composing -//! [handlers](crate::handler), however sometimes you might have some general -//! [`Service`] that you want to route requests to. That is enabled by the -//! functions in this module. -//! -//! # Example -//! -//! Using [`Redirect`] to redirect requests can be done like so: -//! -//! ``` -//! use tower_http::services::Redirect; -//! use axum::{ -//! body::Body, -//! routing::{get, service_method_routing as service}, -//! http::Request, -//! Router, -//! }; -//! -//! async fn handler(request: Request) { /* ... */ } -//! -//! let redirect_service = Redirect::::permanent("/new".parse().unwrap()); -//! -//! let app = Router::new() -//! .route("/old", service::get(redirect_service)) -//! .route("/new", get(handler)); -//! # async { -//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -//! # }; -//! ``` -//! -//! # Regarding backpressure and `Service::poll_ready` -//! -//! Generally routing to one of multiple services and backpressure doesn't mix -//! well. Ideally you would want ensure a service is ready to receive a request -//! before calling it. However, in order to know which service to call, you need -//! the request... -//! -//! One approach is to not consider the router service itself ready until all -//! destination services are ready. That is the approach used by -//! [`tower::steer::Steer`]. -//! -//! Another approach is to always consider all services ready (always return -//! `Poll::Ready(Ok(()))`) from `Service::poll_ready` and then actually drive -//! readiness inside the response future returned by `Service::call`. This works -//! well when your services don't care about backpressure and are always ready -//! anyway. -//! -//! axum expects that all services used in your app wont care about -//! backpressure and so it uses the latter strategy. However that means you -//! should avoid routing to a service (or using a middleware) that _does_ care -//! about backpressure. At the very least you should [load shed] so requests are -//! dropped quickly and don't keep piling up. -//! -//! It also means that if `poll_ready` returns an error then that error will be -//! returned in the response future from `call` and _not_ from `poll_ready`. In -//! that case, the underlying service will _not_ be discarded and will continue -//! to be used for future requests. Services that expect to be discarded if -//! `poll_ready` fails should _not_ be used with axum. -//! -//! One possible approach is to only apply backpressure sensitive middleware -//! around your entire app. This is possible because axum applications are -//! themselves services: -//! -//! ```rust -//! use axum::{ -//! routing::get, -//! Router, -//! }; -//! use tower::ServiceBuilder; -//! # let some_backpressure_sensitive_middleware = -//! # tower::layer::util::Identity::new(); -//! -//! async fn handler() { /* ... */ } -//! -//! let app = Router::new().route("/", get(handler)); -//! -//! let app = ServiceBuilder::new() -//! .layer(some_backpressure_sensitive_middleware) -//! .service(app); -//! # async { -//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -//! # }; -//! ``` -//! -//! However when applying middleware around your whole application in this way -//! you have to take care that errors are still being handled with -//! appropriately. -//! -//! Also note that handlers created from async functions don't care about -//! backpressure and are always ready. So if you're not using any Tower -//! middleware you don't have to worry about any of this. -//! -//! [`Redirect`]: tower_http::services::Redirect -//! [load shed]: tower::load_shed -//! [`Service`'s]: tower::Service - -use crate::{ - body::{box_body, BoxBody}, - routing::{MethodFilter, MethodNotAllowed}, - util::{Either, EitherProj}, - BoxError, -}; -use bytes::Bytes; -use futures_util::ready; -use http::{Method, Request, Response}; -use http_body::Empty; -use pin_project_lite::pin_project; -use std::{ - fmt, - future::Future, - marker::PhantomData, - pin::Pin, - task::{Context, Poll}, -}; -use tower::{util::Oneshot, ServiceExt as _}; -use tower_service::Service; - -/// Route requests with any standard HTTP method to the given service. -/// -/// See [`get`] for an example. -/// -/// Note that this only accepts the standard HTTP methods. If you need to -/// support non-standard methods you can route directly to a [`Service`]. -pub fn any(svc: S) -> MethodRouter, B> -where - S: Service> + Clone, -{ - on(MethodFilter::all(), svc) -} - -/// Route `DELETE` requests to the given service. -/// -/// See [`get`] for an example. -pub fn delete(svc: S) -> MethodRouter, B> -where - S: Service> + Clone, -{ - on(MethodFilter::DELETE, svc) -} - -/// Route `GET` requests to the given service. -/// -/// # Example -/// -/// ```rust -/// use axum::{ -/// http::Request, -/// Router, -/// routing::service_method_routing as service, -/// }; -/// use http::Response; -/// use std::convert::Infallible; -/// use hyper::Body; -/// -/// let service = tower::service_fn(|request: Request| async { -/// Ok::<_, Infallible>(Response::new(Body::empty())) -/// }); -/// -/// // Requests to `GET /` will go to `service`. -/// let app = Router::new().route("/", service::get(service)); -/// # async { -/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -/// -/// Note that `get` routes will also be called for `HEAD` requests but will have -/// the response body removed. Make sure to add explicit `HEAD` routes -/// afterwards. -pub fn get(svc: S) -> MethodRouter, B> -where - S: Service> + Clone, -{ - on(MethodFilter::GET | MethodFilter::HEAD, svc) -} - -/// Route `HEAD` requests to the given service. -/// -/// See [`get`] for an example. -pub fn head(svc: S) -> MethodRouter, B> -where - S: Service> + Clone, -{ - on(MethodFilter::HEAD, svc) -} - -/// Route `OPTIONS` requests to the given service. -/// -/// See [`get`] for an example. -pub fn options(svc: S) -> MethodRouter, B> -where - S: Service> + Clone, -{ - on(MethodFilter::OPTIONS, svc) -} - -/// Route `PATCH` requests to the given service. -/// -/// See [`get`] for an example. -pub fn patch(svc: S) -> MethodRouter, B> -where - S: Service> + Clone, -{ - on(MethodFilter::PATCH, svc) -} - -/// Route `POST` requests to the given service. -/// -/// See [`get`] for an example. -pub fn post(svc: S) -> MethodRouter, B> -where - S: Service> + Clone, -{ - on(MethodFilter::POST, svc) -} - -/// Route `PUT` requests to the given service. -/// -/// See [`get`] for an example. -pub fn put(svc: S) -> MethodRouter, B> -where - S: Service> + Clone, -{ - on(MethodFilter::PUT, svc) -} - -/// Route `TRACE` requests to the given service. -/// -/// See [`get`] for an example. -pub fn trace(svc: S) -> MethodRouter, B> -where - S: Service> + Clone, -{ - on(MethodFilter::TRACE, svc) -} - -/// Route requests with the given method to the service. -/// -/// # Example -/// -/// ```rust -/// use axum::{ -/// http::Request, -/// routing::on, -/// Router, -/// routing::{MethodFilter, service_method_routing as service}, -/// }; -/// use http::Response; -/// use std::convert::Infallible; -/// use hyper::Body; -/// -/// let service = tower::service_fn(|request: Request| async { -/// Ok::<_, Infallible>(Response::new(Body::empty())) -/// }); -/// -/// // Requests to `POST /` will go to `service`. -/// let app = Router::new().route("/", service::on(MethodFilter::POST, service)); -/// # async { -/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -pub fn on(method: MethodFilter, svc: S) -> MethodRouter, B> -where - S: Service> + Clone, -{ - MethodRouter { - method, - svc, - fallback: MethodNotAllowed::new(), - _request_body: PhantomData, - } -} - -/// A [`Service`] that accepts requests based on a [`MethodFilter`] and allows -/// chaining additional services. -pub struct MethodRouter { - pub(crate) method: MethodFilter, - pub(crate) svc: S, - pub(crate) fallback: F, - pub(crate) _request_body: PhantomData B>, -} - -impl fmt::Debug for MethodRouter -where - S: fmt::Debug, - F: fmt::Debug, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MethodRouter") - .field("method", &self.method) - .field("svc", &self.svc) - .field("fallback", &self.fallback) - .finish() - } -} - -impl Clone for MethodRouter -where - S: Clone, - F: Clone, -{ - fn clone(&self) -> Self { - Self { - method: self.method, - svc: self.svc.clone(), - fallback: self.fallback.clone(), - _request_body: PhantomData, - } - } -} - -impl MethodRouter { - /// Chain an additional service that will accept all requests regardless of - /// its HTTP method. - /// - /// See [`MethodRouter::get`] for an example. - pub fn any(self, svc: T) -> MethodRouter - where - T: Service> + Clone, - { - self.on(MethodFilter::all(), svc) - } - - /// Chain an additional service that will only accept `DELETE` requests. - /// - /// See [`MethodRouter::get`] for an example. - pub fn delete(self, svc: T) -> MethodRouter - where - T: Service> + Clone, - { - self.on(MethodFilter::DELETE, svc) - } - - /// Chain an additional service that will only accept `GET` requests. - /// - /// # Example - /// - /// ```rust - /// use axum::{ - /// http::Request, - /// Router, - /// routing::{MethodFilter, on, service_method_routing as service}, - /// }; - /// use http::Response; - /// use std::convert::Infallible; - /// use hyper::Body; - /// - /// let service = tower::service_fn(|request: Request| async { - /// Ok::<_, Infallible>(Response::new(Body::empty())) - /// }); - /// - /// let other_service = tower::service_fn(|request: Request| async { - /// Ok::<_, Infallible>(Response::new(Body::empty())) - /// }); - /// - /// // Requests to `GET /` will go to `service` and `POST /` will go to - /// // `other_service`. - /// let app = Router::new().route("/", service::post(service).get(other_service)); - /// # async { - /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); - /// # }; - /// ``` - /// - /// Note that `get` routes will also be called for `HEAD` requests but will have - /// the response body removed. Make sure to add explicit `HEAD` routes - /// afterwards. - pub fn get(self, svc: T) -> MethodRouter - where - T: Service> + Clone, - { - self.on(MethodFilter::GET | MethodFilter::HEAD, svc) - } - - /// Chain an additional service that will only accept `HEAD` requests. - /// - /// See [`MethodRouter::get`] for an example. - pub fn head(self, svc: T) -> MethodRouter - where - T: Service> + Clone, - { - self.on(MethodFilter::HEAD, svc) - } - - /// Chain an additional service that will only accept `OPTIONS` requests. - /// - /// See [`MethodRouter::get`] for an example. - pub fn options(self, svc: T) -> MethodRouter - where - T: Service> + Clone, - { - self.on(MethodFilter::OPTIONS, svc) - } - - /// Chain an additional service that will only accept `PATCH` requests. - /// - /// See [`MethodRouter::get`] for an example. - pub fn patch(self, svc: T) -> MethodRouter - where - T: Service> + Clone, - { - self.on(MethodFilter::PATCH, svc) - } - - /// Chain an additional service that will only accept `POST` requests. - /// - /// See [`MethodRouter::get`] for an example. - pub fn post(self, svc: T) -> MethodRouter - where - T: Service> + Clone, - { - self.on(MethodFilter::POST, svc) - } - - /// Chain an additional service that will only accept `PUT` requests. - /// - /// See [`MethodRouter::get`] for an example. - pub fn put(self, svc: T) -> MethodRouter - where - T: Service> + Clone, - { - self.on(MethodFilter::PUT, svc) - } - - /// Chain an additional service that will only accept `TRACE` requests. - /// - /// See [`MethodRouter::get`] for an example. - pub fn trace(self, svc: T) -> MethodRouter - where - T: Service> + Clone, - { - self.on(MethodFilter::TRACE, svc) - } - - /// Chain an additional service that will accept requests matching the given - /// `MethodFilter`. - /// - /// # Example - /// - /// ```rust - /// use axum::{ - /// http::Request, - /// Router, - /// routing::{MethodFilter, on, service_method_routing as service}, - /// }; - /// use http::Response; - /// use std::convert::Infallible; - /// use hyper::Body; - /// - /// let service = tower::service_fn(|request: Request| async { - /// Ok::<_, Infallible>(Response::new(Body::empty())) - /// }); - /// - /// let other_service = tower::service_fn(|request: Request| async { - /// Ok::<_, Infallible>(Response::new(Body::empty())) - /// }); - /// - /// // Requests to `DELETE /` will go to `service` - /// let app = Router::new().route("/", service::on(MethodFilter::DELETE, service)); - /// # async { - /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); - /// # }; - /// ``` - pub fn on(self, method: MethodFilter, svc: T) -> MethodRouter - where - T: Service> + Clone, - { - MethodRouter { - method, - svc, - fallback: self, - _request_body: PhantomData, - } - } -} - -impl Service> for MethodRouter -where - S: Service, Response = Response> + Clone, - ResBody: http_body::Body + Send + 'static, - ResBody::Error: Into, - F: Service, Response = Response, Error = S::Error> + Clone, -{ - type Response = Response; - type Error = S::Error; - type Future = MethodRouterFuture; - - #[inline] - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - let req_method = req.method().clone(); - - let f = if self.method.matches(req.method()) { - let fut = self.svc.clone().oneshot(req); - Either::A { inner: fut } - } else { - let fut = self.fallback.clone().oneshot(req); - Either::B { inner: fut } - }; - - MethodRouterFuture { - inner: f, - req_method, - } - } -} - -pin_project! { - /// The response future for [`MethodRouter`]. - pub struct MethodRouterFuture - where - S: Service>, - F: Service> - { - #[pin] - pub(super) inner: Either< - Oneshot>, - Oneshot>, - >, - pub(super) req_method: Method, - } -} - -impl Future for MethodRouterFuture -where - S: Service, Response = Response> + Clone, - ResBody: http_body::Body + Send + 'static, - ResBody::Error: Into, - F: Service, Response = Response, Error = S::Error>, -{ - type Output = Result, S::Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - let response = match this.inner.project() { - EitherProj::A { inner } => ready!(inner.poll(cx))?.map(box_body), - EitherProj::B { inner } => ready!(inner.poll(cx))?, - }; - - if this.req_method == &Method::HEAD { - let response = response.map(|_| box_body(Empty::new())); - Poll::Ready(Ok(response)) - } else { - Poll::Ready(Ok(response)) - } - } -} - -#[test] -fn traits() { - use crate::test_helpers::*; - - assert_send::>(); - assert_sync::>(); -} diff --git a/axum/src/routing/tests/get_to_head.rs b/axum/src/routing/tests/get_to_head.rs index 50f1e355..a7fe623e 100644 --- a/axum/src/routing/tests/get_to_head.rs +++ b/axum/src/routing/tests/get_to_head.rs @@ -4,6 +4,7 @@ use tower::ServiceExt; mod for_handlers { use super::*; + use headers::HeaderMap; #[tokio::test] async fn get_handles_head() { @@ -38,14 +39,14 @@ mod for_handlers { mod for_services { use super::*; - use crate::routing::service_method_routing::get; + use crate::routing::get_service; use http::header::HeaderValue; #[tokio::test] async fn get_handles_head() { let app = Router::new().route( "/", - get(service_fn(|_req: Request| async move { + get_service(service_fn(|_req: Request| async move { let res = Response::builder() .header("x-some-header", "foobar".parse::().unwrap()) .body(Body::from("you shouldn't see this")) diff --git a/axum/src/routing/tests/merge.rs b/axum/src/routing/tests/merge.rs index f50e5a90..6b2fedfb 100644 --- a/axum/src/routing/tests/merge.rs +++ b/axum/src/routing/tests/merge.rs @@ -196,18 +196,18 @@ async fn many_ors() { #[tokio::test] async fn services() { - use crate::routing::service_method_routing::get; + use crate::routing::get_service; let app = Router::new() .route( "/foo", - get(service_fn(|_: Request| async { + get_service(service_fn(|_: Request| async { Ok::<_, Infallible>(Response::new(Body::empty())) })), ) .merge(Router::new().route( "/bar", - get(service_fn(|_: Request| async { + get_service(service_fn(|_: Request| async { Ok::<_, Infallible>(Response::new(Body::empty())) })), )); diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index af2ecfdc..56a33006 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -3,12 +3,12 @@ use crate::{ extract::{self, Path}, handler::Handler, response::IntoResponse, - routing::{any, delete, get, on, patch, post, service_method_routing as service, MethodFilter}, + routing::{delete, get, get_service, on, on_service, patch, patch_service, post, MethodFilter}, test_helpers::*, BoxError, Json, Router, }; use bytes::Bytes; -use http::{header::HeaderMap, Method, Request, Response, StatusCode, Uri}; +use http::{Method, Request, Response, StatusCode, Uri}; use hyper::Body; use serde::Deserialize; use serde_json::{json, Value}; @@ -135,20 +135,20 @@ async fn routing_between_services() { let app = Router::new() .route( "/one", - service::get(service_fn(|_: Request| async { + get_service(service_fn(|_: Request| async { Ok::<_, Infallible>(Response::new(Body::from("one get"))) })) - .post(service_fn(|_: Request| async { + .post_service(service_fn(|_: Request| async { Ok::<_, Infallible>(Response::new(Body::from("one post"))) })) - .on( + .on_service( MethodFilter::PUT, service_fn(|_: Request| async { Ok::<_, Infallible>(Response::new(Body::from("one put"))) }), ), ) - .route("/two", service::on(MethodFilter::GET, any(handle))); + .route("/two", on_service(MethodFilter::GET, handle.into_service())); let client = TestClient::new(app); @@ -202,7 +202,7 @@ async fn service_in_bottom() { Ok(Response::new(hyper::Body::empty())) } - let app = Router::new().route("/", service::get(service_fn(handler))); + let app = Router::new().route("/", get_service(service_fn(handler))); TestClient::new(app); } @@ -248,8 +248,8 @@ async fn wrong_method_service() { } let app = Router::new() - .route("/", service::get(Svc).post(Svc)) - .route("/foo", service::patch(Svc)); + .route("/", get_service(Svc).post_service(Svc)) + .route("/foo", patch_service(Svc)); let client = TestClient::new(app); @@ -505,17 +505,6 @@ async fn route_layer() { assert_eq!(res.status(), StatusCode::UNAUTHORIZED); } -#[tokio::test] -#[should_panic( - expected = "Invalid route: insertion failed due to conflict with previously registered route: /foo" -)] -async fn conflicting_route() { - let app = Router::new() - .route("/foo", get(|| async {})) - .route("/foo", get(|| async {})); - TestClient::new(app); -} - #[tokio::test] #[should_panic( expected = "Invalid route: insertion failed due to conflict with previously registered route: /*axum_nest. Note that `nest(\"/\", _)` conflicts with all routes. Use `Router::fallback` instead" @@ -537,3 +526,43 @@ async fn good_error_message_if_using_nest_root_when_merging() { let app = one.merge(two); TestClient::new(app); } + +#[tokio::test] +async fn different_methods_added_in_different_routes() { + let app = Router::new() + .route("/", get(|| async { "GET" })) + .route("/", post(|| async { "POST" })); + + let client = TestClient::new(app); + + let res = client.get("/").send().await; + let body = res.text().await; + assert_eq!(body, "GET"); + + let res = client.post("/").send().await; + let body = res.text().await; + assert_eq!(body, "POST"); +} + +#[tokio::test] +async fn different_methods_added_in_different_routes_deeply_nested() { + let app = Router::new() + .route("/foo/bar/baz", get(|| async { "GET" })) + .nest( + "/foo", + Router::new().nest( + "/bar", + Router::new().route("/baz", post(|| async { "POST" })), + ), + ); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar/baz").send().await; + let body = res.text().await; + assert_eq!(body, "GET"); + + let res = client.post("/foo/bar/baz").send().await; + let body = res.text().await; + assert_eq!(body, "POST"); +} diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index 39f74073..25aa41d2 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -1,5 +1,7 @@ +use tower_http::services::ServeDir; + use super::*; -use crate::{body::box_body, error_handling::HandleErrorExt, extract::Extension}; +use crate::{body::box_body, extract::Extension}; use std::collections::HashMap; #[tokio::test] @@ -167,7 +169,7 @@ async fn nested_service_sees_stripped_uri() { async fn nest_static_file_server() { let app = Router::new().nest( "/static", - service::get(tower_http::services::ServeDir::new(".")).handle_error(|error| { + get_service(ServeDir::new(".")).handle_error(|error| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), diff --git a/examples/sse/src/main.rs b/examples/sse/src/main.rs index aa572ef3..e7e253fc 100644 --- a/examples/sse/src/main.rs +++ b/examples/sse/src/main.rs @@ -5,11 +5,10 @@ //! ``` use axum::{ - error_handling::HandleErrorExt, extract::TypedHeader, http::StatusCode, response::sse::{Event, Sse}, - routing::{get, service_method_routing as service}, + routing::{get, get_service}, Router, }; use futures::stream::{self, Stream}; @@ -26,7 +25,7 @@ async fn main() { tracing_subscriber::fmt::init(); let static_files_service = - service::get(ServeDir::new("examples/sse/assets").append_index_html_on_directories(true)) + get_service(ServeDir::new("examples/sse/assets").append_index_html_on_directories(true)) .handle_error(|error: std::io::Error| { ( StatusCode::INTERNAL_SERVER_ERROR, diff --git a/examples/static-file-server/src/main.rs b/examples/static-file-server/src/main.rs index 5524a078..4d087efc 100644 --- a/examples/static-file-server/src/main.rs +++ b/examples/static-file-server/src/main.rs @@ -4,10 +4,7 @@ //! cargo run -p example-static-file-server //! ``` -use axum::{ - error_handling::HandleErrorExt, http::StatusCode, routing::service_method_routing as service, - Router, -}; +use axum::{http::StatusCode, routing::get_service, Router}; use std::net::SocketAddr; use tower_http::{services::ServeDir, trace::TraceLayer}; @@ -25,7 +22,7 @@ async fn main() { let app = Router::new() .nest( "/static", - service::get(ServeDir::new(".")).handle_error(|error: std::io::Error| { + get_service(ServeDir::new(".")).handle_error(|error: std::io::Error| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), diff --git a/examples/websockets/src/main.rs b/examples/websockets/src/main.rs index 1f744878..3680f0c9 100644 --- a/examples/websockets/src/main.rs +++ b/examples/websockets/src/main.rs @@ -7,14 +7,13 @@ //! ``` use axum::{ - error_handling::HandleErrorExt, extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, TypedHeader, }, http::StatusCode, response::IntoResponse, - routing::{get, service_method_routing as service}, + routing::{get, get_service}, Router, }; use std::net::SocketAddr; @@ -34,7 +33,7 @@ async fn main() { // build our application with some routes let app = Router::new() .fallback( - service::get( + get_service( ServeDir::new("examples/websockets/assets").append_index_html_on_directories(true), ) .handle_error(|error: std::io::Error| {