Set Allow header when returning 405 Method Not Allowed (#733)

This commit is contained in:
David Pedersen 2022-01-30 21:41:34 +01:00 committed by GitHub
parent a04fc42d75
commit 513766b9e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 221 additions and 26 deletions

View file

@ -51,6 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `PathRejection`
- `MatchedPathRejection`
- `WebSocketUpgradeRejection`
- **fixed:** Set `Allow` header when responding with `405 Method Not Allowed`
[#644]: https://github.com/tokio-rs/axum/pull/644
[#665]: https://github.com/tokio-rs/axum/pull/665

View file

@ -51,3 +51,12 @@ async fn fallback_two() -> impl IntoResponse {}
# hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
```
## Setting the `Allow` header
By default `MethodRouter` will set the `Allow` header when returning `405 Method
Not Allowed`. This is also done when the fallback is used unless the response
generated by the fallback already sets the `Allow` header.
This means if you use `fallback` to accept additional methods make sure you set
the `Allow` header correctly.

View file

@ -7,6 +7,7 @@ use crate::{
routing::{Fallback, MethodFilter, Route},
BoxError,
};
use bytes::BytesMut;
use std::{
convert::Infallible,
fmt,
@ -386,7 +387,7 @@ where
ResBody: HttpBody<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
{
MethodRouter::new().fallback(svc)
MethodRouter::new().fallback(svc).skip_allow_header()
}
top_level_handler_fn!(delete, DELETE);
@ -469,7 +470,9 @@ where
B: Send + 'static,
T: 'static,
{
MethodRouter::new().fallback_boxed_response_body(handler.into_service())
MethodRouter::new()
.fallback_boxed_response_body(handler.into_service())
.skip_allow_header()
}
/// A [`Service`] that accepts requests based on a [`MethodFilter`] and
@ -484,9 +487,20 @@ pub struct MethodRouter<B = Body, E = Infallible> {
put: Option<Route<B, E>>,
trace: Option<Route<B, E>>,
fallback: Fallback<B, E>,
allow_header: AllowHeader,
_request_body: PhantomData<fn() -> (B, E)>,
}
#[derive(Clone)]
enum AllowHeader {
/// No `Allow` header value has been built-up yet. This is the default state
None,
/// Don't set an `Allow` header. This is used when `any` or `any_service` are called.
Skip,
/// The current value of the `Allow` header.
Bytes(BytesMut),
}
impl<B, E> fmt::Debug for MethodRouter<B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MethodRouter")
@ -522,6 +536,7 @@ impl<B, E> MethodRouter<B, E> {
post: None,
put: None,
trace: None,
allow_header: AllowHeader::None,
fallback: Fallback::Default(fallback),
_request_body: PhantomData,
}
@ -677,6 +692,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
put: self.put.map(layer_fn),
trace: self.trace.map(layer_fn),
fallback: self.fallback.map(layer_fn),
allow_header: self.allow_header,
_request_body: PhantomData,
}
}
@ -710,6 +726,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
put: self.put.map(layer_fn),
trace: self.trace.map(layer_fn),
fallback: self.fallback,
allow_header: self.allow_header,
_request_body: PhantomData,
}
}
@ -741,6 +758,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
put,
trace,
fallback,
allow_header,
_request_body: _,
} = self;
@ -754,6 +772,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
put: put_other,
trace: trace_other,
fallback: fallback_other,
allow_header: allow_header_other,
_request_body: _,
} = other;
@ -775,6 +794,18 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
}
};
let allow_header = match (allow_header, allow_header_other) {
(AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
(AllowHeader::None, AllowHeader::None) => AllowHeader::None,
(AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
(AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
(AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
a.extend_from_slice(b",");
a.extend_from_slice(&b);
AllowHeader::Bytes(a)
}
};
Self {
get,
head,
@ -785,6 +816,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
put,
trace,
fallback,
allow_header,
_request_body: PhantomData,
}
}
@ -821,31 +853,41 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
mut put,
mut trace,
fallback,
mut allow_header,
_request_body: _,
} = self;
let svc = Some(Route::new(svc));
if filter.contains(MethodFilter::GET) {
get = svc.clone();
append_allow_header(&mut allow_header, "GET");
append_allow_header(&mut allow_header, "HEAD");
}
if filter.contains(MethodFilter::HEAD) {
append_allow_header(&mut allow_header, "HEAD");
head = svc.clone();
}
if filter.contains(MethodFilter::DELETE) {
append_allow_header(&mut allow_header, "DELETE");
delete = svc.clone();
}
if filter.contains(MethodFilter::OPTIONS) {
append_allow_header(&mut allow_header, "OPTIONS");
options = svc.clone();
}
if filter.contains(MethodFilter::PATCH) {
append_allow_header(&mut allow_header, "PATCH");
patch = svc.clone();
}
if filter.contains(MethodFilter::POST) {
append_allow_header(&mut allow_header, "POST");
post = svc.clone();
}
if filter.contains(MethodFilter::PUT) {
append_allow_header(&mut allow_header, "PUT");
put = svc.clone();
}
if filter.contains(MethodFilter::TRACE) {
append_allow_header(&mut allow_header, "TRACE");
trace = svc;
}
Self {
@ -858,9 +900,35 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
put,
trace,
fallback,
allow_header,
_request_body: PhantomData,
}
}
fn skip_allow_header(mut self) -> Self {
self.allow_header = AllowHeader::Skip;
self
}
}
fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
match allow_header {
AllowHeader::None => {
*allow_header = AllowHeader::Bytes(BytesMut::from(method));
}
AllowHeader::Skip => {}
AllowHeader::Bytes(allow_header) => {
if let Ok(s) = std::str::from_utf8(allow_header) {
if !s.contains(method) {
allow_header.extend_from_slice(b",");
allow_header.extend_from_slice(method.as_bytes());
}
} else {
#[cfg(debug_assertions)]
panic!("`allow_header` contained invalid uft-8. This should never happen")
}
}
}
}
impl<B, E> Clone for MethodRouter<B, E> {
@ -875,6 +943,7 @@ impl<B, E> Clone for MethodRouter<B, E> {
put: self.put.clone(),
trace: self.trace.clone(),
fallback: self.fallback.clone(),
allow_header: self.allow_header.clone(),
_request_body: PhantomData,
}
}
@ -931,6 +1000,7 @@ impl<B, E> Service<Request<B>> for MethodRouter<B, E> {
put,
trace,
fallback,
allow_header,
_request_body: _,
} = self;
@ -944,13 +1014,18 @@ impl<B, E> Service<Request<B>> for MethodRouter<B, E> {
call!(req, method, DELETE, delete);
call!(req, method, TRACE, trace);
match fallback {
Fallback::Default(fallback) => {
RouteFuture::new(fallback.0.clone().oneshot(req)).strip_body(method == Method::HEAD)
}
Fallback::Custom(fallback) => {
RouteFuture::new(fallback.0.clone().oneshot(req)).strip_body(method == Method::HEAD)
}
let future =
match fallback {
Fallback::Default(fallback) => RouteFuture::new(fallback.0.clone().oneshot(req))
.strip_body(method == Method::HEAD),
Fallback::Custom(fallback) => RouteFuture::new(fallback.0.clone().oneshot(req))
.strip_body(method == Method::HEAD),
};
match allow_header {
AllowHeader::None => future.allow_header(Bytes::new()),
AllowHeader::Skip => future,
AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
}
}
}
@ -959,6 +1034,8 @@ impl<B, E> Service<Request<B>> for MethodRouter<B, E> {
mod tests {
use super::*;
use crate::{body::Body, error_handling::HandleErrorLayer};
use axum_core::response::IntoResponse;
use http::{header::ALLOW, HeaderMap};
use std::time::Duration;
use tower::{timeout::TimeoutLayer, Service, ServiceExt};
use tower_http::{auth::RequireAuthorizationLayer, services::fs::ServeDir};
@ -966,7 +1043,7 @@ mod tests {
#[tokio::test]
async fn method_not_allowed_by_default() {
let mut svc = MethodRouter::new();
let (status, body) = call(Method::GET, &mut svc).await;
let (status, _, body) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert!(body.is_empty());
}
@ -974,7 +1051,7 @@ mod tests {
#[tokio::test]
async fn get_handler() {
let mut svc = MethodRouter::new().get(ok);
let (status, body) = call(Method::GET, &mut svc).await;
let (status, _, body) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "ok");
}
@ -982,7 +1059,7 @@ mod tests {
#[tokio::test]
async fn get_accepts_head() {
let mut svc = MethodRouter::new().get(ok);
let (status, body) = call(Method::HEAD, &mut svc).await;
let (status, _, body) = call(Method::HEAD, &mut svc).await;
assert_eq!(status, StatusCode::OK);
assert!(body.is_empty());
}
@ -990,7 +1067,7 @@ mod tests {
#[tokio::test]
async fn head_takes_precedence_over_get() {
let mut svc = MethodRouter::new().head(created).get(ok);
let (status, body) = call(Method::HEAD, &mut svc).await;
let (status, _, body) = call(Method::HEAD, &mut svc).await;
assert_eq!(status, StatusCode::CREATED);
assert!(body.is_empty());
}
@ -999,10 +1076,10 @@ mod tests {
async fn merge() {
let mut svc = get(ok).merge(post(ok));
let (status, _) = call(Method::GET, &mut svc).await;
let (status, _, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK);
let (status, _) = call(Method::POST, &mut svc).await;
let (status, _, _) = call(Method::POST, &mut svc).await;
assert_eq!(status, StatusCode::OK);
}
@ -1013,11 +1090,11 @@ mod tests {
.layer(RequireAuthorizationLayer::bearer("password"));
// method with route
let (status, _) = call(Method::GET, &mut svc).await;
let (status, _, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
// method without route
let (status, _) = call(Method::DELETE, &mut svc).await;
let (status, _, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
}
@ -1028,11 +1105,11 @@ mod tests {
.route_layer(RequireAuthorizationLayer::bearer("password"));
// method with route
let (status, _) = call(Method::GET, &mut svc).await;
let (status, _, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
// method without route
let (status, _) = call(Method::DELETE, &mut svc).await;
let (status, _, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
}
@ -1062,7 +1139,93 @@ mod tests {
crate::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(app.into_make_service());
}
async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, String)
#[tokio::test]
async fn sets_allow_header() {
let mut svc = MethodRouter::new().put(ok).patch(ok);
let (status, headers, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "PUT,PATCH");
}
#[tokio::test]
async fn sets_allow_header_get_head() {
let mut svc = MethodRouter::new().get(ok).head(ok);
let (status, headers, _) = call(Method::PUT, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "GET,HEAD");
}
#[tokio::test]
async fn empty_allow_header_by_default() {
let mut svc = MethodRouter::new();
let (status, headers, _) = call(Method::PATCH, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "");
}
#[tokio::test]
async fn allow_header_when_merging() {
let a = put(ok).patch(ok);
let b = get(ok).head(ok);
let mut svc = a.merge(b);
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
}
#[tokio::test]
async fn allow_header_any() {
let mut svc = any(ok);
let (status, headers, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK);
assert!(!headers.contains_key(ALLOW));
}
#[tokio::test]
async fn allow_header_with_fallback() {
let mut svc = MethodRouter::new().get(ok).fallback(
(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") }).into_service(),
);
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "GET,HEAD");
}
#[tokio::test]
async fn allow_header_with_fallback_that_sets_allow() {
async fn fallback(method: Method) -> Response {
if method == Method::POST {
"OK".into_response()
} else {
let headers = crate::response::Headers([(ALLOW, "GET,POST")]);
(
StatusCode::METHOD_NOT_ALLOWED,
headers,
"Method not allowed",
)
.into_response()
}
}
let mut svc = MethodRouter::new()
.get(ok)
.fallback(fallback.into_service());
let (status, _, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK);
let (status, _, _) = call(Method::POST, &mut svc).await;
assert_eq!(status, StatusCode::OK);
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "GET,POST");
}
async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
where
S: Service<Request<Body>, Response = Response, Error = Infallible>,
{
@ -1074,7 +1237,7 @@ mod tests {
let response = svc.ready().await.unwrap().call(request).await.unwrap();
let (parts, body) = response.into_parts();
let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap();
(parts.status, body)
(parts.status, parts.headers, body)
}
async fn ok() -> (StatusCode, &'static str) {

View file

@ -2,7 +2,8 @@ use crate::{
body::{boxed, Body, Empty},
response::Response,
};
use http::Request;
use bytes::Bytes;
use http::{header, HeaderValue, Request};
use pin_project_lite::pin_project;
use std::{
convert::Infallible,
@ -70,6 +71,7 @@ pin_project! {
Request<B>,
>,
strip_body: bool,
allow_header: Option<Bytes>,
}
}
@ -80,6 +82,7 @@ impl<B, E> RouteFuture<B, E> {
RouteFuture {
future,
strip_body: false,
allow_header: None,
}
}
@ -87,22 +90,41 @@ impl<B, E> RouteFuture<B, E> {
self.strip_body = strip_body;
self
}
pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
self.allow_header = Some(allow_header);
self
}
}
impl<B, E> Future for RouteFuture<B, E> {
type Output = Result<Response, E>;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let strip_body = self.strip_body;
let allow_header = self.allow_header.take();
match self.project().future.poll(cx) {
Poll::Ready(Ok(res)) => {
if strip_body {
Poll::Ready(Ok(res.map(|_| boxed(Empty::new()))))
let mut res = if strip_body {
res.map(|_| boxed(Empty::new()))
} else {
Poll::Ready(Ok(res))
res
};
match allow_header {
Some(allow_header) if !res.headers().contains_key(header::ALLOW) => {
res.headers_mut().insert(
header::ALLOW,
HeaderValue::from_maybe_shared(allow_header)
.expect("invalid `Allow` header"),
);
}
_ => {}
}
Poll::Ready(Ok(res))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => Poll::Pending,