Run middleware for requests with no matching route (#422)

While thinking about #419 I realized that `main` had a bug where
middleware wouldn't be called if no route matched the incoming request.
`Router` would just directly return a 404 without calling any
middleware.

This fixes by making the fallback default to a service that always
returns 404 and always calling that if no route matches.

Layers applied to the router is then also applied to the fallback.

Unfortunately this breaks #380 but I don't currently see a way to
support both. Auth middleware need to run _after_ routing because you
don't care about auth for unknown paths, but logging middleware need to
run _before_ routing because they do care about seeing requests for
unknown paths so they can log them...

Part of #419
This commit is contained in:
David Pedersen 2021-10-26 18:39:05 +02:00 committed by GitHub
parent e9533c566f
commit 8fe4eaf1d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 225 additions and 110 deletions

View file

@ -41,9 +41,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **fixed:** Adding a conflicting route will now cause a panic instead of silently making - **fixed:** Adding a conflicting route will now cause a panic instead of silently making
a route unreachable. a route unreachable.
- **fixed:** Route matching is faster as number of routes increase. - **fixed:** Route matching is faster as number of routes increase.
- **fixed:** Middleware that return early (such as `tower_http::auth::RequireAuthorization`)
now no longer catch requests that would otherwise be 404s. They also work
correctly with `Router::merge` (previously called `or`) ([#408])
- **fixed:** Correctly handle trailing slashes in routes: - **fixed:** Correctly handle trailing slashes in routes:
- If a route with a trailing slash exists and a request without a trailing - If a route with a trailing slash exists and a request without a trailing
slash is received, axum will send a 301 redirection to the route with the slash is received, axum will send a 301 redirection to the route with the

View file

@ -13,6 +13,8 @@ use std::{
use tower::util::Oneshot; use tower::util::Oneshot;
use tower_service::Service; use tower_service::Service;
pub use super::method_not_allowed::MethodNotAllowedFuture;
opaque_future! { opaque_future! {
/// Response future for [`Router`](super::Router). /// Response future for [`Router`](super::Router).
pub type RouterFuture<B> = pub type RouterFuture<B> =
@ -60,12 +62,6 @@ impl<B> Future for RouteFuture<B> {
} }
} }
opaque_future! {
/// Response future for [`MethodNotAllowed`](super::MethodNotAllowed).
pub type MethodNotAllowedFuture<E> =
std::future::Ready<Result<Response<BoxBody>, E>>;
}
pin_project! { pin_project! {
/// The response future for [`Nested`](super::Nested). /// The response future for [`Nested`](super::Nested).
#[derive(Debug)] #[derive(Debug)]

View file

@ -0,0 +1,69 @@
use crate::body::BoxBody;
use http::{Request, Response, StatusCode};
use std::{
convert::Infallible,
fmt,
future::ready,
marker::PhantomData,
task::{Context, Poll},
};
use tower_service::Service;
/// A [`Service`] that responds with `405 Method not allowed` to all requests.
///
/// This is used as the bottom service in a method router. You shouldn't have to
/// use it manually.
pub struct MethodNotAllowed<E = Infallible> {
_marker: PhantomData<fn() -> E>,
}
impl<E> MethodNotAllowed<E> {
pub(crate) fn new() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<E> Clone for MethodNotAllowed<E> {
fn clone(&self) -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<E> fmt::Debug for MethodNotAllowed<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("MethodNotAllowed").finish()
}
}
impl<B, E> Service<Request<B>> for MethodNotAllowed<E>
where
B: Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = E;
type Future = MethodNotAllowedFuture<E>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<B>) -> Self::Future {
let res = Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(crate::body::empty())
.unwrap();
MethodNotAllowedFuture::new(ready(Ok(res)))
}
}
opaque_future! {
/// Response future for [`MethodNotAllowed`](super::MethodNotAllowed).
pub type MethodNotAllowedFuture<E> =
std::future::Ready<Result<Response<BoxBody>, E>>;
}

View file

@ -1,6 +1,7 @@
//! Routing between [`Service`]s and handlers. //! Routing between [`Service`]s and handlers.
use self::future::{MethodNotAllowedFuture, NestedFuture, RouteFuture, RouterFuture}; use self::future::{NestedFuture, RouteFuture, RouterFuture};
use self::not_found::NotFound;
use crate::{ use crate::{
body::{box_body, Body, BoxBody}, body::{box_body, Body, BoxBody},
clone_box_service::CloneBoxService, clone_box_service::CloneBoxService,
@ -19,7 +20,6 @@ use std::{
convert::Infallible, convert::Infallible,
fmt, fmt,
future::ready, future::ready,
marker::PhantomData,
sync::Arc, sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
}; };
@ -33,8 +33,11 @@ pub mod handler_method_router;
pub mod service_method_router; pub mod service_method_router;
mod method_filter; mod method_filter;
mod method_not_allowed;
mod not_found;
pub use self::method_filter::MethodFilter; pub use self::method_filter::MethodFilter;
pub(crate) use self::method_not_allowed::MethodNotAllowed;
#[doc(no_inline)] #[doc(no_inline)]
pub use self::handler_method_router::{ pub use self::handler_method_router::{
@ -56,7 +59,7 @@ impl RouteId {
pub struct Router<B = Body> { pub struct Router<B = Body> {
routes: HashMap<RouteId, Route<B>>, routes: HashMap<RouteId, Route<B>>,
node: Node, node: Node,
fallback: Option<Route<B>>, fallback: Fallback<B>,
} }
impl<B> Clone for Router<B> { impl<B> Clone for Router<B> {
@ -102,7 +105,7 @@ where
Self { Self {
routes: Default::default(), routes: Default::default(),
node: Default::default(), node: Default::default(),
fallback: None, fallback: Fallback::Default(Route::new(NotFound)),
} }
} }
@ -198,7 +201,7 @@ where
panic!("Invalid route: {}", err); panic!("Invalid route: {}", err);
} }
self.routes.insert(id, Route(CloneBoxService::new(svc))); self.routes.insert(id, Route::new(svc));
self self
} }
@ -350,8 +353,7 @@ where
panic!("Invalid route: {}", err); panic!("Invalid route: {}", err);
} }
self.routes self.routes.insert(id, Route::new(Nested { svc }));
.insert(id, Route(CloneBoxService::new(Nested { svc })));
self self
} }
@ -452,7 +454,7 @@ where
}) })
.collect::<HashMap<RouteId, Route<LayeredReqBody>>>(); .collect::<HashMap<RouteId, Route<LayeredReqBody>>>();
let fallback = self.fallback.map(|fallback| Layer::layer(&layer, fallback)); let fallback = self.fallback.map(|svc| Layer::layer(&layer, svc));
Router { Router {
routes, routes,
@ -620,9 +622,12 @@ where
assert!(self.routes.insert(id, route).is_none()); assert!(self.routes.insert(id, route).is_none());
} }
if let Some(new_fallback) = fallback { self.fallback = match (self.fallback, fallback) {
self.fallback = Some(new_fallback); (Fallback::Default(_), pick @ Fallback::Default(_)) => pick,
} (Fallback::Default(_), pick @ Fallback::Custom(_)) => pick,
(pick @ Fallback::Custom(_), Fallback::Default(_)) => pick,
(Fallback::Custom(_), pick @ Fallback::Custom(_)) => pick,
};
self self
} }
@ -728,7 +733,7 @@ where
+ 'static, + 'static,
T::Future: Send + 'static, T::Future: Send + 'static,
{ {
self.fallback = Some(Route(CloneBoxService::new(svc))); self.fallback = Fallback::Custom(Route::new(svc));
self self
} }
@ -804,14 +809,15 @@ where
.body(crate::body::empty()) .body(crate::body::empty())
.unwrap(); .unwrap();
RouterFuture::from_response(res) RouterFuture::from_response(res)
} else if let Some(fallback) = &self.fallback {
RouterFuture::from_oneshot(fallback.clone().oneshot(req))
} else { } else {
let res = Response::builder() match &self.fallback {
.status(StatusCode::NOT_FOUND) Fallback::Default(inner) => {
.body(crate::body::empty()) RouterFuture::from_oneshot(inner.clone().oneshot(req))
.unwrap(); }
RouterFuture::from_response(res) Fallback::Custom(inner) => {
RouterFuture::from_oneshot(inner.clone().oneshot(req))
}
}
} }
} }
} }
@ -875,59 +881,6 @@ pub(crate) struct InvalidUtf8InPathParam {
pub(crate) key: ByteStr, pub(crate) key: ByteStr,
} }
/// A [`Service`] that responds with `405 Method not allowed` to all requests.
///
/// This is used as the bottom service in a method router. You shouldn't have to
/// use it manually.
pub struct MethodNotAllowed<E = Infallible> {
_marker: PhantomData<fn() -> E>,
}
impl<E> MethodNotAllowed<E> {
pub(crate) fn new() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<E> Clone for MethodNotAllowed<E> {
fn clone(&self) -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<E> fmt::Debug for MethodNotAllowed<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("MethodNotAllowed").finish()
}
}
impl<B, E> Service<Request<B>> for MethodNotAllowed<E>
where
B: Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = E;
type Future = MethodNotAllowedFuture<E>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<B>) -> Self::Future {
let res = Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(crate::body::empty())
.unwrap();
MethodNotAllowedFuture::new(ready(Ok(res)))
}
}
/// A [`Service`] that has been nested inside a router at some path. /// A [`Service`] that has been nested inside a router at some path.
/// ///
/// Created with [`Router::nest`]. /// Created with [`Router::nest`].
@ -1023,6 +976,19 @@ where
/// You normally shouldn't need to care about this type. /// You normally shouldn't need to care about this type.
pub struct Route<B = Body>(CloneBoxService<Request<B>, Response<BoxBody>, Infallible>); pub struct Route<B = Body>(CloneBoxService<Request<B>, Response<BoxBody>, Infallible>);
impl<B> Route<B> {
fn new<T>(svc: T) -> Self
where
T: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible>
+ Clone
+ Send
+ 'static,
T::Future: Send + 'static,
{
Self(CloneBoxService::new(svc))
}
}
impl<ReqBody> Clone for Route<ReqBody> { impl<ReqBody> Clone for Route<ReqBody> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self(self.0.clone()) Self(self.0.clone())
@ -1091,6 +1057,41 @@ impl fmt::Debug for Node {
} }
} }
enum Fallback<B> {
Default(Route<B>),
Custom(Route<B>),
}
impl<B> Clone for Fallback<B> {
fn clone(&self) -> Self {
match self {
Fallback::Default(inner) => Fallback::Default(inner.clone()),
Fallback::Custom(inner) => Fallback::Custom(inner.clone()),
}
}
}
impl<B> fmt::Debug for Fallback<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
Self::Custom(inner) => f.debug_tuple("Custom").field(inner).finish(),
}
}
}
impl<B> Fallback<B> {
fn map<F, B2>(self, f: F) -> Fallback<B2>
where
F: FnOnce(Route<B>) -> Route<B2>,
{
match self {
Fallback::Default(inner) => Fallback::Default(f(inner)),
Fallback::Custom(inner) => Fallback::Custom(f(inner)),
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

38
src/routing/not_found.rs Normal file
View file

@ -0,0 +1,38 @@
use crate::body::BoxBody;
use http::{Request, Response, StatusCode};
use std::{
convert::Infallible,
future::ready,
task::{Context, Poll},
};
use tower_service::Service;
/// A [`Service`] that responds with `405 Method not allowed` to all requests.
///
/// This is used as the bottom service in a method router. You shouldn't have to
/// use it manually.
#[derive(Clone, Copy, Debug)]
pub(crate) struct NotFound;
impl<B> Service<Request<B>> for NotFound
where
B: Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = std::future::Ready<Result<Response<BoxBody>, Self::Error>>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<B>) -> Self::Future {
let res = Response::builder()
.status(StatusCode::NOT_FOUND)
.body(crate::body::empty())
.unwrap();
ready(Ok(res))
}
}

View file

@ -19,6 +19,7 @@ use hyper::Body;
use serde::Deserialize; use serde::Deserialize;
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::future::Ready; use std::future::Ready;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{ use std::{
collections::HashMap, collections::HashMap,
convert::Infallible, convert::Infallible,
@ -552,35 +553,6 @@ async fn middleware_applies_to_routes_above() {
assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.status(), StatusCode::OK);
} }
#[tokio::test]
async fn middleware_that_return_early() {
let app = Router::new()
.route("/", get(|| async {}))
.layer(RequireAuthorizationLayer::bearer("password"))
.route("/public", get(|| async {}));
let client = TestClient::new(app);
assert_eq!(
client.get("/").send().await.status(),
StatusCode::UNAUTHORIZED
);
assert_eq!(
client
.get("/")
.header("authorization", "Bearer password")
.send()
.await
.status(),
StatusCode::OK
);
assert_eq!(
client.get("/doesnt-exist").send().await.status(),
StatusCode::NOT_FOUND
);
assert_eq!(client.get("/public").send().await.status(), StatusCode::OK);
}
#[tokio::test] #[tokio::test]
async fn with_trailing_slash() { async fn with_trailing_slash() {
let app = Router::new().route("/foo", get(|| async {})); let app = Router::new().route("/foo", get(|| async {}));
@ -672,6 +644,48 @@ async fn empty_route_nested() {
TestClient::new(app); TestClient::new(app);
} }
#[tokio::test]
async fn middleware_still_run_for_unmatched_requests() {
#[derive(Clone)]
struct CountMiddleware<S>(S);
static COUNT: AtomicUsize = AtomicUsize::new(0);
impl<R, S> Service<R> for CountMiddleware<S>
where
S: Service<R>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
fn call(&mut self, req: R) -> Self::Future {
COUNT.fetch_add(1, Ordering::SeqCst);
self.0.call(req)
}
}
let app = Router::new()
.route("/", get(|| async {}))
.layer(tower::layer::layer_fn(CountMiddleware));
let client = TestClient::new(app);
assert_eq!(COUNT.load(Ordering::SeqCst), 0);
client.get("/").send().await;
assert_eq!(COUNT.load(Ordering::SeqCst), 1);
client.get("/not-found").send().await;
assert_eq!(COUNT.load(Ordering::SeqCst), 2);
}
// TODO(david): middleware still run for empty routers
pub(crate) fn assert_send<T: Send>() {} pub(crate) fn assert_send<T: Send>() {}
pub(crate) fn assert_sync<T: Sync>() {} pub(crate) fn assert_sync<T: Sync>() {}
pub(crate) fn assert_unpin<T: Unpin>() {} pub(crate) fn assert_unpin<T: Unpin>() {}