mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-31 16:40:42 +01:00
Set Allow
header when returning 405 Method Not Allowed
(#733)
This commit is contained in:
parent
a04fc42d75
commit
513766b9e7
4 changed files with 221 additions and 26 deletions
|
@ -51,6 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- `PathRejection`
|
||||
- `MatchedPathRejection`
|
||||
- `WebSocketUpgradeRejection`
|
||||
- **fixed:** Set `Allow` header when responding with `405 Method Not Allowed`
|
||||
|
||||
[#644]: https://github.com/tokio-rs/axum/pull/644
|
||||
[#665]: https://github.com/tokio-rs/axum/pull/665
|
||||
|
|
|
@ -51,3 +51,12 @@ async fn fallback_two() -> impl IntoResponse {}
|
|||
# hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
# };
|
||||
```
|
||||
|
||||
## Setting the `Allow` header
|
||||
|
||||
By default `MethodRouter` will set the `Allow` header when returning `405 Method
|
||||
Not Allowed`. This is also done when the fallback is used unless the response
|
||||
generated by the fallback already sets the `Allow` header.
|
||||
|
||||
This means if you use `fallback` to accept additional methods make sure you set
|
||||
the `Allow` header correctly.
|
||||
|
|
|
@ -7,6 +7,7 @@ use crate::{
|
|||
routing::{Fallback, MethodFilter, Route},
|
||||
BoxError,
|
||||
};
|
||||
use bytes::BytesMut;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
fmt,
|
||||
|
@ -386,7 +387,7 @@ where
|
|||
ResBody: HttpBody<Data = Bytes> + Send + 'static,
|
||||
ResBody::Error: Into<BoxError>,
|
||||
{
|
||||
MethodRouter::new().fallback(svc)
|
||||
MethodRouter::new().fallback(svc).skip_allow_header()
|
||||
}
|
||||
|
||||
top_level_handler_fn!(delete, DELETE);
|
||||
|
@ -469,7 +470,9 @@ where
|
|||
B: Send + 'static,
|
||||
T: 'static,
|
||||
{
|
||||
MethodRouter::new().fallback_boxed_response_body(handler.into_service())
|
||||
MethodRouter::new()
|
||||
.fallback_boxed_response_body(handler.into_service())
|
||||
.skip_allow_header()
|
||||
}
|
||||
|
||||
/// A [`Service`] that accepts requests based on a [`MethodFilter`] and
|
||||
|
@ -484,9 +487,20 @@ pub struct MethodRouter<B = Body, E = Infallible> {
|
|||
put: Option<Route<B, E>>,
|
||||
trace: Option<Route<B, E>>,
|
||||
fallback: Fallback<B, E>,
|
||||
allow_header: AllowHeader,
|
||||
_request_body: PhantomData<fn() -> (B, E)>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum AllowHeader {
|
||||
/// No `Allow` header value has been built-up yet. This is the default state
|
||||
None,
|
||||
/// Don't set an `Allow` header. This is used when `any` or `any_service` are called.
|
||||
Skip,
|
||||
/// The current value of the `Allow` header.
|
||||
Bytes(BytesMut),
|
||||
}
|
||||
|
||||
impl<B, E> fmt::Debug for MethodRouter<B, E> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("MethodRouter")
|
||||
|
@ -522,6 +536,7 @@ impl<B, E> MethodRouter<B, E> {
|
|||
post: None,
|
||||
put: None,
|
||||
trace: None,
|
||||
allow_header: AllowHeader::None,
|
||||
fallback: Fallback::Default(fallback),
|
||||
_request_body: PhantomData,
|
||||
}
|
||||
|
@ -677,6 +692,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
|||
put: self.put.map(layer_fn),
|
||||
trace: self.trace.map(layer_fn),
|
||||
fallback: self.fallback.map(layer_fn),
|
||||
allow_header: self.allow_header,
|
||||
_request_body: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -710,6 +726,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
|||
put: self.put.map(layer_fn),
|
||||
trace: self.trace.map(layer_fn),
|
||||
fallback: self.fallback,
|
||||
allow_header: self.allow_header,
|
||||
_request_body: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -741,6 +758,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
|||
put,
|
||||
trace,
|
||||
fallback,
|
||||
allow_header,
|
||||
_request_body: _,
|
||||
} = self;
|
||||
|
||||
|
@ -754,6 +772,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
|||
put: put_other,
|
||||
trace: trace_other,
|
||||
fallback: fallback_other,
|
||||
allow_header: allow_header_other,
|
||||
_request_body: _,
|
||||
} = other;
|
||||
|
||||
|
@ -775,6 +794,18 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
|||
}
|
||||
};
|
||||
|
||||
let allow_header = match (allow_header, allow_header_other) {
|
||||
(AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
|
||||
(AllowHeader::None, AllowHeader::None) => AllowHeader::None,
|
||||
(AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
|
||||
(AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
|
||||
(AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
|
||||
a.extend_from_slice(b",");
|
||||
a.extend_from_slice(&b);
|
||||
AllowHeader::Bytes(a)
|
||||
}
|
||||
};
|
||||
|
||||
Self {
|
||||
get,
|
||||
head,
|
||||
|
@ -785,6 +816,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
|||
put,
|
||||
trace,
|
||||
fallback,
|
||||
allow_header,
|
||||
_request_body: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -821,31 +853,41 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
|||
mut put,
|
||||
mut trace,
|
||||
fallback,
|
||||
mut allow_header,
|
||||
_request_body: _,
|
||||
} = self;
|
||||
let svc = Some(Route::new(svc));
|
||||
if filter.contains(MethodFilter::GET) {
|
||||
get = svc.clone();
|
||||
append_allow_header(&mut allow_header, "GET");
|
||||
append_allow_header(&mut allow_header, "HEAD");
|
||||
}
|
||||
if filter.contains(MethodFilter::HEAD) {
|
||||
append_allow_header(&mut allow_header, "HEAD");
|
||||
head = svc.clone();
|
||||
}
|
||||
if filter.contains(MethodFilter::DELETE) {
|
||||
append_allow_header(&mut allow_header, "DELETE");
|
||||
delete = svc.clone();
|
||||
}
|
||||
if filter.contains(MethodFilter::OPTIONS) {
|
||||
append_allow_header(&mut allow_header, "OPTIONS");
|
||||
options = svc.clone();
|
||||
}
|
||||
if filter.contains(MethodFilter::PATCH) {
|
||||
append_allow_header(&mut allow_header, "PATCH");
|
||||
patch = svc.clone();
|
||||
}
|
||||
if filter.contains(MethodFilter::POST) {
|
||||
append_allow_header(&mut allow_header, "POST");
|
||||
post = svc.clone();
|
||||
}
|
||||
if filter.contains(MethodFilter::PUT) {
|
||||
append_allow_header(&mut allow_header, "PUT");
|
||||
put = svc.clone();
|
||||
}
|
||||
if filter.contains(MethodFilter::TRACE) {
|
||||
append_allow_header(&mut allow_header, "TRACE");
|
||||
trace = svc;
|
||||
}
|
||||
Self {
|
||||
|
@ -858,9 +900,35 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
|||
put,
|
||||
trace,
|
||||
fallback,
|
||||
allow_header,
|
||||
_request_body: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn skip_allow_header(mut self) -> Self {
|
||||
self.allow_header = AllowHeader::Skip;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
|
||||
match allow_header {
|
||||
AllowHeader::None => {
|
||||
*allow_header = AllowHeader::Bytes(BytesMut::from(method));
|
||||
}
|
||||
AllowHeader::Skip => {}
|
||||
AllowHeader::Bytes(allow_header) => {
|
||||
if let Ok(s) = std::str::from_utf8(allow_header) {
|
||||
if !s.contains(method) {
|
||||
allow_header.extend_from_slice(b",");
|
||||
allow_header.extend_from_slice(method.as_bytes());
|
||||
}
|
||||
} else {
|
||||
#[cfg(debug_assertions)]
|
||||
panic!("`allow_header` contained invalid uft-8. This should never happen")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, E> Clone for MethodRouter<B, E> {
|
||||
|
@ -875,6 +943,7 @@ impl<B, E> Clone for MethodRouter<B, E> {
|
|||
put: self.put.clone(),
|
||||
trace: self.trace.clone(),
|
||||
fallback: self.fallback.clone(),
|
||||
allow_header: self.allow_header.clone(),
|
||||
_request_body: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -931,6 +1000,7 @@ impl<B, E> Service<Request<B>> for MethodRouter<B, E> {
|
|||
put,
|
||||
trace,
|
||||
fallback,
|
||||
allow_header,
|
||||
_request_body: _,
|
||||
} = self;
|
||||
|
||||
|
@ -944,13 +1014,18 @@ impl<B, E> Service<Request<B>> for MethodRouter<B, E> {
|
|||
call!(req, method, DELETE, delete);
|
||||
call!(req, method, TRACE, trace);
|
||||
|
||||
match fallback {
|
||||
Fallback::Default(fallback) => {
|
||||
RouteFuture::new(fallback.0.clone().oneshot(req)).strip_body(method == Method::HEAD)
|
||||
}
|
||||
Fallback::Custom(fallback) => {
|
||||
RouteFuture::new(fallback.0.clone().oneshot(req)).strip_body(method == Method::HEAD)
|
||||
}
|
||||
let future =
|
||||
match fallback {
|
||||
Fallback::Default(fallback) => RouteFuture::new(fallback.0.clone().oneshot(req))
|
||||
.strip_body(method == Method::HEAD),
|
||||
Fallback::Custom(fallback) => RouteFuture::new(fallback.0.clone().oneshot(req))
|
||||
.strip_body(method == Method::HEAD),
|
||||
};
|
||||
|
||||
match allow_header {
|
||||
AllowHeader::None => future.allow_header(Bytes::new()),
|
||||
AllowHeader::Skip => future,
|
||||
AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -959,6 +1034,8 @@ impl<B, E> Service<Request<B>> for MethodRouter<B, E> {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::{body::Body, error_handling::HandleErrorLayer};
|
||||
use axum_core::response::IntoResponse;
|
||||
use http::{header::ALLOW, HeaderMap};
|
||||
use std::time::Duration;
|
||||
use tower::{timeout::TimeoutLayer, Service, ServiceExt};
|
||||
use tower_http::{auth::RequireAuthorizationLayer, services::fs::ServeDir};
|
||||
|
@ -966,7 +1043,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn method_not_allowed_by_default() {
|
||||
let mut svc = MethodRouter::new();
|
||||
let (status, body) = call(Method::GET, &mut svc).await;
|
||||
let (status, _, body) = call(Method::GET, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||
assert!(body.is_empty());
|
||||
}
|
||||
|
@ -974,7 +1051,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn get_handler() {
|
||||
let mut svc = MethodRouter::new().get(ok);
|
||||
let (status, body) = call(Method::GET, &mut svc).await;
|
||||
let (status, _, body) = call(Method::GET, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::OK);
|
||||
assert_eq!(body, "ok");
|
||||
}
|
||||
|
@ -982,7 +1059,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn get_accepts_head() {
|
||||
let mut svc = MethodRouter::new().get(ok);
|
||||
let (status, body) = call(Method::HEAD, &mut svc).await;
|
||||
let (status, _, body) = call(Method::HEAD, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::OK);
|
||||
assert!(body.is_empty());
|
||||
}
|
||||
|
@ -990,7 +1067,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn head_takes_precedence_over_get() {
|
||||
let mut svc = MethodRouter::new().head(created).get(ok);
|
||||
let (status, body) = call(Method::HEAD, &mut svc).await;
|
||||
let (status, _, body) = call(Method::HEAD, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::CREATED);
|
||||
assert!(body.is_empty());
|
||||
}
|
||||
|
@ -999,10 +1076,10 @@ mod tests {
|
|||
async fn merge() {
|
||||
let mut svc = get(ok).merge(post(ok));
|
||||
|
||||
let (status, _) = call(Method::GET, &mut svc).await;
|
||||
let (status, _, _) = call(Method::GET, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::OK);
|
||||
|
||||
let (status, _) = call(Method::POST, &mut svc).await;
|
||||
let (status, _, _) = call(Method::POST, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::OK);
|
||||
}
|
||||
|
||||
|
@ -1013,11 +1090,11 @@ mod tests {
|
|||
.layer(RequireAuthorizationLayer::bearer("password"));
|
||||
|
||||
// method with route
|
||||
let (status, _) = call(Method::GET, &mut svc).await;
|
||||
let (status, _, _) = call(Method::GET, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||
|
||||
// method without route
|
||||
let (status, _) = call(Method::DELETE, &mut svc).await;
|
||||
let (status, _, _) = call(Method::DELETE, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
|
@ -1028,11 +1105,11 @@ mod tests {
|
|||
.route_layer(RequireAuthorizationLayer::bearer("password"));
|
||||
|
||||
// method with route
|
||||
let (status, _) = call(Method::GET, &mut svc).await;
|
||||
let (status, _, _) = call(Method::GET, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||
|
||||
// method without route
|
||||
let (status, _) = call(Method::DELETE, &mut svc).await;
|
||||
let (status, _, _) = call(Method::DELETE, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||
}
|
||||
|
||||
|
@ -1062,7 +1139,93 @@ mod tests {
|
|||
crate::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(app.into_make_service());
|
||||
}
|
||||
|
||||
async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, String)
|
||||
#[tokio::test]
|
||||
async fn sets_allow_header() {
|
||||
let mut svc = MethodRouter::new().put(ok).patch(ok);
|
||||
let (status, headers, _) = call(Method::GET, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||
assert_eq!(headers[ALLOW], "PUT,PATCH");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sets_allow_header_get_head() {
|
||||
let mut svc = MethodRouter::new().get(ok).head(ok);
|
||||
let (status, headers, _) = call(Method::PUT, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||
assert_eq!(headers[ALLOW], "GET,HEAD");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_allow_header_by_default() {
|
||||
let mut svc = MethodRouter::new();
|
||||
let (status, headers, _) = call(Method::PATCH, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||
assert_eq!(headers[ALLOW], "");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn allow_header_when_merging() {
|
||||
let a = put(ok).patch(ok);
|
||||
let b = get(ok).head(ok);
|
||||
let mut svc = a.merge(b);
|
||||
|
||||
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||
assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn allow_header_any() {
|
||||
let mut svc = any(ok);
|
||||
|
||||
let (status, headers, _) = call(Method::GET, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::OK);
|
||||
assert!(!headers.contains_key(ALLOW));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn allow_header_with_fallback() {
|
||||
let mut svc = MethodRouter::new().get(ok).fallback(
|
||||
(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") }).into_service(),
|
||||
);
|
||||
|
||||
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||
assert_eq!(headers[ALLOW], "GET,HEAD");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn allow_header_with_fallback_that_sets_allow() {
|
||||
async fn fallback(method: Method) -> Response {
|
||||
if method == Method::POST {
|
||||
"OK".into_response()
|
||||
} else {
|
||||
let headers = crate::response::Headers([(ALLOW, "GET,POST")]);
|
||||
(
|
||||
StatusCode::METHOD_NOT_ALLOWED,
|
||||
headers,
|
||||
"Method not allowed",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
let mut svc = MethodRouter::new()
|
||||
.get(ok)
|
||||
.fallback(fallback.into_service());
|
||||
|
||||
let (status, _, _) = call(Method::GET, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::OK);
|
||||
|
||||
let (status, _, _) = call(Method::POST, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::OK);
|
||||
|
||||
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
|
||||
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||
assert_eq!(headers[ALLOW], "GET,POST");
|
||||
}
|
||||
|
||||
async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
|
||||
where
|
||||
S: Service<Request<Body>, Response = Response, Error = Infallible>,
|
||||
{
|
||||
|
@ -1074,7 +1237,7 @@ mod tests {
|
|||
let response = svc.ready().await.unwrap().call(request).await.unwrap();
|
||||
let (parts, body) = response.into_parts();
|
||||
let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap();
|
||||
(parts.status, body)
|
||||
(parts.status, parts.headers, body)
|
||||
}
|
||||
|
||||
async fn ok() -> (StatusCode, &'static str) {
|
||||
|
|
|
@ -2,7 +2,8 @@ use crate::{
|
|||
body::{boxed, Body, Empty},
|
||||
response::Response,
|
||||
};
|
||||
use http::Request;
|
||||
use bytes::Bytes;
|
||||
use http::{header, HeaderValue, Request};
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
|
@ -70,6 +71,7 @@ pin_project! {
|
|||
Request<B>,
|
||||
>,
|
||||
strip_body: bool,
|
||||
allow_header: Option<Bytes>,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -80,6 +82,7 @@ impl<B, E> RouteFuture<B, E> {
|
|||
RouteFuture {
|
||||
future,
|
||||
strip_body: false,
|
||||
allow_header: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -87,22 +90,41 @@ impl<B, E> RouteFuture<B, E> {
|
|||
self.strip_body = strip_body;
|
||||
self
|
||||
}
|
||||
|
||||
pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
|
||||
self.allow_header = Some(allow_header);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, E> Future for RouteFuture<B, E> {
|
||||
type Output = Result<Response, E>;
|
||||
|
||||
#[inline]
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let strip_body = self.strip_body;
|
||||
let allow_header = self.allow_header.take();
|
||||
|
||||
match self.project().future.poll(cx) {
|
||||
Poll::Ready(Ok(res)) => {
|
||||
if strip_body {
|
||||
Poll::Ready(Ok(res.map(|_| boxed(Empty::new()))))
|
||||
let mut res = if strip_body {
|
||||
res.map(|_| boxed(Empty::new()))
|
||||
} else {
|
||||
Poll::Ready(Ok(res))
|
||||
res
|
||||
};
|
||||
|
||||
match allow_header {
|
||||
Some(allow_header) if !res.headers().contains_key(header::ALLOW) => {
|
||||
res.headers_mut().insert(
|
||||
header::ALLOW,
|
||||
HeaderValue::from_maybe_shared(allow_header)
|
||||
.expect("invalid `Allow` header"),
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(res))
|
||||
}
|
||||
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
|
|
Loading…
Reference in a new issue