From 8fe4eaf1d5a17fc1919afc64f7c2e8c890069653 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Oct 2021 18:39:05 +0200 Subject: [PATCH] Run middleware for requests with no matching route (#422) While thinking about #419 I realized that `main` had a bug where middleware wouldn't be called if no route matched the incoming request. `Router` would just directly return a 404 without calling any middleware. This fixes by making the fallback default to a service that always returns 404 and always calling that if no route matches. Layers applied to the router is then also applied to the fallback. Unfortunately this breaks #380 but I don't currently see a way to support both. Auth middleware need to run _after_ routing because you don't care about auth for unknown paths, but logging middleware need to run _before_ routing because they do care about seeing requests for unknown paths so they can log them... Part of #419 --- CHANGELOG.md | 3 - src/routing/future.rs | 8 +- src/routing/method_not_allowed.rs | 69 ++++++++++++++ src/routing/mod.rs | 145 +++++++++++++++--------------- src/routing/not_found.rs | 38 ++++++++ src/tests/mod.rs | 72 +++++++++------ 6 files changed, 225 insertions(+), 110 deletions(-) create mode 100644 src/routing/method_not_allowed.rs create mode 100644 src/routing/not_found.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 900f1eb4..81b0c84d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,9 +41,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **fixed:** Adding a conflicting route will now cause a panic instead of silently making a route unreachable. - **fixed:** Route matching is faster as number of routes increase. - - **fixed:** Middleware that return early (such as `tower_http::auth::RequireAuthorization`) - now no longer catch requests that would otherwise be 404s. They also work - correctly with `Router::merge` (previously called `or`) ([#408]) - **fixed:** Correctly handle trailing slashes in routes: - If a route with a trailing slash exists and a request without a trailing slash is received, axum will send a 301 redirection to the route with the diff --git a/src/routing/future.rs b/src/routing/future.rs index 43cbebd7..df0c4349 100644 --- a/src/routing/future.rs +++ b/src/routing/future.rs @@ -13,6 +13,8 @@ use std::{ use tower::util::Oneshot; use tower_service::Service; +pub use super::method_not_allowed::MethodNotAllowedFuture; + opaque_future! { /// Response future for [`Router`](super::Router). pub type RouterFuture = @@ -60,12 +62,6 @@ impl Future for RouteFuture { } } -opaque_future! { - /// Response future for [`MethodNotAllowed`](super::MethodNotAllowed). - pub type MethodNotAllowedFuture = - std::future::Ready, E>>; -} - pin_project! { /// The response future for [`Nested`](super::Nested). #[derive(Debug)] diff --git a/src/routing/method_not_allowed.rs b/src/routing/method_not_allowed.rs new file mode 100644 index 00000000..6e29490e --- /dev/null +++ b/src/routing/method_not_allowed.rs @@ -0,0 +1,69 @@ +use crate::body::BoxBody; +use http::{Request, Response, StatusCode}; +use std::{ + convert::Infallible, + fmt, + future::ready, + marker::PhantomData, + task::{Context, Poll}, +}; +use tower_service::Service; + +/// A [`Service`] that responds with `405 Method not allowed` to all requests. +/// +/// This is used as the bottom service in a method router. You shouldn't have to +/// use it manually. +pub struct MethodNotAllowed { + _marker: PhantomData E>, +} + +impl MethodNotAllowed { + pub(crate) fn new() -> Self { + Self { + _marker: PhantomData, + } + } +} + +impl Clone for MethodNotAllowed { + fn clone(&self) -> Self { + Self { + _marker: PhantomData, + } + } +} + +impl fmt::Debug for MethodNotAllowed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("MethodNotAllowed").finish() + } +} + +impl Service> for MethodNotAllowed +where + B: Send + Sync + 'static, +{ + type Response = Response; + type Error = E; + type Future = MethodNotAllowedFuture; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: Request) -> Self::Future { + let res = Response::builder() + .status(StatusCode::METHOD_NOT_ALLOWED) + .body(crate::body::empty()) + .unwrap(); + + MethodNotAllowedFuture::new(ready(Ok(res))) + } +} + +opaque_future! { + /// Response future for [`MethodNotAllowed`](super::MethodNotAllowed). + pub type MethodNotAllowedFuture = + std::future::Ready, E>>; +} diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 0fffea6d..8a3f92e3 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -1,6 +1,7 @@ //! Routing between [`Service`]s and handlers. -use self::future::{MethodNotAllowedFuture, NestedFuture, RouteFuture, RouterFuture}; +use self::future::{NestedFuture, RouteFuture, RouterFuture}; +use self::not_found::NotFound; use crate::{ body::{box_body, Body, BoxBody}, clone_box_service::CloneBoxService, @@ -19,7 +20,6 @@ use std::{ convert::Infallible, fmt, future::ready, - marker::PhantomData, sync::Arc, task::{Context, Poll}, }; @@ -33,8 +33,11 @@ pub mod handler_method_router; pub mod service_method_router; mod method_filter; +mod method_not_allowed; +mod not_found; pub use self::method_filter::MethodFilter; +pub(crate) use self::method_not_allowed::MethodNotAllowed; #[doc(no_inline)] pub use self::handler_method_router::{ @@ -56,7 +59,7 @@ impl RouteId { pub struct Router { routes: HashMap>, node: Node, - fallback: Option>, + fallback: Fallback, } impl Clone for Router { @@ -102,7 +105,7 @@ where Self { routes: Default::default(), node: Default::default(), - fallback: None, + fallback: Fallback::Default(Route::new(NotFound)), } } @@ -198,7 +201,7 @@ where panic!("Invalid route: {}", err); } - self.routes.insert(id, Route(CloneBoxService::new(svc))); + self.routes.insert(id, Route::new(svc)); self } @@ -350,8 +353,7 @@ where panic!("Invalid route: {}", err); } - self.routes - .insert(id, Route(CloneBoxService::new(Nested { svc }))); + self.routes.insert(id, Route::new(Nested { svc })); self } @@ -452,7 +454,7 @@ where }) .collect::>>(); - let fallback = self.fallback.map(|fallback| Layer::layer(&layer, fallback)); + let fallback = self.fallback.map(|svc| Layer::layer(&layer, svc)); Router { routes, @@ -620,9 +622,12 @@ where assert!(self.routes.insert(id, route).is_none()); } - if let Some(new_fallback) = fallback { - self.fallback = Some(new_fallback); - } + self.fallback = match (self.fallback, fallback) { + (Fallback::Default(_), pick @ Fallback::Default(_)) => pick, + (Fallback::Default(_), pick @ Fallback::Custom(_)) => pick, + (pick @ Fallback::Custom(_), Fallback::Default(_)) => pick, + (Fallback::Custom(_), pick @ Fallback::Custom(_)) => pick, + }; self } @@ -728,7 +733,7 @@ where + 'static, T::Future: Send + 'static, { - self.fallback = Some(Route(CloneBoxService::new(svc))); + self.fallback = Fallback::Custom(Route::new(svc)); self } @@ -804,14 +809,15 @@ where .body(crate::body::empty()) .unwrap(); RouterFuture::from_response(res) - } else if let Some(fallback) = &self.fallback { - RouterFuture::from_oneshot(fallback.clone().oneshot(req)) } else { - let res = Response::builder() - .status(StatusCode::NOT_FOUND) - .body(crate::body::empty()) - .unwrap(); - RouterFuture::from_response(res) + match &self.fallback { + Fallback::Default(inner) => { + RouterFuture::from_oneshot(inner.clone().oneshot(req)) + } + Fallback::Custom(inner) => { + RouterFuture::from_oneshot(inner.clone().oneshot(req)) + } + } } } } @@ -875,59 +881,6 @@ pub(crate) struct InvalidUtf8InPathParam { pub(crate) key: ByteStr, } -/// A [`Service`] that responds with `405 Method not allowed` to all requests. -/// -/// This is used as the bottom service in a method router. You shouldn't have to -/// use it manually. -pub struct MethodNotAllowed { - _marker: PhantomData E>, -} - -impl MethodNotAllowed { - pub(crate) fn new() -> Self { - Self { - _marker: PhantomData, - } - } -} - -impl Clone for MethodNotAllowed { - fn clone(&self) -> Self { - Self { - _marker: PhantomData, - } - } -} - -impl fmt::Debug for MethodNotAllowed { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("MethodNotAllowed").finish() - } -} - -impl Service> for MethodNotAllowed -where - B: Send + Sync + 'static, -{ - type Response = Response; - type Error = E; - type Future = MethodNotAllowedFuture; - - #[inline] - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, _req: Request) -> Self::Future { - let res = Response::builder() - .status(StatusCode::METHOD_NOT_ALLOWED) - .body(crate::body::empty()) - .unwrap(); - - MethodNotAllowedFuture::new(ready(Ok(res))) - } -} - /// A [`Service`] that has been nested inside a router at some path. /// /// Created with [`Router::nest`]. @@ -1023,6 +976,19 @@ where /// You normally shouldn't need to care about this type. pub struct Route(CloneBoxService, Response, Infallible>); +impl Route { + fn new(svc: T) -> Self + where + T: Service, Response = Response, Error = Infallible> + + Clone + + Send + + 'static, + T::Future: Send + 'static, + { + Self(CloneBoxService::new(svc)) + } +} + impl Clone for Route { fn clone(&self) -> Self { Self(self.0.clone()) @@ -1091,6 +1057,41 @@ impl fmt::Debug for Node { } } +enum Fallback { + Default(Route), + Custom(Route), +} + +impl Clone for Fallback { + fn clone(&self) -> Self { + match self { + Fallback::Default(inner) => Fallback::Default(inner.clone()), + Fallback::Custom(inner) => Fallback::Custom(inner.clone()), + } + } +} + +impl fmt::Debug for Fallback { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(), + Self::Custom(inner) => f.debug_tuple("Custom").field(inner).finish(), + } + } +} + +impl Fallback { + fn map(self, f: F) -> Fallback + where + F: FnOnce(Route) -> Route, + { + match self { + Fallback::Default(inner) => Fallback::Default(f(inner)), + Fallback::Custom(inner) => Fallback::Custom(f(inner)), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/routing/not_found.rs b/src/routing/not_found.rs new file mode 100644 index 00000000..0ccc0859 --- /dev/null +++ b/src/routing/not_found.rs @@ -0,0 +1,38 @@ +use crate::body::BoxBody; +use http::{Request, Response, StatusCode}; +use std::{ + convert::Infallible, + future::ready, + task::{Context, Poll}, +}; +use tower_service::Service; + +/// A [`Service`] that responds with `405 Method not allowed` to all requests. +/// +/// This is used as the bottom service in a method router. You shouldn't have to +/// use it manually. +#[derive(Clone, Copy, Debug)] +pub(crate) struct NotFound; + +impl Service> for NotFound +where + B: Send + Sync + 'static, +{ + type Response = Response; + type Error = Infallible; + type Future = std::future::Ready, Self::Error>>; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: Request) -> Self::Future { + let res = Response::builder() + .status(StatusCode::NOT_FOUND) + .body(crate::body::empty()) + .unwrap(); + + ready(Ok(res)) + } +} diff --git a/src/tests/mod.rs b/src/tests/mod.rs index c351cd79..f1c1e98c 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -19,6 +19,7 @@ use hyper::Body; use serde::Deserialize; use serde_json::{json, Value}; use std::future::Ready; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::{ collections::HashMap, convert::Infallible, @@ -552,35 +553,6 @@ async fn middleware_applies_to_routes_above() { assert_eq!(res.status(), StatusCode::OK); } -#[tokio::test] -async fn middleware_that_return_early() { - let app = Router::new() - .route("/", get(|| async {})) - .layer(RequireAuthorizationLayer::bearer("password")) - .route("/public", get(|| async {})); - - let client = TestClient::new(app); - - assert_eq!( - client.get("/").send().await.status(), - StatusCode::UNAUTHORIZED - ); - assert_eq!( - client - .get("/") - .header("authorization", "Bearer password") - .send() - .await - .status(), - StatusCode::OK - ); - assert_eq!( - client.get("/doesnt-exist").send().await.status(), - StatusCode::NOT_FOUND - ); - assert_eq!(client.get("/public").send().await.status(), StatusCode::OK); -} - #[tokio::test] async fn with_trailing_slash() { let app = Router::new().route("/foo", get(|| async {})); @@ -672,6 +644,48 @@ async fn empty_route_nested() { TestClient::new(app); } +#[tokio::test] +async fn middleware_still_run_for_unmatched_requests() { + #[derive(Clone)] + struct CountMiddleware(S); + + static COUNT: AtomicUsize = AtomicUsize::new(0); + + impl Service for CountMiddleware + where + S: Service, + { + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) + } + + fn call(&mut self, req: R) -> Self::Future { + COUNT.fetch_add(1, Ordering::SeqCst); + self.0.call(req) + } + } + + let app = Router::new() + .route("/", get(|| async {})) + .layer(tower::layer::layer_fn(CountMiddleware)); + + let client = TestClient::new(app); + + assert_eq!(COUNT.load(Ordering::SeqCst), 0); + + client.get("/").send().await; + assert_eq!(COUNT.load(Ordering::SeqCst), 1); + + client.get("/not-found").send().await; + assert_eq!(COUNT.load(Ordering::SeqCst), 2); +} + +// TODO(david): middleware still run for empty routers + pub(crate) fn assert_send() {} pub(crate) fn assert_sync() {} pub(crate) fn assert_unpin() {}