mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-05 18:27:07 +01:00
Set Allow
header when returning 405 Method Not Allowed
(#733)
This commit is contained in:
parent
a04fc42d75
commit
513766b9e7
4 changed files with 221 additions and 26 deletions
|
@ -51,6 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
- `PathRejection`
|
- `PathRejection`
|
||||||
- `MatchedPathRejection`
|
- `MatchedPathRejection`
|
||||||
- `WebSocketUpgradeRejection`
|
- `WebSocketUpgradeRejection`
|
||||||
|
- **fixed:** Set `Allow` header when responding with `405 Method Not Allowed`
|
||||||
|
|
||||||
[#644]: https://github.com/tokio-rs/axum/pull/644
|
[#644]: https://github.com/tokio-rs/axum/pull/644
|
||||||
[#665]: https://github.com/tokio-rs/axum/pull/665
|
[#665]: https://github.com/tokio-rs/axum/pull/665
|
||||||
|
|
|
@ -51,3 +51,12 @@ async fn fallback_two() -> impl IntoResponse {}
|
||||||
# hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
# hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||||
# };
|
# };
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Setting the `Allow` header
|
||||||
|
|
||||||
|
By default `MethodRouter` will set the `Allow` header when returning `405 Method
|
||||||
|
Not Allowed`. This is also done when the fallback is used unless the response
|
||||||
|
generated by the fallback already sets the `Allow` header.
|
||||||
|
|
||||||
|
This means if you use `fallback` to accept additional methods make sure you set
|
||||||
|
the `Allow` header correctly.
|
||||||
|
|
|
@ -7,6 +7,7 @@ use crate::{
|
||||||
routing::{Fallback, MethodFilter, Route},
|
routing::{Fallback, MethodFilter, Route},
|
||||||
BoxError,
|
BoxError,
|
||||||
};
|
};
|
||||||
|
use bytes::BytesMut;
|
||||||
use std::{
|
use std::{
|
||||||
convert::Infallible,
|
convert::Infallible,
|
||||||
fmt,
|
fmt,
|
||||||
|
@ -386,7 +387,7 @@ where
|
||||||
ResBody: HttpBody<Data = Bytes> + Send + 'static,
|
ResBody: HttpBody<Data = Bytes> + Send + 'static,
|
||||||
ResBody::Error: Into<BoxError>,
|
ResBody::Error: Into<BoxError>,
|
||||||
{
|
{
|
||||||
MethodRouter::new().fallback(svc)
|
MethodRouter::new().fallback(svc).skip_allow_header()
|
||||||
}
|
}
|
||||||
|
|
||||||
top_level_handler_fn!(delete, DELETE);
|
top_level_handler_fn!(delete, DELETE);
|
||||||
|
@ -469,7 +470,9 @@ where
|
||||||
B: Send + 'static,
|
B: Send + 'static,
|
||||||
T: 'static,
|
T: 'static,
|
||||||
{
|
{
|
||||||
MethodRouter::new().fallback_boxed_response_body(handler.into_service())
|
MethodRouter::new()
|
||||||
|
.fallback_boxed_response_body(handler.into_service())
|
||||||
|
.skip_allow_header()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A [`Service`] that accepts requests based on a [`MethodFilter`] and
|
/// A [`Service`] that accepts requests based on a [`MethodFilter`] and
|
||||||
|
@ -484,9 +487,20 @@ pub struct MethodRouter<B = Body, E = Infallible> {
|
||||||
put: Option<Route<B, E>>,
|
put: Option<Route<B, E>>,
|
||||||
trace: Option<Route<B, E>>,
|
trace: Option<Route<B, E>>,
|
||||||
fallback: Fallback<B, E>,
|
fallback: Fallback<B, E>,
|
||||||
|
allow_header: AllowHeader,
|
||||||
_request_body: PhantomData<fn() -> (B, E)>,
|
_request_body: PhantomData<fn() -> (B, E)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
enum AllowHeader {
|
||||||
|
/// No `Allow` header value has been built-up yet. This is the default state
|
||||||
|
None,
|
||||||
|
/// Don't set an `Allow` header. This is used when `any` or `any_service` are called.
|
||||||
|
Skip,
|
||||||
|
/// The current value of the `Allow` header.
|
||||||
|
Bytes(BytesMut),
|
||||||
|
}
|
||||||
|
|
||||||
impl<B, E> fmt::Debug for MethodRouter<B, E> {
|
impl<B, E> fmt::Debug for MethodRouter<B, E> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("MethodRouter")
|
f.debug_struct("MethodRouter")
|
||||||
|
@ -522,6 +536,7 @@ impl<B, E> MethodRouter<B, E> {
|
||||||
post: None,
|
post: None,
|
||||||
put: None,
|
put: None,
|
||||||
trace: None,
|
trace: None,
|
||||||
|
allow_header: AllowHeader::None,
|
||||||
fallback: Fallback::Default(fallback),
|
fallback: Fallback::Default(fallback),
|
||||||
_request_body: PhantomData,
|
_request_body: PhantomData,
|
||||||
}
|
}
|
||||||
|
@ -677,6 +692,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
||||||
put: self.put.map(layer_fn),
|
put: self.put.map(layer_fn),
|
||||||
trace: self.trace.map(layer_fn),
|
trace: self.trace.map(layer_fn),
|
||||||
fallback: self.fallback.map(layer_fn),
|
fallback: self.fallback.map(layer_fn),
|
||||||
|
allow_header: self.allow_header,
|
||||||
_request_body: PhantomData,
|
_request_body: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -710,6 +726,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
||||||
put: self.put.map(layer_fn),
|
put: self.put.map(layer_fn),
|
||||||
trace: self.trace.map(layer_fn),
|
trace: self.trace.map(layer_fn),
|
||||||
fallback: self.fallback,
|
fallback: self.fallback,
|
||||||
|
allow_header: self.allow_header,
|
||||||
_request_body: PhantomData,
|
_request_body: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -741,6 +758,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
||||||
put,
|
put,
|
||||||
trace,
|
trace,
|
||||||
fallback,
|
fallback,
|
||||||
|
allow_header,
|
||||||
_request_body: _,
|
_request_body: _,
|
||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
|
@ -754,6 +772,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
||||||
put: put_other,
|
put: put_other,
|
||||||
trace: trace_other,
|
trace: trace_other,
|
||||||
fallback: fallback_other,
|
fallback: fallback_other,
|
||||||
|
allow_header: allow_header_other,
|
||||||
_request_body: _,
|
_request_body: _,
|
||||||
} = other;
|
} = other;
|
||||||
|
|
||||||
|
@ -775,6 +794,18 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let allow_header = match (allow_header, allow_header_other) {
|
||||||
|
(AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
|
||||||
|
(AllowHeader::None, AllowHeader::None) => AllowHeader::None,
|
||||||
|
(AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
|
||||||
|
(AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
|
||||||
|
(AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
|
||||||
|
a.extend_from_slice(b",");
|
||||||
|
a.extend_from_slice(&b);
|
||||||
|
AllowHeader::Bytes(a)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
get,
|
get,
|
||||||
head,
|
head,
|
||||||
|
@ -785,6 +816,7 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
||||||
put,
|
put,
|
||||||
trace,
|
trace,
|
||||||
fallback,
|
fallback,
|
||||||
|
allow_header,
|
||||||
_request_body: PhantomData,
|
_request_body: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -821,31 +853,41 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
||||||
mut put,
|
mut put,
|
||||||
mut trace,
|
mut trace,
|
||||||
fallback,
|
fallback,
|
||||||
|
mut allow_header,
|
||||||
_request_body: _,
|
_request_body: _,
|
||||||
} = self;
|
} = self;
|
||||||
let svc = Some(Route::new(svc));
|
let svc = Some(Route::new(svc));
|
||||||
if filter.contains(MethodFilter::GET) {
|
if filter.contains(MethodFilter::GET) {
|
||||||
get = svc.clone();
|
get = svc.clone();
|
||||||
|
append_allow_header(&mut allow_header, "GET");
|
||||||
|
append_allow_header(&mut allow_header, "HEAD");
|
||||||
}
|
}
|
||||||
if filter.contains(MethodFilter::HEAD) {
|
if filter.contains(MethodFilter::HEAD) {
|
||||||
|
append_allow_header(&mut allow_header, "HEAD");
|
||||||
head = svc.clone();
|
head = svc.clone();
|
||||||
}
|
}
|
||||||
if filter.contains(MethodFilter::DELETE) {
|
if filter.contains(MethodFilter::DELETE) {
|
||||||
|
append_allow_header(&mut allow_header, "DELETE");
|
||||||
delete = svc.clone();
|
delete = svc.clone();
|
||||||
}
|
}
|
||||||
if filter.contains(MethodFilter::OPTIONS) {
|
if filter.contains(MethodFilter::OPTIONS) {
|
||||||
|
append_allow_header(&mut allow_header, "OPTIONS");
|
||||||
options = svc.clone();
|
options = svc.clone();
|
||||||
}
|
}
|
||||||
if filter.contains(MethodFilter::PATCH) {
|
if filter.contains(MethodFilter::PATCH) {
|
||||||
|
append_allow_header(&mut allow_header, "PATCH");
|
||||||
patch = svc.clone();
|
patch = svc.clone();
|
||||||
}
|
}
|
||||||
if filter.contains(MethodFilter::POST) {
|
if filter.contains(MethodFilter::POST) {
|
||||||
|
append_allow_header(&mut allow_header, "POST");
|
||||||
post = svc.clone();
|
post = svc.clone();
|
||||||
}
|
}
|
||||||
if filter.contains(MethodFilter::PUT) {
|
if filter.contains(MethodFilter::PUT) {
|
||||||
|
append_allow_header(&mut allow_header, "PUT");
|
||||||
put = svc.clone();
|
put = svc.clone();
|
||||||
}
|
}
|
||||||
if filter.contains(MethodFilter::TRACE) {
|
if filter.contains(MethodFilter::TRACE) {
|
||||||
|
append_allow_header(&mut allow_header, "TRACE");
|
||||||
trace = svc;
|
trace = svc;
|
||||||
}
|
}
|
||||||
Self {
|
Self {
|
||||||
|
@ -858,9 +900,35 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
|
||||||
put,
|
put,
|
||||||
trace,
|
trace,
|
||||||
fallback,
|
fallback,
|
||||||
|
allow_header,
|
||||||
_request_body: PhantomData,
|
_request_body: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn skip_allow_header(mut self) -> Self {
|
||||||
|
self.allow_header = AllowHeader::Skip;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
|
||||||
|
match allow_header {
|
||||||
|
AllowHeader::None => {
|
||||||
|
*allow_header = AllowHeader::Bytes(BytesMut::from(method));
|
||||||
|
}
|
||||||
|
AllowHeader::Skip => {}
|
||||||
|
AllowHeader::Bytes(allow_header) => {
|
||||||
|
if let Ok(s) = std::str::from_utf8(allow_header) {
|
||||||
|
if !s.contains(method) {
|
||||||
|
allow_header.extend_from_slice(b",");
|
||||||
|
allow_header.extend_from_slice(method.as_bytes());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
panic!("`allow_header` contained invalid uft-8. This should never happen")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B, E> Clone for MethodRouter<B, E> {
|
impl<B, E> Clone for MethodRouter<B, E> {
|
||||||
|
@ -875,6 +943,7 @@ impl<B, E> Clone for MethodRouter<B, E> {
|
||||||
put: self.put.clone(),
|
put: self.put.clone(),
|
||||||
trace: self.trace.clone(),
|
trace: self.trace.clone(),
|
||||||
fallback: self.fallback.clone(),
|
fallback: self.fallback.clone(),
|
||||||
|
allow_header: self.allow_header.clone(),
|
||||||
_request_body: PhantomData,
|
_request_body: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -931,6 +1000,7 @@ impl<B, E> Service<Request<B>> for MethodRouter<B, E> {
|
||||||
put,
|
put,
|
||||||
trace,
|
trace,
|
||||||
fallback,
|
fallback,
|
||||||
|
allow_header,
|
||||||
_request_body: _,
|
_request_body: _,
|
||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
|
@ -944,13 +1014,18 @@ impl<B, E> Service<Request<B>> for MethodRouter<B, E> {
|
||||||
call!(req, method, DELETE, delete);
|
call!(req, method, DELETE, delete);
|
||||||
call!(req, method, TRACE, trace);
|
call!(req, method, TRACE, trace);
|
||||||
|
|
||||||
match fallback {
|
let future =
|
||||||
Fallback::Default(fallback) => {
|
match fallback {
|
||||||
RouteFuture::new(fallback.0.clone().oneshot(req)).strip_body(method == Method::HEAD)
|
Fallback::Default(fallback) => RouteFuture::new(fallback.0.clone().oneshot(req))
|
||||||
}
|
.strip_body(method == Method::HEAD),
|
||||||
Fallback::Custom(fallback) => {
|
Fallback::Custom(fallback) => RouteFuture::new(fallback.0.clone().oneshot(req))
|
||||||
RouteFuture::new(fallback.0.clone().oneshot(req)).strip_body(method == Method::HEAD)
|
.strip_body(method == Method::HEAD),
|
||||||
}
|
};
|
||||||
|
|
||||||
|
match allow_header {
|
||||||
|
AllowHeader::None => future.allow_header(Bytes::new()),
|
||||||
|
AllowHeader::Skip => future,
|
||||||
|
AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -959,6 +1034,8 @@ impl<B, E> Service<Request<B>> for MethodRouter<B, E> {
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{body::Body, error_handling::HandleErrorLayer};
|
use crate::{body::Body, error_handling::HandleErrorLayer};
|
||||||
|
use axum_core::response::IntoResponse;
|
||||||
|
use http::{header::ALLOW, HeaderMap};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tower::{timeout::TimeoutLayer, Service, ServiceExt};
|
use tower::{timeout::TimeoutLayer, Service, ServiceExt};
|
||||||
use tower_http::{auth::RequireAuthorizationLayer, services::fs::ServeDir};
|
use tower_http::{auth::RequireAuthorizationLayer, services::fs::ServeDir};
|
||||||
|
@ -966,7 +1043,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn method_not_allowed_by_default() {
|
async fn method_not_allowed_by_default() {
|
||||||
let mut svc = MethodRouter::new();
|
let mut svc = MethodRouter::new();
|
||||||
let (status, body) = call(Method::GET, &mut svc).await;
|
let (status, _, body) = call(Method::GET, &mut svc).await;
|
||||||
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||||
assert!(body.is_empty());
|
assert!(body.is_empty());
|
||||||
}
|
}
|
||||||
|
@ -974,7 +1051,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn get_handler() {
|
async fn get_handler() {
|
||||||
let mut svc = MethodRouter::new().get(ok);
|
let mut svc = MethodRouter::new().get(ok);
|
||||||
let (status, body) = call(Method::GET, &mut svc).await;
|
let (status, _, body) = call(Method::GET, &mut svc).await;
|
||||||
assert_eq!(status, StatusCode::OK);
|
assert_eq!(status, StatusCode::OK);
|
||||||
assert_eq!(body, "ok");
|
assert_eq!(body, "ok");
|
||||||
}
|
}
|
||||||
|
@ -982,7 +1059,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn get_accepts_head() {
|
async fn get_accepts_head() {
|
||||||
let mut svc = MethodRouter::new().get(ok);
|
let mut svc = MethodRouter::new().get(ok);
|
||||||
let (status, body) = call(Method::HEAD, &mut svc).await;
|
let (status, _, body) = call(Method::HEAD, &mut svc).await;
|
||||||
assert_eq!(status, StatusCode::OK);
|
assert_eq!(status, StatusCode::OK);
|
||||||
assert!(body.is_empty());
|
assert!(body.is_empty());
|
||||||
}
|
}
|
||||||
|
@ -990,7 +1067,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn head_takes_precedence_over_get() {
|
async fn head_takes_precedence_over_get() {
|
||||||
let mut svc = MethodRouter::new().head(created).get(ok);
|
let mut svc = MethodRouter::new().head(created).get(ok);
|
||||||
let (status, body) = call(Method::HEAD, &mut svc).await;
|
let (status, _, body) = call(Method::HEAD, &mut svc).await;
|
||||||
assert_eq!(status, StatusCode::CREATED);
|
assert_eq!(status, StatusCode::CREATED);
|
||||||
assert!(body.is_empty());
|
assert!(body.is_empty());
|
||||||
}
|
}
|
||||||
|
@ -999,10 +1076,10 @@ mod tests {
|
||||||
async fn merge() {
|
async fn merge() {
|
||||||
let mut svc = get(ok).merge(post(ok));
|
let mut svc = get(ok).merge(post(ok));
|
||||||
|
|
||||||
let (status, _) = call(Method::GET, &mut svc).await;
|
let (status, _, _) = call(Method::GET, &mut svc).await;
|
||||||
assert_eq!(status, StatusCode::OK);
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
|
||||||
let (status, _) = call(Method::POST, &mut svc).await;
|
let (status, _, _) = call(Method::POST, &mut svc).await;
|
||||||
assert_eq!(status, StatusCode::OK);
|
assert_eq!(status, StatusCode::OK);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1013,11 +1090,11 @@ mod tests {
|
||||||
.layer(RequireAuthorizationLayer::bearer("password"));
|
.layer(RequireAuthorizationLayer::bearer("password"));
|
||||||
|
|
||||||
// method with route
|
// method with route
|
||||||
let (status, _) = call(Method::GET, &mut svc).await;
|
let (status, _, _) = call(Method::GET, &mut svc).await;
|
||||||
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||||
|
|
||||||
// method without route
|
// method without route
|
||||||
let (status, _) = call(Method::DELETE, &mut svc).await;
|
let (status, _, _) = call(Method::DELETE, &mut svc).await;
|
||||||
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1028,11 +1105,11 @@ mod tests {
|
||||||
.route_layer(RequireAuthorizationLayer::bearer("password"));
|
.route_layer(RequireAuthorizationLayer::bearer("password"));
|
||||||
|
|
||||||
// method with route
|
// method with route
|
||||||
let (status, _) = call(Method::GET, &mut svc).await;
|
let (status, _, _) = call(Method::GET, &mut svc).await;
|
||||||
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||||
|
|
||||||
// method without route
|
// method without route
|
||||||
let (status, _) = call(Method::DELETE, &mut svc).await;
|
let (status, _, _) = call(Method::DELETE, &mut svc).await;
|
||||||
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1062,7 +1139,93 @@ mod tests {
|
||||||
crate::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(app.into_make_service());
|
crate::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(app.into_make_service());
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, String)
|
#[tokio::test]
|
||||||
|
async fn sets_allow_header() {
|
||||||
|
let mut svc = MethodRouter::new().put(ok).patch(ok);
|
||||||
|
let (status, headers, _) = call(Method::GET, &mut svc).await;
|
||||||
|
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||||
|
assert_eq!(headers[ALLOW], "PUT,PATCH");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sets_allow_header_get_head() {
|
||||||
|
let mut svc = MethodRouter::new().get(ok).head(ok);
|
||||||
|
let (status, headers, _) = call(Method::PUT, &mut svc).await;
|
||||||
|
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||||
|
assert_eq!(headers[ALLOW], "GET,HEAD");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn empty_allow_header_by_default() {
|
||||||
|
let mut svc = MethodRouter::new();
|
||||||
|
let (status, headers, _) = call(Method::PATCH, &mut svc).await;
|
||||||
|
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||||
|
assert_eq!(headers[ALLOW], "");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn allow_header_when_merging() {
|
||||||
|
let a = put(ok).patch(ok);
|
||||||
|
let b = get(ok).head(ok);
|
||||||
|
let mut svc = a.merge(b);
|
||||||
|
|
||||||
|
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
|
||||||
|
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||||
|
assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn allow_header_any() {
|
||||||
|
let mut svc = any(ok);
|
||||||
|
|
||||||
|
let (status, headers, _) = call(Method::GET, &mut svc).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert!(!headers.contains_key(ALLOW));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn allow_header_with_fallback() {
|
||||||
|
let mut svc = MethodRouter::new().get(ok).fallback(
|
||||||
|
(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") }).into_service(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
|
||||||
|
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||||
|
assert_eq!(headers[ALLOW], "GET,HEAD");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn allow_header_with_fallback_that_sets_allow() {
|
||||||
|
async fn fallback(method: Method) -> Response {
|
||||||
|
if method == Method::POST {
|
||||||
|
"OK".into_response()
|
||||||
|
} else {
|
||||||
|
let headers = crate::response::Headers([(ALLOW, "GET,POST")]);
|
||||||
|
(
|
||||||
|
StatusCode::METHOD_NOT_ALLOWED,
|
||||||
|
headers,
|
||||||
|
"Method not allowed",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut svc = MethodRouter::new()
|
||||||
|
.get(ok)
|
||||||
|
.fallback(fallback.into_service());
|
||||||
|
|
||||||
|
let (status, _, _) = call(Method::GET, &mut svc).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
|
||||||
|
let (status, _, _) = call(Method::POST, &mut svc).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
|
||||||
|
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
|
||||||
|
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
|
||||||
|
assert_eq!(headers[ALLOW], "GET,POST");
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
|
||||||
where
|
where
|
||||||
S: Service<Request<Body>, Response = Response, Error = Infallible>,
|
S: Service<Request<Body>, Response = Response, Error = Infallible>,
|
||||||
{
|
{
|
||||||
|
@ -1074,7 +1237,7 @@ mod tests {
|
||||||
let response = svc.ready().await.unwrap().call(request).await.unwrap();
|
let response = svc.ready().await.unwrap().call(request).await.unwrap();
|
||||||
let (parts, body) = response.into_parts();
|
let (parts, body) = response.into_parts();
|
||||||
let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap();
|
let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap();
|
||||||
(parts.status, body)
|
(parts.status, parts.headers, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn ok() -> (StatusCode, &'static str) {
|
async fn ok() -> (StatusCode, &'static str) {
|
||||||
|
|
|
@ -2,7 +2,8 @@ use crate::{
|
||||||
body::{boxed, Body, Empty},
|
body::{boxed, Body, Empty},
|
||||||
response::Response,
|
response::Response,
|
||||||
};
|
};
|
||||||
use http::Request;
|
use bytes::Bytes;
|
||||||
|
use http::{header, HeaderValue, Request};
|
||||||
use pin_project_lite::pin_project;
|
use pin_project_lite::pin_project;
|
||||||
use std::{
|
use std::{
|
||||||
convert::Infallible,
|
convert::Infallible,
|
||||||
|
@ -70,6 +71,7 @@ pin_project! {
|
||||||
Request<B>,
|
Request<B>,
|
||||||
>,
|
>,
|
||||||
strip_body: bool,
|
strip_body: bool,
|
||||||
|
allow_header: Option<Bytes>,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,6 +82,7 @@ impl<B, E> RouteFuture<B, E> {
|
||||||
RouteFuture {
|
RouteFuture {
|
||||||
future,
|
future,
|
||||||
strip_body: false,
|
strip_body: false,
|
||||||
|
allow_header: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,22 +90,41 @@ impl<B, E> RouteFuture<B, E> {
|
||||||
self.strip_body = strip_body;
|
self.strip_body = strip_body;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
|
||||||
|
self.allow_header = Some(allow_header);
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B, E> Future for RouteFuture<B, E> {
|
impl<B, E> Future for RouteFuture<B, E> {
|
||||||
type Output = Result<Response, E>;
|
type Output = Result<Response, E>;
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
let strip_body = self.strip_body;
|
let strip_body = self.strip_body;
|
||||||
|
let allow_header = self.allow_header.take();
|
||||||
|
|
||||||
match self.project().future.poll(cx) {
|
match self.project().future.poll(cx) {
|
||||||
Poll::Ready(Ok(res)) => {
|
Poll::Ready(Ok(res)) => {
|
||||||
if strip_body {
|
let mut res = if strip_body {
|
||||||
Poll::Ready(Ok(res.map(|_| boxed(Empty::new()))))
|
res.map(|_| boxed(Empty::new()))
|
||||||
} else {
|
} else {
|
||||||
Poll::Ready(Ok(res))
|
res
|
||||||
|
};
|
||||||
|
|
||||||
|
match allow_header {
|
||||||
|
Some(allow_header) if !res.headers().contains_key(header::ALLOW) => {
|
||||||
|
res.headers_mut().insert(
|
||||||
|
header::ALLOW,
|
||||||
|
HeaderValue::from_maybe_shared(allow_header)
|
||||||
|
.expect("invalid `Allow` header"),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Poll::Ready(Ok(res))
|
||||||
}
|
}
|
||||||
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
|
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
|
||||||
Poll::Pending => Poll::Pending,
|
Poll::Pending => Poll::Pending,
|
||||||
|
|
Loading…
Reference in a new issue