mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-28 07:20:12 +01:00
Run middleware for requests with no matching route (#422)
While thinking about #419 I realized that `main` had a bug where middleware wouldn't be called if no route matched the incoming request. `Router` would just directly return a 404 without calling any middleware. This fixes by making the fallback default to a service that always returns 404 and always calling that if no route matches. Layers applied to the router is then also applied to the fallback. Unfortunately this breaks #380 but I don't currently see a way to support both. Auth middleware need to run _after_ routing because you don't care about auth for unknown paths, but logging middleware need to run _before_ routing because they do care about seeing requests for unknown paths so they can log them... Part of #419
This commit is contained in:
parent
e9533c566f
commit
8fe4eaf1d5
6 changed files with 225 additions and 110 deletions
|
@ -41,9 +41,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
- **fixed:** Adding a conflicting route will now cause a panic instead of silently making
|
- **fixed:** Adding a conflicting route will now cause a panic instead of silently making
|
||||||
a route unreachable.
|
a route unreachable.
|
||||||
- **fixed:** Route matching is faster as number of routes increase.
|
- **fixed:** Route matching is faster as number of routes increase.
|
||||||
- **fixed:** Middleware that return early (such as `tower_http::auth::RequireAuthorization`)
|
|
||||||
now no longer catch requests that would otherwise be 404s. They also work
|
|
||||||
correctly with `Router::merge` (previously called `or`) ([#408])
|
|
||||||
- **fixed:** Correctly handle trailing slashes in routes:
|
- **fixed:** Correctly handle trailing slashes in routes:
|
||||||
- If a route with a trailing slash exists and a request without a trailing
|
- If a route with a trailing slash exists and a request without a trailing
|
||||||
slash is received, axum will send a 301 redirection to the route with the
|
slash is received, axum will send a 301 redirection to the route with the
|
||||||
|
|
|
@ -13,6 +13,8 @@ use std::{
|
||||||
use tower::util::Oneshot;
|
use tower::util::Oneshot;
|
||||||
use tower_service::Service;
|
use tower_service::Service;
|
||||||
|
|
||||||
|
pub use super::method_not_allowed::MethodNotAllowedFuture;
|
||||||
|
|
||||||
opaque_future! {
|
opaque_future! {
|
||||||
/// Response future for [`Router`](super::Router).
|
/// Response future for [`Router`](super::Router).
|
||||||
pub type RouterFuture<B> =
|
pub type RouterFuture<B> =
|
||||||
|
@ -60,12 +62,6 @@ impl<B> Future for RouteFuture<B> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
opaque_future! {
|
|
||||||
/// Response future for [`MethodNotAllowed`](super::MethodNotAllowed).
|
|
||||||
pub type MethodNotAllowedFuture<E> =
|
|
||||||
std::future::Ready<Result<Response<BoxBody>, E>>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pin_project! {
|
pin_project! {
|
||||||
/// The response future for [`Nested`](super::Nested).
|
/// The response future for [`Nested`](super::Nested).
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
69
src/routing/method_not_allowed.rs
Normal file
69
src/routing/method_not_allowed.rs
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
use crate::body::BoxBody;
|
||||||
|
use http::{Request, Response, StatusCode};
|
||||||
|
use std::{
|
||||||
|
convert::Infallible,
|
||||||
|
fmt,
|
||||||
|
future::ready,
|
||||||
|
marker::PhantomData,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
use tower_service::Service;
|
||||||
|
|
||||||
|
/// A [`Service`] that responds with `405 Method not allowed` to all requests.
|
||||||
|
///
|
||||||
|
/// This is used as the bottom service in a method router. You shouldn't have to
|
||||||
|
/// use it manually.
|
||||||
|
pub struct MethodNotAllowed<E = Infallible> {
|
||||||
|
_marker: PhantomData<fn() -> E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> MethodNotAllowed<E> {
|
||||||
|
pub(crate) fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
_marker: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> Clone for MethodNotAllowed<E> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
_marker: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> fmt::Debug for MethodNotAllowed<E> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_tuple("MethodNotAllowed").finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B, E> Service<Request<B>> for MethodNotAllowed<E>
|
||||||
|
where
|
||||||
|
B: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
type Response = Response<BoxBody>;
|
||||||
|
type Error = E;
|
||||||
|
type Future = MethodNotAllowedFuture<E>;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, _req: Request<B>) -> Self::Future {
|
||||||
|
let res = Response::builder()
|
||||||
|
.status(StatusCode::METHOD_NOT_ALLOWED)
|
||||||
|
.body(crate::body::empty())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
MethodNotAllowedFuture::new(ready(Ok(res)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
opaque_future! {
|
||||||
|
/// Response future for [`MethodNotAllowed`](super::MethodNotAllowed).
|
||||||
|
pub type MethodNotAllowedFuture<E> =
|
||||||
|
std::future::Ready<Result<Response<BoxBody>, E>>;
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
//! Routing between [`Service`]s and handlers.
|
//! Routing between [`Service`]s and handlers.
|
||||||
|
|
||||||
use self::future::{MethodNotAllowedFuture, NestedFuture, RouteFuture, RouterFuture};
|
use self::future::{NestedFuture, RouteFuture, RouterFuture};
|
||||||
|
use self::not_found::NotFound;
|
||||||
use crate::{
|
use crate::{
|
||||||
body::{box_body, Body, BoxBody},
|
body::{box_body, Body, BoxBody},
|
||||||
clone_box_service::CloneBoxService,
|
clone_box_service::CloneBoxService,
|
||||||
|
@ -19,7 +20,6 @@ use std::{
|
||||||
convert::Infallible,
|
convert::Infallible,
|
||||||
fmt,
|
fmt,
|
||||||
future::ready,
|
future::ready,
|
||||||
marker::PhantomData,
|
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
|
@ -33,8 +33,11 @@ pub mod handler_method_router;
|
||||||
pub mod service_method_router;
|
pub mod service_method_router;
|
||||||
|
|
||||||
mod method_filter;
|
mod method_filter;
|
||||||
|
mod method_not_allowed;
|
||||||
|
mod not_found;
|
||||||
|
|
||||||
pub use self::method_filter::MethodFilter;
|
pub use self::method_filter::MethodFilter;
|
||||||
|
pub(crate) use self::method_not_allowed::MethodNotAllowed;
|
||||||
|
|
||||||
#[doc(no_inline)]
|
#[doc(no_inline)]
|
||||||
pub use self::handler_method_router::{
|
pub use self::handler_method_router::{
|
||||||
|
@ -56,7 +59,7 @@ impl RouteId {
|
||||||
pub struct Router<B = Body> {
|
pub struct Router<B = Body> {
|
||||||
routes: HashMap<RouteId, Route<B>>,
|
routes: HashMap<RouteId, Route<B>>,
|
||||||
node: Node,
|
node: Node,
|
||||||
fallback: Option<Route<B>>,
|
fallback: Fallback<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B> Clone for Router<B> {
|
impl<B> Clone for Router<B> {
|
||||||
|
@ -102,7 +105,7 @@ where
|
||||||
Self {
|
Self {
|
||||||
routes: Default::default(),
|
routes: Default::default(),
|
||||||
node: Default::default(),
|
node: Default::default(),
|
||||||
fallback: None,
|
fallback: Fallback::Default(Route::new(NotFound)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,7 +201,7 @@ where
|
||||||
panic!("Invalid route: {}", err);
|
panic!("Invalid route: {}", err);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.routes.insert(id, Route(CloneBoxService::new(svc)));
|
self.routes.insert(id, Route::new(svc));
|
||||||
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
@ -350,8 +353,7 @@ where
|
||||||
panic!("Invalid route: {}", err);
|
panic!("Invalid route: {}", err);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.routes
|
self.routes.insert(id, Route::new(Nested { svc }));
|
||||||
.insert(id, Route(CloneBoxService::new(Nested { svc })));
|
|
||||||
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
@ -452,7 +454,7 @@ where
|
||||||
})
|
})
|
||||||
.collect::<HashMap<RouteId, Route<LayeredReqBody>>>();
|
.collect::<HashMap<RouteId, Route<LayeredReqBody>>>();
|
||||||
|
|
||||||
let fallback = self.fallback.map(|fallback| Layer::layer(&layer, fallback));
|
let fallback = self.fallback.map(|svc| Layer::layer(&layer, svc));
|
||||||
|
|
||||||
Router {
|
Router {
|
||||||
routes,
|
routes,
|
||||||
|
@ -620,9 +622,12 @@ where
|
||||||
assert!(self.routes.insert(id, route).is_none());
|
assert!(self.routes.insert(id, route).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(new_fallback) = fallback {
|
self.fallback = match (self.fallback, fallback) {
|
||||||
self.fallback = Some(new_fallback);
|
(Fallback::Default(_), pick @ Fallback::Default(_)) => pick,
|
||||||
}
|
(Fallback::Default(_), pick @ Fallback::Custom(_)) => pick,
|
||||||
|
(pick @ Fallback::Custom(_), Fallback::Default(_)) => pick,
|
||||||
|
(Fallback::Custom(_), pick @ Fallback::Custom(_)) => pick,
|
||||||
|
};
|
||||||
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
@ -728,7 +733,7 @@ where
|
||||||
+ 'static,
|
+ 'static,
|
||||||
T::Future: Send + 'static,
|
T::Future: Send + 'static,
|
||||||
{
|
{
|
||||||
self.fallback = Some(Route(CloneBoxService::new(svc)));
|
self.fallback = Fallback::Custom(Route::new(svc));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -804,14 +809,15 @@ where
|
||||||
.body(crate::body::empty())
|
.body(crate::body::empty())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
RouterFuture::from_response(res)
|
RouterFuture::from_response(res)
|
||||||
} else if let Some(fallback) = &self.fallback {
|
|
||||||
RouterFuture::from_oneshot(fallback.clone().oneshot(req))
|
|
||||||
} else {
|
} else {
|
||||||
let res = Response::builder()
|
match &self.fallback {
|
||||||
.status(StatusCode::NOT_FOUND)
|
Fallback::Default(inner) => {
|
||||||
.body(crate::body::empty())
|
RouterFuture::from_oneshot(inner.clone().oneshot(req))
|
||||||
.unwrap();
|
}
|
||||||
RouterFuture::from_response(res)
|
Fallback::Custom(inner) => {
|
||||||
|
RouterFuture::from_oneshot(inner.clone().oneshot(req))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -875,59 +881,6 @@ pub(crate) struct InvalidUtf8InPathParam {
|
||||||
pub(crate) key: ByteStr,
|
pub(crate) key: ByteStr,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A [`Service`] that responds with `405 Method not allowed` to all requests.
|
|
||||||
///
|
|
||||||
/// This is used as the bottom service in a method router. You shouldn't have to
|
|
||||||
/// use it manually.
|
|
||||||
pub struct MethodNotAllowed<E = Infallible> {
|
|
||||||
_marker: PhantomData<fn() -> E>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<E> MethodNotAllowed<E> {
|
|
||||||
pub(crate) fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
_marker: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<E> Clone for MethodNotAllowed<E> {
|
|
||||||
fn clone(&self) -> Self {
|
|
||||||
Self {
|
|
||||||
_marker: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<E> fmt::Debug for MethodNotAllowed<E> {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
f.debug_tuple("MethodNotAllowed").finish()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B, E> Service<Request<B>> for MethodNotAllowed<E>
|
|
||||||
where
|
|
||||||
B: Send + Sync + 'static,
|
|
||||||
{
|
|
||||||
type Response = Response<BoxBody>;
|
|
||||||
type Error = E;
|
|
||||||
type Future = MethodNotAllowedFuture<E>;
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
||||||
Poll::Ready(Ok(()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn call(&mut self, _req: Request<B>) -> Self::Future {
|
|
||||||
let res = Response::builder()
|
|
||||||
.status(StatusCode::METHOD_NOT_ALLOWED)
|
|
||||||
.body(crate::body::empty())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
MethodNotAllowedFuture::new(ready(Ok(res)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A [`Service`] that has been nested inside a router at some path.
|
/// A [`Service`] that has been nested inside a router at some path.
|
||||||
///
|
///
|
||||||
/// Created with [`Router::nest`].
|
/// Created with [`Router::nest`].
|
||||||
|
@ -1023,6 +976,19 @@ where
|
||||||
/// You normally shouldn't need to care about this type.
|
/// You normally shouldn't need to care about this type.
|
||||||
pub struct Route<B = Body>(CloneBoxService<Request<B>, Response<BoxBody>, Infallible>);
|
pub struct Route<B = Body>(CloneBoxService<Request<B>, Response<BoxBody>, Infallible>);
|
||||||
|
|
||||||
|
impl<B> Route<B> {
|
||||||
|
fn new<T>(svc: T) -> Self
|
||||||
|
where
|
||||||
|
T: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible>
|
||||||
|
+ Clone
|
||||||
|
+ Send
|
||||||
|
+ 'static,
|
||||||
|
T::Future: Send + 'static,
|
||||||
|
{
|
||||||
|
Self(CloneBoxService::new(svc))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<ReqBody> Clone for Route<ReqBody> {
|
impl<ReqBody> Clone for Route<ReqBody> {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self(self.0.clone())
|
Self(self.0.clone())
|
||||||
|
@ -1091,6 +1057,41 @@ impl fmt::Debug for Node {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum Fallback<B> {
|
||||||
|
Default(Route<B>),
|
||||||
|
Custom(Route<B>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B> Clone for Fallback<B> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
match self {
|
||||||
|
Fallback::Default(inner) => Fallback::Default(inner.clone()),
|
||||||
|
Fallback::Custom(inner) => Fallback::Custom(inner.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B> fmt::Debug for Fallback<B> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
|
||||||
|
Self::Custom(inner) => f.debug_tuple("Custom").field(inner).finish(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B> Fallback<B> {
|
||||||
|
fn map<F, B2>(self, f: F) -> Fallback<B2>
|
||||||
|
where
|
||||||
|
F: FnOnce(Route<B>) -> Route<B2>,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Fallback::Default(inner) => Fallback::Default(f(inner)),
|
||||||
|
Fallback::Custom(inner) => Fallback::Custom(f(inner)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
38
src/routing/not_found.rs
Normal file
38
src/routing/not_found.rs
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
use crate::body::BoxBody;
|
||||||
|
use http::{Request, Response, StatusCode};
|
||||||
|
use std::{
|
||||||
|
convert::Infallible,
|
||||||
|
future::ready,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
use tower_service::Service;
|
||||||
|
|
||||||
|
/// A [`Service`] that responds with `405 Method not allowed` to all requests.
|
||||||
|
///
|
||||||
|
/// This is used as the bottom service in a method router. You shouldn't have to
|
||||||
|
/// use it manually.
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub(crate) struct NotFound;
|
||||||
|
|
||||||
|
impl<B> Service<Request<B>> for NotFound
|
||||||
|
where
|
||||||
|
B: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
type Response = Response<BoxBody>;
|
||||||
|
type Error = Infallible;
|
||||||
|
type Future = std::future::Ready<Result<Response<BoxBody>, Self::Error>>;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, _req: Request<B>) -> Self::Future {
|
||||||
|
let res = Response::builder()
|
||||||
|
.status(StatusCode::NOT_FOUND)
|
||||||
|
.body(crate::body::empty())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
ready(Ok(res))
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,6 +19,7 @@ use hyper::Body;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use std::future::Ready;
|
use std::future::Ready;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
convert::Infallible,
|
convert::Infallible,
|
||||||
|
@ -552,35 +553,6 @@ async fn middleware_applies_to_routes_above() {
|
||||||
assert_eq!(res.status(), StatusCode::OK);
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn middleware_that_return_early() {
|
|
||||||
let app = Router::new()
|
|
||||||
.route("/", get(|| async {}))
|
|
||||||
.layer(RequireAuthorizationLayer::bearer("password"))
|
|
||||||
.route("/public", get(|| async {}));
|
|
||||||
|
|
||||||
let client = TestClient::new(app);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
client.get("/").send().await.status(),
|
|
||||||
StatusCode::UNAUTHORIZED
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
client
|
|
||||||
.get("/")
|
|
||||||
.header("authorization", "Bearer password")
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.status(),
|
|
||||||
StatusCode::OK
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
client.get("/doesnt-exist").send().await.status(),
|
|
||||||
StatusCode::NOT_FOUND
|
|
||||||
);
|
|
||||||
assert_eq!(client.get("/public").send().await.status(), StatusCode::OK);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn with_trailing_slash() {
|
async fn with_trailing_slash() {
|
||||||
let app = Router::new().route("/foo", get(|| async {}));
|
let app = Router::new().route("/foo", get(|| async {}));
|
||||||
|
@ -672,6 +644,48 @@ async fn empty_route_nested() {
|
||||||
TestClient::new(app);
|
TestClient::new(app);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn middleware_still_run_for_unmatched_requests() {
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct CountMiddleware<S>(S);
|
||||||
|
|
||||||
|
static COUNT: AtomicUsize = AtomicUsize::new(0);
|
||||||
|
|
||||||
|
impl<R, S> Service<R> for CountMiddleware<S>
|
||||||
|
where
|
||||||
|
S: Service<R>,
|
||||||
|
{
|
||||||
|
type Response = S::Response;
|
||||||
|
type Error = S::Error;
|
||||||
|
type Future = S::Future;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.0.poll_ready(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, req: R) -> Self::Future {
|
||||||
|
COUNT.fetch_add(1, Ordering::SeqCst);
|
||||||
|
self.0.call(req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/", get(|| async {}))
|
||||||
|
.layer(tower::layer::layer_fn(CountMiddleware));
|
||||||
|
|
||||||
|
let client = TestClient::new(app);
|
||||||
|
|
||||||
|
assert_eq!(COUNT.load(Ordering::SeqCst), 0);
|
||||||
|
|
||||||
|
client.get("/").send().await;
|
||||||
|
assert_eq!(COUNT.load(Ordering::SeqCst), 1);
|
||||||
|
|
||||||
|
client.get("/not-found").send().await;
|
||||||
|
assert_eq!(COUNT.load(Ordering::SeqCst), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(david): middleware still run for empty routers
|
||||||
|
|
||||||
pub(crate) fn assert_send<T: Send>() {}
|
pub(crate) fn assert_send<T: Send>() {}
|
||||||
pub(crate) fn assert_sync<T: Sync>() {}
|
pub(crate) fn assert_sync<T: Sync>() {}
|
||||||
pub(crate) fn assert_unpin<T: Unpin>() {}
|
pub(crate) fn assert_unpin<T: Unpin>() {}
|
||||||
|
|
Loading…
Reference in a new issue