mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-28 07:20:12 +01:00
Run middleware for requests with no matching route (#422)
While thinking about #419 I realized that `main` had a bug where middleware wouldn't be called if no route matched the incoming request. `Router` would just directly return a 404 without calling any middleware. This fixes by making the fallback default to a service that always returns 404 and always calling that if no route matches. Layers applied to the router is then also applied to the fallback. Unfortunately this breaks #380 but I don't currently see a way to support both. Auth middleware need to run _after_ routing because you don't care about auth for unknown paths, but logging middleware need to run _before_ routing because they do care about seeing requests for unknown paths so they can log them... Part of #419
This commit is contained in:
parent
e9533c566f
commit
8fe4eaf1d5
6 changed files with 225 additions and 110 deletions
|
@ -41,9 +41,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- **fixed:** Adding a conflicting route will now cause a panic instead of silently making
|
||||
a route unreachable.
|
||||
- **fixed:** Route matching is faster as number of routes increase.
|
||||
- **fixed:** Middleware that return early (such as `tower_http::auth::RequireAuthorization`)
|
||||
now no longer catch requests that would otherwise be 404s. They also work
|
||||
correctly with `Router::merge` (previously called `or`) ([#408])
|
||||
- **fixed:** Correctly handle trailing slashes in routes:
|
||||
- If a route with a trailing slash exists and a request without a trailing
|
||||
slash is received, axum will send a 301 redirection to the route with the
|
||||
|
|
|
@ -13,6 +13,8 @@ use std::{
|
|||
use tower::util::Oneshot;
|
||||
use tower_service::Service;
|
||||
|
||||
pub use super::method_not_allowed::MethodNotAllowedFuture;
|
||||
|
||||
opaque_future! {
|
||||
/// Response future for [`Router`](super::Router).
|
||||
pub type RouterFuture<B> =
|
||||
|
@ -60,12 +62,6 @@ impl<B> Future for RouteFuture<B> {
|
|||
}
|
||||
}
|
||||
|
||||
opaque_future! {
|
||||
/// Response future for [`MethodNotAllowed`](super::MethodNotAllowed).
|
||||
pub type MethodNotAllowedFuture<E> =
|
||||
std::future::Ready<Result<Response<BoxBody>, E>>;
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// The response future for [`Nested`](super::Nested).
|
||||
#[derive(Debug)]
|
||||
|
|
69
src/routing/method_not_allowed.rs
Normal file
69
src/routing/method_not_allowed.rs
Normal file
|
@ -0,0 +1,69 @@
|
|||
use crate::body::BoxBody;
|
||||
use http::{Request, Response, StatusCode};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
fmt,
|
||||
future::ready,
|
||||
marker::PhantomData,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tower_service::Service;
|
||||
|
||||
/// A [`Service`] that responds with `405 Method not allowed` to all requests.
|
||||
///
|
||||
/// This is used as the bottom service in a method router. You shouldn't have to
|
||||
/// use it manually.
|
||||
pub struct MethodNotAllowed<E = Infallible> {
|
||||
_marker: PhantomData<fn() -> E>,
|
||||
}
|
||||
|
||||
impl<E> MethodNotAllowed<E> {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> Clone for MethodNotAllowed<E> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> fmt::Debug for MethodNotAllowed<E> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_tuple("MethodNotAllowed").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, E> Service<Request<B>> for MethodNotAllowed<E>
|
||||
where
|
||||
B: Send + Sync + 'static,
|
||||
{
|
||||
type Response = Response<BoxBody>;
|
||||
type Error = E;
|
||||
type Future = MethodNotAllowedFuture<E>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, _req: Request<B>) -> Self::Future {
|
||||
let res = Response::builder()
|
||||
.status(StatusCode::METHOD_NOT_ALLOWED)
|
||||
.body(crate::body::empty())
|
||||
.unwrap();
|
||||
|
||||
MethodNotAllowedFuture::new(ready(Ok(res)))
|
||||
}
|
||||
}
|
||||
|
||||
opaque_future! {
|
||||
/// Response future for [`MethodNotAllowed`](super::MethodNotAllowed).
|
||||
pub type MethodNotAllowedFuture<E> =
|
||||
std::future::Ready<Result<Response<BoxBody>, E>>;
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
//! Routing between [`Service`]s and handlers.
|
||||
|
||||
use self::future::{MethodNotAllowedFuture, NestedFuture, RouteFuture, RouterFuture};
|
||||
use self::future::{NestedFuture, RouteFuture, RouterFuture};
|
||||
use self::not_found::NotFound;
|
||||
use crate::{
|
||||
body::{box_body, Body, BoxBody},
|
||||
clone_box_service::CloneBoxService,
|
||||
|
@ -19,7 +20,6 @@ use std::{
|
|||
convert::Infallible,
|
||||
fmt,
|
||||
future::ready,
|
||||
marker::PhantomData,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
@ -33,8 +33,11 @@ pub mod handler_method_router;
|
|||
pub mod service_method_router;
|
||||
|
||||
mod method_filter;
|
||||
mod method_not_allowed;
|
||||
mod not_found;
|
||||
|
||||
pub use self::method_filter::MethodFilter;
|
||||
pub(crate) use self::method_not_allowed::MethodNotAllowed;
|
||||
|
||||
#[doc(no_inline)]
|
||||
pub use self::handler_method_router::{
|
||||
|
@ -56,7 +59,7 @@ impl RouteId {
|
|||
pub struct Router<B = Body> {
|
||||
routes: HashMap<RouteId, Route<B>>,
|
||||
node: Node,
|
||||
fallback: Option<Route<B>>,
|
||||
fallback: Fallback<B>,
|
||||
}
|
||||
|
||||
impl<B> Clone for Router<B> {
|
||||
|
@ -102,7 +105,7 @@ where
|
|||
Self {
|
||||
routes: Default::default(),
|
||||
node: Default::default(),
|
||||
fallback: None,
|
||||
fallback: Fallback::Default(Route::new(NotFound)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -198,7 +201,7 @@ where
|
|||
panic!("Invalid route: {}", err);
|
||||
}
|
||||
|
||||
self.routes.insert(id, Route(CloneBoxService::new(svc)));
|
||||
self.routes.insert(id, Route::new(svc));
|
||||
|
||||
self
|
||||
}
|
||||
|
@ -350,8 +353,7 @@ where
|
|||
panic!("Invalid route: {}", err);
|
||||
}
|
||||
|
||||
self.routes
|
||||
.insert(id, Route(CloneBoxService::new(Nested { svc })));
|
||||
self.routes.insert(id, Route::new(Nested { svc }));
|
||||
|
||||
self
|
||||
}
|
||||
|
@ -452,7 +454,7 @@ where
|
|||
})
|
||||
.collect::<HashMap<RouteId, Route<LayeredReqBody>>>();
|
||||
|
||||
let fallback = self.fallback.map(|fallback| Layer::layer(&layer, fallback));
|
||||
let fallback = self.fallback.map(|svc| Layer::layer(&layer, svc));
|
||||
|
||||
Router {
|
||||
routes,
|
||||
|
@ -620,9 +622,12 @@ where
|
|||
assert!(self.routes.insert(id, route).is_none());
|
||||
}
|
||||
|
||||
if let Some(new_fallback) = fallback {
|
||||
self.fallback = Some(new_fallback);
|
||||
}
|
||||
self.fallback = match (self.fallback, fallback) {
|
||||
(Fallback::Default(_), pick @ Fallback::Default(_)) => pick,
|
||||
(Fallback::Default(_), pick @ Fallback::Custom(_)) => pick,
|
||||
(pick @ Fallback::Custom(_), Fallback::Default(_)) => pick,
|
||||
(Fallback::Custom(_), pick @ Fallback::Custom(_)) => pick,
|
||||
};
|
||||
|
||||
self
|
||||
}
|
||||
|
@ -728,7 +733,7 @@ where
|
|||
+ 'static,
|
||||
T::Future: Send + 'static,
|
||||
{
|
||||
self.fallback = Some(Route(CloneBoxService::new(svc)));
|
||||
self.fallback = Fallback::Custom(Route::new(svc));
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -804,14 +809,15 @@ where
|
|||
.body(crate::body::empty())
|
||||
.unwrap();
|
||||
RouterFuture::from_response(res)
|
||||
} else if let Some(fallback) = &self.fallback {
|
||||
RouterFuture::from_oneshot(fallback.clone().oneshot(req))
|
||||
} else {
|
||||
let res = Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(crate::body::empty())
|
||||
.unwrap();
|
||||
RouterFuture::from_response(res)
|
||||
match &self.fallback {
|
||||
Fallback::Default(inner) => {
|
||||
RouterFuture::from_oneshot(inner.clone().oneshot(req))
|
||||
}
|
||||
Fallback::Custom(inner) => {
|
||||
RouterFuture::from_oneshot(inner.clone().oneshot(req))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -875,59 +881,6 @@ pub(crate) struct InvalidUtf8InPathParam {
|
|||
pub(crate) key: ByteStr,
|
||||
}
|
||||
|
||||
/// A [`Service`] that responds with `405 Method not allowed` to all requests.
|
||||
///
|
||||
/// This is used as the bottom service in a method router. You shouldn't have to
|
||||
/// use it manually.
|
||||
pub struct MethodNotAllowed<E = Infallible> {
|
||||
_marker: PhantomData<fn() -> E>,
|
||||
}
|
||||
|
||||
impl<E> MethodNotAllowed<E> {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> Clone for MethodNotAllowed<E> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> fmt::Debug for MethodNotAllowed<E> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_tuple("MethodNotAllowed").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, E> Service<Request<B>> for MethodNotAllowed<E>
|
||||
where
|
||||
B: Send + Sync + 'static,
|
||||
{
|
||||
type Response = Response<BoxBody>;
|
||||
type Error = E;
|
||||
type Future = MethodNotAllowedFuture<E>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, _req: Request<B>) -> Self::Future {
|
||||
let res = Response::builder()
|
||||
.status(StatusCode::METHOD_NOT_ALLOWED)
|
||||
.body(crate::body::empty())
|
||||
.unwrap();
|
||||
|
||||
MethodNotAllowedFuture::new(ready(Ok(res)))
|
||||
}
|
||||
}
|
||||
|
||||
/// A [`Service`] that has been nested inside a router at some path.
|
||||
///
|
||||
/// Created with [`Router::nest`].
|
||||
|
@ -1023,6 +976,19 @@ where
|
|||
/// You normally shouldn't need to care about this type.
|
||||
pub struct Route<B = Body>(CloneBoxService<Request<B>, Response<BoxBody>, Infallible>);
|
||||
|
||||
impl<B> Route<B> {
|
||||
fn new<T>(svc: T) -> Self
|
||||
where
|
||||
T: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
T::Future: Send + 'static,
|
||||
{
|
||||
Self(CloneBoxService::new(svc))
|
||||
}
|
||||
}
|
||||
|
||||
impl<ReqBody> Clone for Route<ReqBody> {
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.clone())
|
||||
|
@ -1091,6 +1057,41 @@ impl fmt::Debug for Node {
|
|||
}
|
||||
}
|
||||
|
||||
enum Fallback<B> {
|
||||
Default(Route<B>),
|
||||
Custom(Route<B>),
|
||||
}
|
||||
|
||||
impl<B> Clone for Fallback<B> {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
Fallback::Default(inner) => Fallback::Default(inner.clone()),
|
||||
Fallback::Custom(inner) => Fallback::Custom(inner.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> fmt::Debug for Fallback<B> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
|
||||
Self::Custom(inner) => f.debug_tuple("Custom").field(inner).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> Fallback<B> {
|
||||
fn map<F, B2>(self, f: F) -> Fallback<B2>
|
||||
where
|
||||
F: FnOnce(Route<B>) -> Route<B2>,
|
||||
{
|
||||
match self {
|
||||
Fallback::Default(inner) => Fallback::Default(f(inner)),
|
||||
Fallback::Custom(inner) => Fallback::Custom(f(inner)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
38
src/routing/not_found.rs
Normal file
38
src/routing/not_found.rs
Normal file
|
@ -0,0 +1,38 @@
|
|||
use crate::body::BoxBody;
|
||||
use http::{Request, Response, StatusCode};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
future::ready,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tower_service::Service;
|
||||
|
||||
/// A [`Service`] that responds with `405 Method not allowed` to all requests.
|
||||
///
|
||||
/// This is used as the bottom service in a method router. You shouldn't have to
|
||||
/// use it manually.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub(crate) struct NotFound;
|
||||
|
||||
impl<B> Service<Request<B>> for NotFound
|
||||
where
|
||||
B: Send + Sync + 'static,
|
||||
{
|
||||
type Response = Response<BoxBody>;
|
||||
type Error = Infallible;
|
||||
type Future = std::future::Ready<Result<Response<BoxBody>, Self::Error>>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, _req: Request<B>) -> Self::Future {
|
||||
let res = Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(crate::body::empty())
|
||||
.unwrap();
|
||||
|
||||
ready(Ok(res))
|
||||
}
|
||||
}
|
|
@ -19,6 +19,7 @@ use hyper::Body;
|
|||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::future::Ready;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
convert::Infallible,
|
||||
|
@ -552,35 +553,6 @@ async fn middleware_applies_to_routes_above() {
|
|||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn middleware_that_return_early() {
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async {}))
|
||||
.layer(RequireAuthorizationLayer::bearer("password"))
|
||||
.route("/public", get(|| async {}));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
assert_eq!(
|
||||
client.get("/").send().await.status(),
|
||||
StatusCode::UNAUTHORIZED
|
||||
);
|
||||
assert_eq!(
|
||||
client
|
||||
.get("/")
|
||||
.header("authorization", "Bearer password")
|
||||
.send()
|
||||
.await
|
||||
.status(),
|
||||
StatusCode::OK
|
||||
);
|
||||
assert_eq!(
|
||||
client.get("/doesnt-exist").send().await.status(),
|
||||
StatusCode::NOT_FOUND
|
||||
);
|
||||
assert_eq!(client.get("/public").send().await.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn with_trailing_slash() {
|
||||
let app = Router::new().route("/foo", get(|| async {}));
|
||||
|
@ -672,6 +644,48 @@ async fn empty_route_nested() {
|
|||
TestClient::new(app);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn middleware_still_run_for_unmatched_requests() {
|
||||
#[derive(Clone)]
|
||||
struct CountMiddleware<S>(S);
|
||||
|
||||
static COUNT: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
impl<R, S> Service<R> for CountMiddleware<S>
|
||||
where
|
||||
S: Service<R>,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = S::Future;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.0.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, req: R) -> Self::Future {
|
||||
COUNT.fetch_add(1, Ordering::SeqCst);
|
||||
self.0.call(req)
|
||||
}
|
||||
}
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async {}))
|
||||
.layer(tower::layer::layer_fn(CountMiddleware));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
assert_eq!(COUNT.load(Ordering::SeqCst), 0);
|
||||
|
||||
client.get("/").send().await;
|
||||
assert_eq!(COUNT.load(Ordering::SeqCst), 1);
|
||||
|
||||
client.get("/not-found").send().await;
|
||||
assert_eq!(COUNT.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
// TODO(david): middleware still run for empty routers
|
||||
|
||||
pub(crate) fn assert_send<T: Send>() {}
|
||||
pub(crate) fn assert_sync<T: Sync>() {}
|
||||
pub(crate) fn assert_unpin<T: Unpin>() {}
|
||||
|
|
Loading…
Reference in a new issue