1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-03-30 19:28:16 +02:00

Return 405 Method Not Allowed for unsupported method for route ()

Fixes https://github.com/tokio-rs/axum/issues/61
This commit is contained in:
David Pedersen 2021-07-31 21:05:53 +02:00 committed by GitHub
parent 49dd1ca49a
commit 407aa533d7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 151 additions and 14 deletions

View file

@ -7,9 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- Implement `Stream` for `WebSocket`.
- Implement `Sink` for `WebSocket`.
- Implement `Deref` most extractors.
- Implement `Stream` for `WebSocket` ([#52](https://github.com/tokio-rs/axum/pull/52))
- Implement `Sink` for `WebSocket` ([#52](https://github.com/tokio-rs/axum/pull/52))
- Implement `Deref` most extractors ([#56](https://github.com/tokio-rs/axum/pull/56))
- Return `405 Method Not Allowed` for unsupported method for route ([#63](https://github.com/tokio-rs/axum/pull/63))
## Breaking changes

View file

@ -168,7 +168,7 @@ where
OnMethod {
method,
svc: handler.into_service(),
fallback: EmptyRouter::new(),
fallback: EmptyRouter::method_not_allowed(),
}
}

View file

@ -676,7 +676,7 @@ where
{
use routing::RoutingDsl;
routing::EmptyRouter::new().route(description, service)
routing::EmptyRouter::not_found().route(description, service)
}
mod sealed {

View file

@ -84,7 +84,6 @@ pub struct Route<S, F> {
}
/// Trait for building routers.
// TODO(david): this name isn't great
#[async_trait]
pub trait RoutingDsl: crate::sealed::Sealed + Sized {
/// Add another route to the router.
@ -364,21 +363,38 @@ fn insert_url_params<B>(req: &mut Request<B>, params: Vec<(String, String)>) {
}
}
/// A [`Service`] that responds with `404 Not Found` to all requests.
/// A [`Service`] that responds with `404 Not Found` or `405 Method not allowed`
/// to all requests.
///
/// This is used as the bottom service in a router stack. You shouldn't have to
/// use to manually.
pub struct EmptyRouter<E = Infallible>(PhantomData<fn() -> E>);
pub struct EmptyRouter<E = Infallible> {
status: StatusCode,
_marker: PhantomData<fn() -> E>,
}
impl<E> EmptyRouter<E> {
pub(crate) fn new() -> Self {
Self(PhantomData)
pub(crate) fn not_found() -> Self {
Self {
status: StatusCode::NOT_FOUND,
_marker: PhantomData,
}
}
pub(crate) fn method_not_allowed() -> Self {
Self {
status: StatusCode::METHOD_NOT_ALLOWED,
_marker: PhantomData,
}
}
}
impl<E> Clone for EmptyRouter<E> {
fn clone(&self) -> Self {
Self(PhantomData)
Self {
status: self.status,
_marker: PhantomData,
}
}
}
@ -405,7 +421,7 @@ impl<B, E> Service<Request<B>> for EmptyRouter<E> {
fn call(&mut self, _req: Request<B>) -> Self::Future {
let mut res = Response::new(crate::body::empty());
*res.status_mut() = StatusCode::NOT_FOUND;
*res.status_mut() = self.status;
EmptyRouterFuture(future::ok(res))
}
}
@ -806,7 +822,7 @@ where
Nested {
pattern: PathPattern::new(description),
svc,
fallback: EmptyRouter::new(),
fallback: EmptyRouter::not_found(),
}
}

View file

@ -256,7 +256,7 @@ where
inner: svc,
_request_body: PhantomData,
},
fallback: EmptyRouter::new(),
fallback: EmptyRouter::method_not_allowed(),
}
}

View file

@ -3,6 +3,7 @@ use crate::{
service,
};
use bytes::Bytes;
use futures_util::future::Ready;
use http::{header::AUTHORIZATION, Request, Response, StatusCode};
use hyper::{Body, Server};
use serde::Deserialize;
@ -10,6 +11,7 @@ use serde_json::json;
use std::{
convert::Infallible,
net::{SocketAddr, TcpListener},
task::{Context, Poll},
time::Duration,
};
use tower::{make::Shared, service_fn, BoxError, Service, ServiceBuilder};
@ -677,6 +679,124 @@ async fn test_extractor_middleware() {
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn wrong_method_handler() {
let app = route("/", get(|| async {}).post(|| async {})).route("/foo", patch(|| async {}));
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.patch(format!("http://{}", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
let res = client
.patch(format!("http://{}/foo", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = client
.post(format!("http://{}/foo", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
let res = client
.get(format!("http://{}/bar", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn wrong_method_nest() {
let nested_app = route("/", get(|| async {}));
let app = crate::routing::nest("/", nested_app);
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = client
.post(format!("http://{}", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
let res = client
.patch(format!("http://{}/foo", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn wrong_method_service() {
#[derive(Clone)]
struct Svc;
impl<R> Service<R> for Svc {
type Response = Response<http_body::Empty<Bytes>>;
type Error = Infallible;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: R) -> Self::Future {
futures_util::future::ok(Response::new(http_body::Empty::new()))
}
}
let app = route("/", service::get(Svc).post(Svc)).route("/foo", service::patch(Svc));
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.patch(format!("http://{}", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
let res = client
.patch(format!("http://{}/foo", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = client
.post(format!("http://{}/foo", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
let res = client
.get(format!("http://{}/bar", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
/// Run a `tower::Service` in the background and get a URI for it.
async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
where