From 513766b9e72898d32888772a9f7f5b3a2bc942fd Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sun, 30 Jan 2022 21:41:34 +0100 Subject: [PATCH] Set `Allow` header when returning `405 Method Not Allowed` (#733) --- axum/CHANGELOG.md | 1 + axum/src/docs/method_routing/fallback.md | 9 + axum/src/routing/method_routing.rs | 205 ++++++++++++++++++++--- axum/src/routing/route.rs | 32 +++- 4 files changed, 221 insertions(+), 26 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index bc528f9b..9c6f307a 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -51,6 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `PathRejection` - `MatchedPathRejection` - `WebSocketUpgradeRejection` +- **fixed:** Set `Allow` header when responding with `405 Method Not Allowed` [#644]: https://github.com/tokio-rs/axum/pull/644 [#665]: https://github.com/tokio-rs/axum/pull/665 diff --git a/axum/src/docs/method_routing/fallback.md b/axum/src/docs/method_routing/fallback.md index c027578c..f00f7395 100644 --- a/axum/src/docs/method_routing/fallback.md +++ b/axum/src/docs/method_routing/fallback.md @@ -51,3 +51,12 @@ async fn fallback_two() -> impl IntoResponse {} # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; ``` + +## Setting the `Allow` header + +By default `MethodRouter` will set the `Allow` header when returning `405 Method +Not Allowed`. This is also done when the fallback is used unless the response +generated by the fallback already sets the `Allow` header. + +This means if you use `fallback` to accept additional methods make sure you set +the `Allow` header correctly. diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index e151de66..8daa1d68 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -7,6 +7,7 @@ use crate::{ routing::{Fallback, MethodFilter, Route}, BoxError, }; +use bytes::BytesMut; use std::{ convert::Infallible, fmt, @@ -386,7 +387,7 @@ where ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { - MethodRouter::new().fallback(svc) + MethodRouter::new().fallback(svc).skip_allow_header() } top_level_handler_fn!(delete, DELETE); @@ -469,7 +470,9 @@ where B: Send + 'static, T: 'static, { - MethodRouter::new().fallback_boxed_response_body(handler.into_service()) + MethodRouter::new() + .fallback_boxed_response_body(handler.into_service()) + .skip_allow_header() } /// A [`Service`] that accepts requests based on a [`MethodFilter`] and @@ -484,9 +487,20 @@ pub struct MethodRouter { put: Option>, trace: Option>, fallback: Fallback, + allow_header: AllowHeader, _request_body: PhantomData (B, E)>, } +#[derive(Clone)] +enum AllowHeader { + /// No `Allow` header value has been built-up yet. This is the default state + None, + /// Don't set an `Allow` header. This is used when `any` or `any_service` are called. + Skip, + /// The current value of the `Allow` header. + Bytes(BytesMut), +} + impl fmt::Debug for MethodRouter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MethodRouter") @@ -522,6 +536,7 @@ impl MethodRouter { post: None, put: None, trace: None, + allow_header: AllowHeader::None, fallback: Fallback::Default(fallback), _request_body: PhantomData, } @@ -677,6 +692,7 @@ impl MethodRouter { put: self.put.map(layer_fn), trace: self.trace.map(layer_fn), fallback: self.fallback.map(layer_fn), + allow_header: self.allow_header, _request_body: PhantomData, } } @@ -710,6 +726,7 @@ impl MethodRouter { put: self.put.map(layer_fn), trace: self.trace.map(layer_fn), fallback: self.fallback, + allow_header: self.allow_header, _request_body: PhantomData, } } @@ -741,6 +758,7 @@ impl MethodRouter { put, trace, fallback, + allow_header, _request_body: _, } = self; @@ -754,6 +772,7 @@ impl MethodRouter { put: put_other, trace: trace_other, fallback: fallback_other, + allow_header: allow_header_other, _request_body: _, } = other; @@ -775,6 +794,18 @@ impl MethodRouter { } }; + let allow_header = match (allow_header, allow_header_other) { + (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip, + (AllowHeader::None, AllowHeader::None) => AllowHeader::None, + (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick), + (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick), + (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => { + a.extend_from_slice(b","); + a.extend_from_slice(&b); + AllowHeader::Bytes(a) + } + }; + Self { get, head, @@ -785,6 +816,7 @@ impl MethodRouter { put, trace, fallback, + allow_header, _request_body: PhantomData, } } @@ -821,31 +853,41 @@ impl MethodRouter { mut put, mut trace, fallback, + mut allow_header, _request_body: _, } = self; let svc = Some(Route::new(svc)); if filter.contains(MethodFilter::GET) { get = svc.clone(); + append_allow_header(&mut allow_header, "GET"); + append_allow_header(&mut allow_header, "HEAD"); } if filter.contains(MethodFilter::HEAD) { + append_allow_header(&mut allow_header, "HEAD"); head = svc.clone(); } if filter.contains(MethodFilter::DELETE) { + append_allow_header(&mut allow_header, "DELETE"); delete = svc.clone(); } if filter.contains(MethodFilter::OPTIONS) { + append_allow_header(&mut allow_header, "OPTIONS"); options = svc.clone(); } if filter.contains(MethodFilter::PATCH) { + append_allow_header(&mut allow_header, "PATCH"); patch = svc.clone(); } if filter.contains(MethodFilter::POST) { + append_allow_header(&mut allow_header, "POST"); post = svc.clone(); } if filter.contains(MethodFilter::PUT) { + append_allow_header(&mut allow_header, "PUT"); put = svc.clone(); } if filter.contains(MethodFilter::TRACE) { + append_allow_header(&mut allow_header, "TRACE"); trace = svc; } Self { @@ -858,9 +900,35 @@ impl MethodRouter { put, trace, fallback, + allow_header, _request_body: PhantomData, } } + + fn skip_allow_header(mut self) -> Self { + self.allow_header = AllowHeader::Skip; + self + } +} + +fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) { + match allow_header { + AllowHeader::None => { + *allow_header = AllowHeader::Bytes(BytesMut::from(method)); + } + AllowHeader::Skip => {} + AllowHeader::Bytes(allow_header) => { + if let Ok(s) = std::str::from_utf8(allow_header) { + if !s.contains(method) { + allow_header.extend_from_slice(b","); + allow_header.extend_from_slice(method.as_bytes()); + } + } else { + #[cfg(debug_assertions)] + panic!("`allow_header` contained invalid uft-8. This should never happen") + } + } + } } impl Clone for MethodRouter { @@ -875,6 +943,7 @@ impl Clone for MethodRouter { put: self.put.clone(), trace: self.trace.clone(), fallback: self.fallback.clone(), + allow_header: self.allow_header.clone(), _request_body: PhantomData, } } @@ -931,6 +1000,7 @@ impl Service> for MethodRouter { put, trace, fallback, + allow_header, _request_body: _, } = self; @@ -944,13 +1014,18 @@ impl Service> for MethodRouter { call!(req, method, DELETE, delete); call!(req, method, TRACE, trace); - match fallback { - Fallback::Default(fallback) => { - RouteFuture::new(fallback.0.clone().oneshot(req)).strip_body(method == Method::HEAD) - } - Fallback::Custom(fallback) => { - RouteFuture::new(fallback.0.clone().oneshot(req)).strip_body(method == Method::HEAD) - } + let future = + match fallback { + Fallback::Default(fallback) => RouteFuture::new(fallback.0.clone().oneshot(req)) + .strip_body(method == Method::HEAD), + Fallback::Custom(fallback) => RouteFuture::new(fallback.0.clone().oneshot(req)) + .strip_body(method == Method::HEAD), + }; + + match allow_header { + AllowHeader::None => future.allow_header(Bytes::new()), + AllowHeader::Skip => future, + AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()), } } } @@ -959,6 +1034,8 @@ impl Service> for MethodRouter { mod tests { use super::*; use crate::{body::Body, error_handling::HandleErrorLayer}; + use axum_core::response::IntoResponse; + use http::{header::ALLOW, HeaderMap}; use std::time::Duration; use tower::{timeout::TimeoutLayer, Service, ServiceExt}; use tower_http::{auth::RequireAuthorizationLayer, services::fs::ServeDir}; @@ -966,7 +1043,7 @@ mod tests { #[tokio::test] async fn method_not_allowed_by_default() { let mut svc = MethodRouter::new(); - let (status, body) = call(Method::GET, &mut svc).await; + let (status, _, body) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert!(body.is_empty()); } @@ -974,7 +1051,7 @@ mod tests { #[tokio::test] async fn get_handler() { let mut svc = MethodRouter::new().get(ok); - let (status, body) = call(Method::GET, &mut svc).await; + let (status, _, body) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); assert_eq!(body, "ok"); } @@ -982,7 +1059,7 @@ mod tests { #[tokio::test] async fn get_accepts_head() { let mut svc = MethodRouter::new().get(ok); - let (status, body) = call(Method::HEAD, &mut svc).await; + let (status, _, body) = call(Method::HEAD, &mut svc).await; assert_eq!(status, StatusCode::OK); assert!(body.is_empty()); } @@ -990,7 +1067,7 @@ mod tests { #[tokio::test] async fn head_takes_precedence_over_get() { let mut svc = MethodRouter::new().head(created).get(ok); - let (status, body) = call(Method::HEAD, &mut svc).await; + let (status, _, body) = call(Method::HEAD, &mut svc).await; assert_eq!(status, StatusCode::CREATED); assert!(body.is_empty()); } @@ -999,10 +1076,10 @@ mod tests { async fn merge() { let mut svc = get(ok).merge(post(ok)); - let (status, _) = call(Method::GET, &mut svc).await; + let (status, _, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); - let (status, _) = call(Method::POST, &mut svc).await; + let (status, _, _) = call(Method::POST, &mut svc).await; assert_eq!(status, StatusCode::OK); } @@ -1013,11 +1090,11 @@ mod tests { .layer(RequireAuthorizationLayer::bearer("password")); // method with route - let (status, _) = call(Method::GET, &mut svc).await; + let (status, _, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::UNAUTHORIZED); // method without route - let (status, _) = call(Method::DELETE, &mut svc).await; + let (status, _, _) = call(Method::DELETE, &mut svc).await; assert_eq!(status, StatusCode::UNAUTHORIZED); } @@ -1028,11 +1105,11 @@ mod tests { .route_layer(RequireAuthorizationLayer::bearer("password")); // method with route - let (status, _) = call(Method::GET, &mut svc).await; + let (status, _, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::UNAUTHORIZED); // method without route - let (status, _) = call(Method::DELETE, &mut svc).await; + let (status, _, _) = call(Method::DELETE, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); } @@ -1062,7 +1139,93 @@ mod tests { crate::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(app.into_make_service()); } - async fn call(method: Method, svc: &mut S) -> (StatusCode, String) + #[tokio::test] + async fn sets_allow_header() { + let mut svc = MethodRouter::new().put(ok).patch(ok); + let (status, headers, _) = call(Method::GET, &mut svc).await; + assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); + assert_eq!(headers[ALLOW], "PUT,PATCH"); + } + + #[tokio::test] + async fn sets_allow_header_get_head() { + let mut svc = MethodRouter::new().get(ok).head(ok); + let (status, headers, _) = call(Method::PUT, &mut svc).await; + assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); + assert_eq!(headers[ALLOW], "GET,HEAD"); + } + + #[tokio::test] + async fn empty_allow_header_by_default() { + let mut svc = MethodRouter::new(); + let (status, headers, _) = call(Method::PATCH, &mut svc).await; + assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); + assert_eq!(headers[ALLOW], ""); + } + + #[tokio::test] + async fn allow_header_when_merging() { + let a = put(ok).patch(ok); + let b = get(ok).head(ok); + let mut svc = a.merge(b); + + let (status, headers, _) = call(Method::DELETE, &mut svc).await; + assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); + assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD"); + } + + #[tokio::test] + async fn allow_header_any() { + let mut svc = any(ok); + + let (status, headers, _) = call(Method::GET, &mut svc).await; + assert_eq!(status, StatusCode::OK); + assert!(!headers.contains_key(ALLOW)); + } + + #[tokio::test] + async fn allow_header_with_fallback() { + let mut svc = MethodRouter::new().get(ok).fallback( + (|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") }).into_service(), + ); + + let (status, headers, _) = call(Method::DELETE, &mut svc).await; + assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); + assert_eq!(headers[ALLOW], "GET,HEAD"); + } + + #[tokio::test] + async fn allow_header_with_fallback_that_sets_allow() { + async fn fallback(method: Method) -> Response { + if method == Method::POST { + "OK".into_response() + } else { + let headers = crate::response::Headers([(ALLOW, "GET,POST")]); + ( + StatusCode::METHOD_NOT_ALLOWED, + headers, + "Method not allowed", + ) + .into_response() + } + } + + let mut svc = MethodRouter::new() + .get(ok) + .fallback(fallback.into_service()); + + let (status, _, _) = call(Method::GET, &mut svc).await; + assert_eq!(status, StatusCode::OK); + + let (status, _, _) = call(Method::POST, &mut svc).await; + assert_eq!(status, StatusCode::OK); + + let (status, headers, _) = call(Method::DELETE, &mut svc).await; + assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); + assert_eq!(headers[ALLOW], "GET,POST"); + } + + async fn call(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String) where S: Service, Response = Response, Error = Infallible>, { @@ -1074,7 +1237,7 @@ mod tests { let response = svc.ready().await.unwrap().call(request).await.unwrap(); let (parts, body) = response.into_parts(); let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap(); - (parts.status, body) + (parts.status, parts.headers, body) } async fn ok() -> (StatusCode, &'static str) { diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index b5a9a74c..f3b4a168 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -2,7 +2,8 @@ use crate::{ body::{boxed, Body, Empty}, response::Response, }; -use http::Request; +use bytes::Bytes; +use http::{header, HeaderValue, Request}; use pin_project_lite::pin_project; use std::{ convert::Infallible, @@ -70,6 +71,7 @@ pin_project! { Request, >, strip_body: bool, + allow_header: Option, } } @@ -80,6 +82,7 @@ impl RouteFuture { RouteFuture { future, strip_body: false, + allow_header: None, } } @@ -87,22 +90,41 @@ impl RouteFuture { self.strip_body = strip_body; self } + + pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self { + self.allow_header = Some(allow_header); + self + } } impl Future for RouteFuture { type Output = Result; #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let strip_body = self.strip_body; + let allow_header = self.allow_header.take(); match self.project().future.poll(cx) { Poll::Ready(Ok(res)) => { - if strip_body { - Poll::Ready(Ok(res.map(|_| boxed(Empty::new())))) + let mut res = if strip_body { + res.map(|_| boxed(Empty::new())) } else { - Poll::Ready(Ok(res)) + res + }; + + match allow_header { + Some(allow_header) if !res.headers().contains_key(header::ALLOW) => { + res.headers_mut().insert( + header::ALLOW, + HeaderValue::from_maybe_shared(allow_header) + .expect("invalid `Allow` header"), + ); + } + _ => {} } + + Poll::Ready(Ok(res)) } Poll::Ready(Err(err)) => Poll::Ready(Err(err)), Poll::Pending => Poll::Pending,