Set Allow header when returning 405 Method Not Allowed (#733)

This commit is contained in:
David Pedersen 2022-01-30 21:41:34 +01:00 committed by GitHub
parent a04fc42d75
commit 513766b9e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 221 additions and 26 deletions

View file

@ -51,6 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `PathRejection` - `PathRejection`
- `MatchedPathRejection` - `MatchedPathRejection`
- `WebSocketUpgradeRejection` - `WebSocketUpgradeRejection`
- **fixed:** Set `Allow` header when responding with `405 Method Not Allowed`
[#644]: https://github.com/tokio-rs/axum/pull/644 [#644]: https://github.com/tokio-rs/axum/pull/644
[#665]: https://github.com/tokio-rs/axum/pull/665 [#665]: https://github.com/tokio-rs/axum/pull/665

View file

@ -51,3 +51,12 @@ async fn fallback_two() -> impl IntoResponse {}
# hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# }; # };
``` ```
## Setting the `Allow` header
By default `MethodRouter` will set the `Allow` header when returning `405 Method
Not Allowed`. This is also done when the fallback is used unless the response
generated by the fallback already sets the `Allow` header.
This means if you use `fallback` to accept additional methods make sure you set
the `Allow` header correctly.

View file

@ -7,6 +7,7 @@ use crate::{
routing::{Fallback, MethodFilter, Route}, routing::{Fallback, MethodFilter, Route},
BoxError, BoxError,
}; };
use bytes::BytesMut;
use std::{ use std::{
convert::Infallible, convert::Infallible,
fmt, fmt,
@ -386,7 +387,7 @@ where
ResBody: HttpBody<Data = Bytes> + Send + 'static, ResBody: HttpBody<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>, ResBody::Error: Into<BoxError>,
{ {
MethodRouter::new().fallback(svc) MethodRouter::new().fallback(svc).skip_allow_header()
} }
top_level_handler_fn!(delete, DELETE); top_level_handler_fn!(delete, DELETE);
@ -469,7 +470,9 @@ where
B: Send + 'static, B: Send + 'static,
T: 'static, T: 'static,
{ {
MethodRouter::new().fallback_boxed_response_body(handler.into_service()) MethodRouter::new()
.fallback_boxed_response_body(handler.into_service())
.skip_allow_header()
} }
/// A [`Service`] that accepts requests based on a [`MethodFilter`] and /// A [`Service`] that accepts requests based on a [`MethodFilter`] and
@ -484,9 +487,20 @@ pub struct MethodRouter<B = Body, E = Infallible> {
put: Option<Route<B, E>>, put: Option<Route<B, E>>,
trace: Option<Route<B, E>>, trace: Option<Route<B, E>>,
fallback: Fallback<B, E>, fallback: Fallback<B, E>,
allow_header: AllowHeader,
_request_body: PhantomData<fn() -> (B, E)>, _request_body: PhantomData<fn() -> (B, E)>,
} }
#[derive(Clone)]
enum AllowHeader {
/// No `Allow` header value has been built-up yet. This is the default state
None,
/// Don't set an `Allow` header. This is used when `any` or `any_service` are called.
Skip,
/// The current value of the `Allow` header.
Bytes(BytesMut),
}
impl<B, E> fmt::Debug for MethodRouter<B, E> { impl<B, E> fmt::Debug for MethodRouter<B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MethodRouter") f.debug_struct("MethodRouter")
@ -522,6 +536,7 @@ impl<B, E> MethodRouter<B, E> {
post: None, post: None,
put: None, put: None,
trace: None, trace: None,
allow_header: AllowHeader::None,
fallback: Fallback::Default(fallback), fallback: Fallback::Default(fallback),
_request_body: PhantomData, _request_body: PhantomData,
} }
@ -677,6 +692,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
put: self.put.map(layer_fn), put: self.put.map(layer_fn),
trace: self.trace.map(layer_fn), trace: self.trace.map(layer_fn),
fallback: self.fallback.map(layer_fn), fallback: self.fallback.map(layer_fn),
allow_header: self.allow_header,
_request_body: PhantomData, _request_body: PhantomData,
} }
} }
@ -710,6 +726,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
put: self.put.map(layer_fn), put: self.put.map(layer_fn),
trace: self.trace.map(layer_fn), trace: self.trace.map(layer_fn),
fallback: self.fallback, fallback: self.fallback,
allow_header: self.allow_header,
_request_body: PhantomData, _request_body: PhantomData,
} }
} }
@ -741,6 +758,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
put, put,
trace, trace,
fallback, fallback,
allow_header,
_request_body: _, _request_body: _,
} = self; } = self;
@ -754,6 +772,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
put: put_other, put: put_other,
trace: trace_other, trace: trace_other,
fallback: fallback_other, fallback: fallback_other,
allow_header: allow_header_other,
_request_body: _, _request_body: _,
} = other; } = other;
@ -775,6 +794,18 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
} }
}; };
let allow_header = match (allow_header, allow_header_other) {
(AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
(AllowHeader::None, AllowHeader::None) => AllowHeader::None,
(AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
(AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
(AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
a.extend_from_slice(b",");
a.extend_from_slice(&b);
AllowHeader::Bytes(a)
}
};
Self { Self {
get, get,
head, head,
@ -785,6 +816,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
put, put,
trace, trace,
fallback, fallback,
allow_header,
_request_body: PhantomData, _request_body: PhantomData,
} }
} }
@ -821,31 +853,41 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
mut put, mut put,
mut trace, mut trace,
fallback, fallback,
mut allow_header,
_request_body: _, _request_body: _,
} = self; } = self;
let svc = Some(Route::new(svc)); let svc = Some(Route::new(svc));
if filter.contains(MethodFilter::GET) { if filter.contains(MethodFilter::GET) {
get = svc.clone(); get = svc.clone();
append_allow_header(&mut allow_header, "GET");
append_allow_header(&mut allow_header, "HEAD");
} }
if filter.contains(MethodFilter::HEAD) { if filter.contains(MethodFilter::HEAD) {
append_allow_header(&mut allow_header, "HEAD");
head = svc.clone(); head = svc.clone();
} }
if filter.contains(MethodFilter::DELETE) { if filter.contains(MethodFilter::DELETE) {
append_allow_header(&mut allow_header, "DELETE");
delete = svc.clone(); delete = svc.clone();
} }
if filter.contains(MethodFilter::OPTIONS) { if filter.contains(MethodFilter::OPTIONS) {
append_allow_header(&mut allow_header, "OPTIONS");
options = svc.clone(); options = svc.clone();
} }
if filter.contains(MethodFilter::PATCH) { if filter.contains(MethodFilter::PATCH) {
append_allow_header(&mut allow_header, "PATCH");
patch = svc.clone(); patch = svc.clone();
} }
if filter.contains(MethodFilter::POST) { if filter.contains(MethodFilter::POST) {
append_allow_header(&mut allow_header, "POST");
post = svc.clone(); post = svc.clone();
} }
if filter.contains(MethodFilter::PUT) { if filter.contains(MethodFilter::PUT) {
append_allow_header(&mut allow_header, "PUT");
put = svc.clone(); put = svc.clone();
} }
if filter.contains(MethodFilter::TRACE) { if filter.contains(MethodFilter::TRACE) {
append_allow_header(&mut allow_header, "TRACE");
trace = svc; trace = svc;
} }
Self { Self {
@ -858,9 +900,35 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
put, put,
trace, trace,
fallback, fallback,
allow_header,
_request_body: PhantomData, _request_body: PhantomData,
} }
} }
fn skip_allow_header(mut self) -> Self {
self.allow_header = AllowHeader::Skip;
self
}
}
fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
match allow_header {
AllowHeader::None => {
*allow_header = AllowHeader::Bytes(BytesMut::from(method));
}
AllowHeader::Skip => {}
AllowHeader::Bytes(allow_header) => {
if let Ok(s) = std::str::from_utf8(allow_header) {
if !s.contains(method) {
allow_header.extend_from_slice(b",");
allow_header.extend_from_slice(method.as_bytes());
}
} else {
#[cfg(debug_assertions)]
panic!("`allow_header` contained invalid uft-8. This should never happen")
}
}
}
} }
impl<B, E> Clone for MethodRouter<B, E> { impl<B, E> Clone for MethodRouter<B, E> {
@ -875,6 +943,7 @@ impl<B, E> Clone for MethodRouter<B, E> {
put: self.put.clone(), put: self.put.clone(),
trace: self.trace.clone(), trace: self.trace.clone(),
fallback: self.fallback.clone(), fallback: self.fallback.clone(),
allow_header: self.allow_header.clone(),
_request_body: PhantomData, _request_body: PhantomData,
} }
} }
@ -931,6 +1000,7 @@ impl<B, E> Service<Request<B>> for MethodRouter<B, E> {
put, put,
trace, trace,
fallback, fallback,
allow_header,
_request_body: _, _request_body: _,
} = self; } = self;
@ -944,13 +1014,18 @@ impl<B, E> Service<Request<B>> for MethodRouter<B, E> {
call!(req, method, DELETE, delete); call!(req, method, DELETE, delete);
call!(req, method, TRACE, trace); call!(req, method, TRACE, trace);
match fallback { let future =
Fallback::Default(fallback) => { match fallback {
RouteFuture::new(fallback.0.clone().oneshot(req)).strip_body(method == Method::HEAD) Fallback::Default(fallback) => RouteFuture::new(fallback.0.clone().oneshot(req))
} .strip_body(method == Method::HEAD),
Fallback::Custom(fallback) => { Fallback::Custom(fallback) => RouteFuture::new(fallback.0.clone().oneshot(req))
RouteFuture::new(fallback.0.clone().oneshot(req)).strip_body(method == Method::HEAD) .strip_body(method == Method::HEAD),
} };
match allow_header {
AllowHeader::None => future.allow_header(Bytes::new()),
AllowHeader::Skip => future,
AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
} }
} }
} }
@ -959,6 +1034,8 @@ impl<B, E> Service<Request<B>> for MethodRouter<B, E> {
mod tests { mod tests {
use super::*; use super::*;
use crate::{body::Body, error_handling::HandleErrorLayer}; use crate::{body::Body, error_handling::HandleErrorLayer};
use axum_core::response::IntoResponse;
use http::{header::ALLOW, HeaderMap};
use std::time::Duration; use std::time::Duration;
use tower::{timeout::TimeoutLayer, Service, ServiceExt}; use tower::{timeout::TimeoutLayer, Service, ServiceExt};
use tower_http::{auth::RequireAuthorizationLayer, services::fs::ServeDir}; use tower_http::{auth::RequireAuthorizationLayer, services::fs::ServeDir};
@ -966,7 +1043,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn method_not_allowed_by_default() { async fn method_not_allowed_by_default() {
let mut svc = MethodRouter::new(); let mut svc = MethodRouter::new();
let (status, body) = call(Method::GET, &mut svc).await; let (status, _, body) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert!(body.is_empty()); assert!(body.is_empty());
} }
@ -974,7 +1051,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn get_handler() { async fn get_handler() {
let mut svc = MethodRouter::new().get(ok); let mut svc = MethodRouter::new().get(ok);
let (status, body) = call(Method::GET, &mut svc).await; let (status, _, body) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK); assert_eq!(status, StatusCode::OK);
assert_eq!(body, "ok"); assert_eq!(body, "ok");
} }
@ -982,7 +1059,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn get_accepts_head() { async fn get_accepts_head() {
let mut svc = MethodRouter::new().get(ok); let mut svc = MethodRouter::new().get(ok);
let (status, body) = call(Method::HEAD, &mut svc).await; let (status, _, body) = call(Method::HEAD, &mut svc).await;
assert_eq!(status, StatusCode::OK); assert_eq!(status, StatusCode::OK);
assert!(body.is_empty()); assert!(body.is_empty());
} }
@ -990,7 +1067,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn head_takes_precedence_over_get() { async fn head_takes_precedence_over_get() {
let mut svc = MethodRouter::new().head(created).get(ok); let mut svc = MethodRouter::new().head(created).get(ok);
let (status, body) = call(Method::HEAD, &mut svc).await; let (status, _, body) = call(Method::HEAD, &mut svc).await;
assert_eq!(status, StatusCode::CREATED); assert_eq!(status, StatusCode::CREATED);
assert!(body.is_empty()); assert!(body.is_empty());
} }
@ -999,10 +1076,10 @@ mod tests {
async fn merge() { async fn merge() {
let mut svc = get(ok).merge(post(ok)); let mut svc = get(ok).merge(post(ok));
let (status, _) = call(Method::GET, &mut svc).await; let (status, _, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK); assert_eq!(status, StatusCode::OK);
let (status, _) = call(Method::POST, &mut svc).await; let (status, _, _) = call(Method::POST, &mut svc).await;
assert_eq!(status, StatusCode::OK); assert_eq!(status, StatusCode::OK);
} }
@ -1013,11 +1090,11 @@ mod tests {
.layer(RequireAuthorizationLayer::bearer("password")); .layer(RequireAuthorizationLayer::bearer("password"));
// method with route // method with route
let (status, _) = call(Method::GET, &mut svc).await; let (status, _, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::UNAUTHORIZED); assert_eq!(status, StatusCode::UNAUTHORIZED);
// method without route // method without route
let (status, _) = call(Method::DELETE, &mut svc).await; let (status, _, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::UNAUTHORIZED); assert_eq!(status, StatusCode::UNAUTHORIZED);
} }
@ -1028,11 +1105,11 @@ mod tests {
.route_layer(RequireAuthorizationLayer::bearer("password")); .route_layer(RequireAuthorizationLayer::bearer("password"));
// method with route // method with route
let (status, _) = call(Method::GET, &mut svc).await; let (status, _, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::UNAUTHORIZED); assert_eq!(status, StatusCode::UNAUTHORIZED);
// method without route // method without route
let (status, _) = call(Method::DELETE, &mut svc).await; let (status, _, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
} }
@ -1062,7 +1139,93 @@ mod tests {
crate::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(app.into_make_service()); crate::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(app.into_make_service());
} }
async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, String) #[tokio::test]
async fn sets_allow_header() {
let mut svc = MethodRouter::new().put(ok).patch(ok);
let (status, headers, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "PUT,PATCH");
}
#[tokio::test]
async fn sets_allow_header_get_head() {
let mut svc = MethodRouter::new().get(ok).head(ok);
let (status, headers, _) = call(Method::PUT, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "GET,HEAD");
}
#[tokio::test]
async fn empty_allow_header_by_default() {
let mut svc = MethodRouter::new();
let (status, headers, _) = call(Method::PATCH, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "");
}
#[tokio::test]
async fn allow_header_when_merging() {
let a = put(ok).patch(ok);
let b = get(ok).head(ok);
let mut svc = a.merge(b);
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
}
#[tokio::test]
async fn allow_header_any() {
let mut svc = any(ok);
let (status, headers, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK);
assert!(!headers.contains_key(ALLOW));
}
#[tokio::test]
async fn allow_header_with_fallback() {
let mut svc = MethodRouter::new().get(ok).fallback(
(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") }).into_service(),
);
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "GET,HEAD");
}
#[tokio::test]
async fn allow_header_with_fallback_that_sets_allow() {
async fn fallback(method: Method) -> Response {
if method == Method::POST {
"OK".into_response()
} else {
let headers = crate::response::Headers([(ALLOW, "GET,POST")]);
(
StatusCode::METHOD_NOT_ALLOWED,
headers,
"Method not allowed",
)
.into_response()
}
}
let mut svc = MethodRouter::new()
.get(ok)
.fallback(fallback.into_service());
let (status, _, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK);
let (status, _, _) = call(Method::POST, &mut svc).await;
assert_eq!(status, StatusCode::OK);
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "GET,POST");
}
async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
where where
S: Service<Request<Body>, Response = Response, Error = Infallible>, S: Service<Request<Body>, Response = Response, Error = Infallible>,
{ {
@ -1074,7 +1237,7 @@ mod tests {
let response = svc.ready().await.unwrap().call(request).await.unwrap(); let response = svc.ready().await.unwrap().call(request).await.unwrap();
let (parts, body) = response.into_parts(); let (parts, body) = response.into_parts();
let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap(); let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap();
(parts.status, body) (parts.status, parts.headers, body)
} }
async fn ok() -> (StatusCode, &'static str) { async fn ok() -> (StatusCode, &'static str) {

View file

@ -2,7 +2,8 @@ use crate::{
body::{boxed, Body, Empty}, body::{boxed, Body, Empty},
response::Response, response::Response,
}; };
use http::Request; use bytes::Bytes;
use http::{header, HeaderValue, Request};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use std::{ use std::{
convert::Infallible, convert::Infallible,
@ -70,6 +71,7 @@ pin_project! {
Request<B>, Request<B>,
>, >,
strip_body: bool, strip_body: bool,
allow_header: Option<Bytes>,
} }
} }
@ -80,6 +82,7 @@ impl<B, E> RouteFuture<B, E> {
RouteFuture { RouteFuture {
future, future,
strip_body: false, strip_body: false,
allow_header: None,
} }
} }
@ -87,22 +90,41 @@ impl<B, E> RouteFuture<B, E> {
self.strip_body = strip_body; self.strip_body = strip_body;
self self
} }
pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
self.allow_header = Some(allow_header);
self
}
} }
impl<B, E> Future for RouteFuture<B, E> { impl<B, E> Future for RouteFuture<B, E> {
type Output = Result<Response, E>; type Output = Result<Response, E>;
#[inline] #[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let strip_body = self.strip_body; let strip_body = self.strip_body;
let allow_header = self.allow_header.take();
match self.project().future.poll(cx) { match self.project().future.poll(cx) {
Poll::Ready(Ok(res)) => { Poll::Ready(Ok(res)) => {
if strip_body { let mut res = if strip_body {
Poll::Ready(Ok(res.map(|_| boxed(Empty::new())))) res.map(|_| boxed(Empty::new()))
} else { } else {
Poll::Ready(Ok(res)) res
};
match allow_header {
Some(allow_header) if !res.headers().contains_key(header::ALLOW) => {
res.headers_mut().insert(
header::ALLOW,
HeaderValue::from_maybe_shared(allow_header)
.expect("invalid `Allow` header"),
);
}
_ => {}
} }
Poll::Ready(Ok(res))
} }
Poll::Ready(Err(err)) => Poll::Ready(Err(err)), Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,