From baf7cabfe1a50dfaf5d1e47501134da3bff3e9e2 Mon Sep 17 00:00:00 2001
From: David Pedersen <david.pdrsn@gmail.com>
Date: Mon, 25 Oct 2021 21:46:49 +0200
Subject: [PATCH] Add test for middleware that return early (#409)

* Add test for middleware that return early

Turns out #408 fixed #380.

* changelog
---
 CHANGELOG.md       |  3 +++
 src/tests/merge.rs | 30 ++++++++++++++++++++++++++++++
 src/tests/mod.rs   | 33 +++++++++++++++++++++++++++++++--
 3 files changed, 64 insertions(+), 2 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 38c333cc..beeb15da 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -123,6 +123,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
   Use `Router::fallback` for adding fallback routes ([#408])
 - **added:** `Router::fallback` for adding handlers for request that didn't
   match any routes ([#408])
+- **fixed:** Middleware that return early (such as `tower_http::auth::RequireAuthorization`)
+  now no longer catch requests that would otherwise be 404s. They also work
+  correctly with `Router::merge` (previously called `or`) ([#408])
 
 [#339]: https://github.com/tokio-rs/axum/pull/339
 [#286]: https://github.com/tokio-rs/axum/pull/286
diff --git a/src/tests/merge.rs b/src/tests/merge.rs
index a91dcb57..f4ba00cd 100644
--- a/src/tests/merge.rs
+++ b/src/tests/merge.rs
@@ -373,3 +373,33 @@ async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() {
         })
     );
 }
+
+#[tokio::test]
+async fn middleware_that_return_early() {
+    let private = Router::new()
+        .route("/", get(|| async {}))
+        .layer(RequireAuthorizationLayer::bearer("password"));
+
+    let public = Router::new().route("/public", get(|| async {}));
+
+    let client = TestClient::new(private.merge(public));
+
+    assert_eq!(
+        client.get("/").send().await.status(),
+        StatusCode::UNAUTHORIZED
+    );
+    assert_eq!(
+        client
+            .get("/")
+            .header("authorization", "Bearer password")
+            .send()
+            .await
+            .status(),
+        StatusCode::OK
+    );
+    assert_eq!(
+        client.get("/doesnt-exist").send().await.status(),
+        StatusCode::NOT_FOUND
+    );
+    assert_eq!(client.get("/public").send().await.status(), StatusCode::OK);
+}
diff --git a/src/tests/mod.rs b/src/tests/mod.rs
index cf9f3bff..73b36692 100644
--- a/src/tests/mod.rs
+++ b/src/tests/mod.rs
@@ -25,8 +25,8 @@ use std::{
     task::{Context, Poll},
     time::Duration,
 };
-use tower::timeout::TimeoutLayer;
-use tower::{service_fn, ServiceBuilder};
+use tower::{service_fn, timeout::TimeoutLayer, ServiceBuilder};
+use tower_http::auth::RequireAuthorizationLayer;
 use tower_service::Service;
 
 pub(crate) use helpers::*;
@@ -550,6 +550,35 @@ async fn middleware_applies_to_routes_above() {
     assert_eq!(res.status(), StatusCode::OK);
 }
 
+#[tokio::test]
+async fn middleware_that_return_early() {
+    let app = Router::new()
+        .route("/", get(|| async {}))
+        .layer(RequireAuthorizationLayer::bearer("password"))
+        .route("/public", get(|| async {}));
+
+    let client = TestClient::new(app);
+
+    assert_eq!(
+        client.get("/").send().await.status(),
+        StatusCode::UNAUTHORIZED
+    );
+    assert_eq!(
+        client
+            .get("/")
+            .header("authorization", "Bearer password")
+            .send()
+            .await
+            .status(),
+        StatusCode::OK
+    );
+    assert_eq!(
+        client.get("/doesnt-exist").send().await.status(),
+        StatusCode::NOT_FOUND
+    );
+    assert_eq!(client.get("/public").send().await.status(), StatusCode::OK);
+}
+
 pub(crate) fn assert_send<T: Send>() {}
 pub(crate) fn assert_sync<T: Sync>() {}
 pub(crate) fn assert_unpin<T: Unpin>() {}