mirror of
https://github.com/tokio-rs/axum.git
synced 2025-03-13 19:27:53 +01:00
Add Router::route_layer
(#474)
This addresses something thats been bothering me for some time: Most middleware need to run regardless if the request matches a route or not. For example you don't wanna skip logging for unmatched requests. However middleware such as authorization only make sense to run for matching requests. This previously wasn't possible to express and you'd have to manually apply the middleware to each handler. Consider this: ```rust Router::new() .route("/foo", get(|| async {})) .layer(RequireAuthorizationLayer::bearer("password")); ``` Calling `GET /foo` with an invalid token would receive `401 Unauthorized` as expected however calling some unknown route like `GET /not-found` would also return `401 Unauthorized`. I think this is unexpected and have seen a few users ask questions about it. It happened because the 404 you'd otherwise see is generated by a fallback service stored on `Router`. When adding a layer to the router the layer would also be applied to the fallback, which in the case of auth means the fallback would never be called for unauthorized requests. I think what axum does today is the right default however I still think we should support this somehow. Especially since [`extractor_middleware`](https://docs.rs/axum/0.3.1/axum/extract/fn.extractor_middleware.html) is mainly useful for auth but it doesn't work great today due to this gotcha. This PR proposes adding `Router::layer_on_matching_route` which only applies layers to routes, not the fallback, which fixes the issue. I'm not a big fan of the name `layer_on_matching_route`, would like something shorter, but I think it communicates the purpose decently. The generics are a bit different since the request body used on the routes and the fallback must match, so layers that changes the request body type are not compatible with `layer_on_matching_route`. Such middleware are very rare so that should be fine.
This commit is contained in:
parent
5d8ebce211
commit
7eb2c40b24
5 changed files with 126 additions and 23 deletions
|
@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
# Unreleased
|
||||
|
||||
- None.
|
||||
- **added:** Add `Router::route_layer` for applying middleware that
|
||||
will only run on requests that match a route ([#474])
|
||||
|
||||
[#474]: https://github.com/tokio-rs/axum/pull/474
|
||||
|
||||
# 0.3.1 (06. November, 2021)
|
||||
|
||||
|
|
28
axum/src/docs/routing/route_layer.md
Normal file
28
axum/src/docs/routing/route_layer.md
Normal file
|
@ -0,0 +1,28 @@
|
|||
Apply a [`tower::Layer`] to the router that will only run if the request matches
|
||||
a route.
|
||||
|
||||
This works similarly to [`Router::layer`] except the middleware will only run if
|
||||
the request matches a route. This is useful for middleware that return early
|
||||
(such as authorization) which might otherwise convert a `404 Not Found` into a
|
||||
`401 Unauthorized`.
|
||||
|
||||
# Example
|
||||
|
||||
```rust
|
||||
use axum::{
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use tower_http::auth::RequireAuthorizationLayer;
|
||||
|
||||
let app = Router::new()
|
||||
.route("/foo", get(|| async {}))
|
||||
.route_layer(RequireAuthorizationLayer::bearer("password"));
|
||||
|
||||
// `GET /foo` with a valid token will receive `200 OK`
|
||||
// `GET /foo` with a invalid token will receive `401 Unauthorized`
|
||||
// `GET /not-found` with a invalid token will receive `404 Not Found`
|
||||
# async {
|
||||
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
# };
|
||||
```
|
|
@ -37,38 +37,52 @@ use tower_service::Service;
|
|||
///
|
||||
/// ```rust
|
||||
/// use axum::{
|
||||
/// Router,
|
||||
/// async_trait,
|
||||
/// extract::{extractor_middleware, FromRequest, RequestParts},
|
||||
/// http::StatusCode,
|
||||
/// routing::{get, post},
|
||||
/// Router,
|
||||
/// };
|
||||
/// use std::convert::Infallible;
|
||||
/// use http::StatusCode;
|
||||
/// use async_trait::async_trait;
|
||||
///
|
||||
/// struct MyExtractor;
|
||||
/// // An extractor that performs authorization.
|
||||
/// struct RequireAuth;
|
||||
///
|
||||
/// #[async_trait]
|
||||
/// impl<B> FromRequest<B> for MyExtractor
|
||||
/// impl<B> FromRequest<B> for RequireAuth
|
||||
/// where
|
||||
/// B: Send,
|
||||
/// {
|
||||
/// type Rejection = Infallible;
|
||||
/// type Rejection = StatusCode;
|
||||
///
|
||||
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
/// # Ok(Self)
|
||||
/// // ...
|
||||
/// let auth_header = req
|
||||
/// .headers()
|
||||
/// .and_then(|headers| headers.get(http::header::AUTHORIZATION))
|
||||
/// .and_then(|value| value.to_str().ok());
|
||||
///
|
||||
/// if let Some(value) = auth_header {
|
||||
/// if value == "secret" {
|
||||
/// return Ok(Self);
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// Err(StatusCode::UNAUTHORIZED)
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// async fn handler() {}
|
||||
/// async fn handler() {
|
||||
/// // If we get here the request has been authorized
|
||||
/// }
|
||||
///
|
||||
/// async fn other_handler() {}
|
||||
/// async fn other_handler() {
|
||||
/// // If we get here the request has been authorized
|
||||
/// }
|
||||
///
|
||||
/// let app = Router::new()
|
||||
/// .route("/", get(handler))
|
||||
/// .route("/foo", post(other_handler))
|
||||
/// // The extractor will run before all routes
|
||||
/// .layer(extractor_middleware::<MyExtractor>());
|
||||
/// .route_layer(extractor_middleware::<RequireAuth>());
|
||||
/// # async {
|
||||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
|
|
|
@ -223,19 +223,16 @@ where
|
|||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/layer.md")]
|
||||
pub fn layer<L, LayeredReqBody, LayeredResBody>(self, layer: L) -> Router<LayeredReqBody>
|
||||
pub fn layer<L, NewReqBody, NewResBody>(self, layer: L) -> Router<NewReqBody>
|
||||
where
|
||||
L: Layer<Route<B>>,
|
||||
L::Service: Service<
|
||||
Request<LayeredReqBody>,
|
||||
Response = Response<LayeredResBody>,
|
||||
Error = Infallible,
|
||||
> + Clone
|
||||
L::Service: Service<Request<NewReqBody>, Response = Response<NewResBody>, Error = Infallible>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
<L::Service as Service<Request<LayeredReqBody>>>::Future: Send + 'static,
|
||||
LayeredResBody: http_body::Body<Data = Bytes> + Send + 'static,
|
||||
LayeredResBody::Error: Into<BoxError>,
|
||||
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
|
||||
NewResBody: http_body::Body<Data = Bytes> + Send + 'static,
|
||||
NewResBody::Error: Into<BoxError>,
|
||||
{
|
||||
let layer = ServiceBuilder::new()
|
||||
.layer_fn(Route::new)
|
||||
|
@ -249,7 +246,7 @@ where
|
|||
let route = Layer::layer(&layer, route);
|
||||
(id, route)
|
||||
})
|
||||
.collect::<HashMap<RouteId, Route<LayeredReqBody>>>();
|
||||
.collect();
|
||||
|
||||
let fallback = self.fallback.map(|svc| Layer::layer(&layer, svc));
|
||||
|
||||
|
@ -260,6 +257,39 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/route_layer.md")]
|
||||
pub fn route_layer<L, NewResBody>(self, layer: L) -> Self
|
||||
where
|
||||
L: Layer<Route<B>>,
|
||||
L::Service: Service<Request<B>, Response = Response<NewResBody>, Error = Infallible>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
<L::Service as Service<Request<B>>>::Future: Send + 'static,
|
||||
NewResBody: http_body::Body<Data = Bytes> + Send + 'static,
|
||||
NewResBody::Error: Into<BoxError>,
|
||||
{
|
||||
let layer = ServiceBuilder::new()
|
||||
.layer_fn(Route::new)
|
||||
.layer(MapResponseBodyLayer::new(box_body))
|
||||
.layer(layer);
|
||||
|
||||
let routes = self
|
||||
.routes
|
||||
.into_iter()
|
||||
.map(|(id, route)| {
|
||||
let route = Layer::layer(&layer, route);
|
||||
(id, route)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Router {
|
||||
routes,
|
||||
node: self.node,
|
||||
fallback: self.fallback,
|
||||
}
|
||||
}
|
||||
|
||||
#[doc = include_str!("../docs/routing/fallback.md")]
|
||||
pub fn fallback<T>(mut self, svc: T) -> Self
|
||||
where
|
||||
|
|
|
@ -478,3 +478,31 @@ async fn middleware_still_run_for_unmatched_requests() {
|
|||
async fn routing_to_router_panics() {
|
||||
TestClient::new(Router::new().route("/", Router::new()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn route_layer() {
|
||||
let app = Router::new()
|
||||
.route("/foo", get(|| async {}))
|
||||
.route_layer(RequireAuthorizationLayer::bearer("password"));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client
|
||||
.get("/foo")
|
||||
.header("authorization", "Bearer password")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let res = client.get("/foo").send().await;
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
|
||||
let res = client.get("/not-found").send().await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
// it would be nice if this would return `405 Method Not Allowed`
|
||||
// but that requires knowing more about which method route we're calling, which we
|
||||
// don't know currently since its just a generic `Service`
|
||||
let res = client.post("/foo").send().await;
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue