From 7eb2c40b24414c4fc0cd07edd0d980dffaba2126 Mon Sep 17 00:00:00 2001 From: David Pedersen <david.pdrsn@gmail.com> Date: Mon, 8 Nov 2021 18:48:02 +0100 Subject: [PATCH] Add `Router::route_layer` (#474) This addresses something thats been bothering me for some time: Most middleware need to run regardless if the request matches a route or not. For example you don't wanna skip logging for unmatched requests. However middleware such as authorization only make sense to run for matching requests. This previously wasn't possible to express and you'd have to manually apply the middleware to each handler. Consider this: ```rust Router::new() .route("/foo", get(|| async {})) .layer(RequireAuthorizationLayer::bearer("password")); ``` Calling `GET /foo` with an invalid token would receive `401 Unauthorized` as expected however calling some unknown route like `GET /not-found` would also return `401 Unauthorized`. I think this is unexpected and have seen a few users ask questions about it. It happened because the 404 you'd otherwise see is generated by a fallback service stored on `Router`. When adding a layer to the router the layer would also be applied to the fallback, which in the case of auth means the fallback would never be called for unauthorized requests. I think what axum does today is the right default however I still think we should support this somehow. Especially since [`extractor_middleware`](https://docs.rs/axum/0.3.1/axum/extract/fn.extractor_middleware.html) is mainly useful for auth but it doesn't work great today due to this gotcha. This PR proposes adding `Router::layer_on_matching_route` which only applies layers to routes, not the fallback, which fixes the issue. I'm not a big fan of the name `layer_on_matching_route`, would like something shorter, but I think it communicates the purpose decently. The generics are a bit different since the request body used on the routes and the fallback must match, so layers that changes the request body type are not compatible with `layer_on_matching_route`. Such middleware are very rare so that should be fine. --- axum/CHANGELOG.md | 5 ++- axum/src/docs/routing/route_layer.md | 28 +++++++++++++ axum/src/extract/extractor_middleware.rs | 38 ++++++++++++------ axum/src/routing/mod.rs | 50 +++++++++++++++++++----- axum/src/routing/tests/mod.rs | 28 +++++++++++++ 5 files changed, 126 insertions(+), 23 deletions(-) create mode 100644 axum/src/docs/routing/route_layer.md diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 23727d30..be2d7693 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **added:** Add `Router::route_layer` for applying middleware that + will only run on requests that match a route ([#474]) + +[#474]: https://github.com/tokio-rs/axum/pull/474 # 0.3.1 (06. November, 2021) diff --git a/axum/src/docs/routing/route_layer.md b/axum/src/docs/routing/route_layer.md new file mode 100644 index 00000000..58a42cd8 --- /dev/null +++ b/axum/src/docs/routing/route_layer.md @@ -0,0 +1,28 @@ +Apply a [`tower::Layer`] to the router that will only run if the request matches +a route. + +This works similarly to [`Router::layer`] except the middleware will only run if +the request matches a route. This is useful for middleware that return early +(such as authorization) which might otherwise convert a `404 Not Found` into a +`401 Unauthorized`. + +# Example + +```rust +use axum::{ + routing::get, + Router, +}; +use tower_http::auth::RequireAuthorizationLayer; + +let app = Router::new() + .route("/foo", get(|| async {})) + .route_layer(RequireAuthorizationLayer::bearer("password")); + +// `GET /foo` with a valid token will receive `200 OK` +// `GET /foo` with a invalid token will receive `401 Unauthorized` +// `GET /not-found` with a invalid token will receive `404 Not Found` +# async { +# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +# }; +``` diff --git a/axum/src/extract/extractor_middleware.rs b/axum/src/extract/extractor_middleware.rs index e7aeee79..85015268 100644 --- a/axum/src/extract/extractor_middleware.rs +++ b/axum/src/extract/extractor_middleware.rs @@ -37,38 +37,52 @@ use tower_service::Service; /// /// ```rust /// use axum::{ -/// Router, -/// async_trait, /// extract::{extractor_middleware, FromRequest, RequestParts}, -/// http::StatusCode, /// routing::{get, post}, +/// Router, /// }; -/// use std::convert::Infallible; +/// use http::StatusCode; +/// use async_trait::async_trait; /// -/// struct MyExtractor; +/// // An extractor that performs authorization. +/// struct RequireAuth; /// /// #[async_trait] -/// impl<B> FromRequest<B> for MyExtractor +/// impl<B> FromRequest<B> for RequireAuth /// where /// B: Send, /// { -/// type Rejection = Infallible; +/// type Rejection = StatusCode; /// /// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { -/// # Ok(Self) -/// // ... +/// let auth_header = req +/// .headers() +/// .and_then(|headers| headers.get(http::header::AUTHORIZATION)) +/// .and_then(|value| value.to_str().ok()); +/// +/// if let Some(value) = auth_header { +/// if value == "secret" { +/// return Ok(Self); +/// } +/// } +/// +/// Err(StatusCode::UNAUTHORIZED) /// } /// } /// -/// async fn handler() {} +/// async fn handler() { +/// // If we get here the request has been authorized +/// } /// -/// async fn other_handler() {} +/// async fn other_handler() { +/// // If we get here the request has been authorized +/// } /// /// let app = Router::new() /// .route("/", get(handler)) /// .route("/foo", post(other_handler)) /// // The extractor will run before all routes -/// .layer(extractor_middleware::<MyExtractor>()); +/// .route_layer(extractor_middleware::<RequireAuth>()); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 68321d4d..b9d8ba06 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -223,19 +223,16 @@ where } #[doc = include_str!("../docs/routing/layer.md")] - pub fn layer<L, LayeredReqBody, LayeredResBody>(self, layer: L) -> Router<LayeredReqBody> + pub fn layer<L, NewReqBody, NewResBody>(self, layer: L) -> Router<NewReqBody> where L: Layer<Route<B>>, - L::Service: Service< - Request<LayeredReqBody>, - Response = Response<LayeredResBody>, - Error = Infallible, - > + Clone + L::Service: Service<Request<NewReqBody>, Response = Response<NewResBody>, Error = Infallible> + + Clone + Send + 'static, - <L::Service as Service<Request<LayeredReqBody>>>::Future: Send + 'static, - LayeredResBody: http_body::Body<Data = Bytes> + Send + 'static, - LayeredResBody::Error: Into<BoxError>, + <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, + NewResBody: http_body::Body<Data = Bytes> + Send + 'static, + NewResBody::Error: Into<BoxError>, { let layer = ServiceBuilder::new() .layer_fn(Route::new) @@ -249,7 +246,7 @@ where let route = Layer::layer(&layer, route); (id, route) }) - .collect::<HashMap<RouteId, Route<LayeredReqBody>>>(); + .collect(); let fallback = self.fallback.map(|svc| Layer::layer(&layer, svc)); @@ -260,6 +257,39 @@ where } } + #[doc = include_str!("../docs/routing/route_layer.md")] + pub fn route_layer<L, NewResBody>(self, layer: L) -> Self + where + L: Layer<Route<B>>, + L::Service: Service<Request<B>, Response = Response<NewResBody>, Error = Infallible> + + Clone + + Send + + 'static, + <L::Service as Service<Request<B>>>::Future: Send + 'static, + NewResBody: http_body::Body<Data = Bytes> + Send + 'static, + NewResBody::Error: Into<BoxError>, + { + let layer = ServiceBuilder::new() + .layer_fn(Route::new) + .layer(MapResponseBodyLayer::new(box_body)) + .layer(layer); + + let routes = self + .routes + .into_iter() + .map(|(id, route)| { + let route = Layer::layer(&layer, route); + (id, route) + }) + .collect(); + + Router { + routes, + node: self.node, + fallback: self.fallback, + } + } + #[doc = include_str!("../docs/routing/fallback.md")] pub fn fallback<T>(mut self, svc: T) -> Self where diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index c674bd26..84bb4565 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -478,3 +478,31 @@ async fn middleware_still_run_for_unmatched_requests() { async fn routing_to_router_panics() { TestClient::new(Router::new().route("/", Router::new())); } + +#[tokio::test] +async fn route_layer() { + let app = Router::new() + .route("/foo", get(|| async {})) + .route_layer(RequireAuthorizationLayer::bearer("password")); + + let client = TestClient::new(app); + + let res = client + .get("/foo") + .header("authorization", "Bearer password") + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + + let res = client.get("/foo").send().await; + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + + let res = client.get("/not-found").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + // it would be nice if this would return `405 Method Not Allowed` + // but that requires knowing more about which method route we're calling, which we + // don't know currently since its just a generic `Service` + let res = client.post("/foo").send().await; + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); +}