mirror of
https://github.com/tokio-rs/axum.git
synced 2025-03-13 19:27:53 +01:00
Fallback to calling next route if no methods match (#224)
This removes a small foot gun from the routing. This means matching different HTTP methods for the same route that aren't defined together now works. So `Router::new().route("/", get(...)).route("/", post(...))` now accepts both `GET` and `POST`. Previously only `POST` would be accepted.
This commit is contained in:
parent
971c0a394a
commit
0d8f8b7b6c
6 changed files with 112 additions and 94 deletions
|
@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- Add `Headers` for easily customizing headers on a response ([#193](https://github.com/tokio-rs/axum/pull/193))
|
||||
- Add `Redirect` response ([#192](https://github.com/tokio-rs/axum/pull/192))
|
||||
- Make `RequestParts::{new, try_into_request}` public ([#194](https://github.com/tokio-rs/axum/pull/194))
|
||||
- Support matching different HTTP methods for the same route that aren't defined
|
||||
together. So `Router::new().route("/", get(...)).route("/", post(...))` now
|
||||
accepts both `GET` and `POST`. Previously only `POST` would be accepted ([#224](https://github.com/tokio-rs/axum/pull/224))
|
||||
|
||||
## Breaking changes
|
||||
|
||||
|
|
24
src/lib.rs
24
src/lib.rs
|
@ -203,30 +203,6 @@
|
|||
//! # }
|
||||
//! ```
|
||||
//!
|
||||
//! ## Matching multiple methods
|
||||
//!
|
||||
//! If you want a path to accept multiple HTTP methods you must add them all at
|
||||
//! once:
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use axum::{
|
||||
//! Router,
|
||||
//! handler::{get, post},
|
||||
//! };
|
||||
//!
|
||||
//! // `GET /` and `POST /` are both accepted
|
||||
//! let app = Router::new().route("/", get(handler).post(handler));
|
||||
//!
|
||||
//! // This will _not_ work. Only `POST /` will be accessible.
|
||||
//! let wont_work = Router::new().route("/", get(handler)).route("/", post(handler));
|
||||
//!
|
||||
//! async fn handler() {}
|
||||
//! # async {
|
||||
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
//! # axum::Server::bind(&"".parse().unwrap()).serve(wont_work.into_make_service()).await.unwrap();
|
||||
//! # };
|
||||
//! ```
|
||||
//!
|
||||
//! ## Routing to any [`Service`]
|
||||
//!
|
||||
//! axum also supports routing to general [`Service`]s:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
//! Future types.
|
||||
|
||||
use crate::{body::BoxBody, buffer::MpscBuffer};
|
||||
use crate::{body::BoxBody, buffer::MpscBuffer, routing::FromEmptyRouter};
|
||||
use futures_util::ready;
|
||||
use http::{Request, Response};
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
|
@ -11,7 +12,7 @@ use std::{
|
|||
};
|
||||
use tower::{
|
||||
util::{BoxService, Oneshot},
|
||||
BoxError, Service,
|
||||
BoxError, Service, ServiceExt,
|
||||
};
|
||||
|
||||
pub use super::or::ResponseFuture as OrResponseFuture;
|
||||
|
@ -68,7 +69,7 @@ pin_project! {
|
|||
F: Service<Request<B>>
|
||||
{
|
||||
#[pin]
|
||||
inner: RouteFutureInner<S, F, B>,
|
||||
state: RouteFutureInner<S, F, B>,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -77,15 +78,18 @@ where
|
|||
S: Service<Request<B>>,
|
||||
F: Service<Request<B>>,
|
||||
{
|
||||
pub(crate) fn a(a: Oneshot<S, Request<B>>) -> Self {
|
||||
pub(crate) fn a(a: Oneshot<S, Request<B>>, fallback: F) -> Self {
|
||||
RouteFuture {
|
||||
inner: RouteFutureInner::A { a },
|
||||
state: RouteFutureInner::A {
|
||||
a,
|
||||
fallback: Some(fallback),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn b(b: Oneshot<F, Request<B>>) -> Self {
|
||||
RouteFuture {
|
||||
inner: RouteFutureInner::B { b },
|
||||
state: RouteFutureInner::B { b },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -98,8 +102,15 @@ pin_project! {
|
|||
S: Service<Request<B>>,
|
||||
F: Service<Request<B>>,
|
||||
{
|
||||
A { #[pin] a: Oneshot<S, Request<B>> },
|
||||
B { #[pin] b: Oneshot<F, Request<B>> },
|
||||
A {
|
||||
#[pin]
|
||||
a: Oneshot<S, Request<B>>,
|
||||
fallback: Option<F>,
|
||||
},
|
||||
B {
|
||||
#[pin]
|
||||
b: Oneshot<F, Request<B>>
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -107,13 +118,38 @@ impl<S, F, B> Future for RouteFuture<S, F, B>
|
|||
where
|
||||
S: Service<Request<B>, Response = Response<BoxBody>>,
|
||||
F: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error>,
|
||||
B: Send + Sync + 'static,
|
||||
{
|
||||
type Output = Result<Response<BoxBody>, S::Error>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.project().inner.project() {
|
||||
RouteFutureInnerProj::A { a } => a.poll(cx),
|
||||
RouteFutureInnerProj::B { b } => b.poll(cx),
|
||||
#[allow(warnings)]
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
loop {
|
||||
let mut this = self.as_mut().project();
|
||||
|
||||
let new_state = match this.state.as_mut().project() {
|
||||
RouteFutureInnerProj::A { a, fallback } => {
|
||||
let mut response = ready!(a.poll(cx))?;
|
||||
|
||||
let req = if let Some(ext) =
|
||||
response.extensions_mut().remove::<FromEmptyRouter<B>>()
|
||||
{
|
||||
ext.request
|
||||
} else {
|
||||
return Poll::Ready(Ok(response));
|
||||
};
|
||||
|
||||
RouteFutureInner::B {
|
||||
b: fallback
|
||||
.take()
|
||||
.expect("future polled after completion")
|
||||
.oneshot(req),
|
||||
}
|
||||
}
|
||||
RouteFutureInnerProj::B { b } => return b.poll(cx),
|
||||
};
|
||||
|
||||
this.state.set(new_state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -135,6 +171,7 @@ impl<S, F, B> Future for NestedFuture<S, F, B>
|
|||
where
|
||||
S: Service<Request<B>, Response = Response<BoxBody>>,
|
||||
F: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error>,
|
||||
B: Send + Sync + 'static,
|
||||
{
|
||||
type Output = Result<Response<BoxBody>, S::Error>;
|
||||
|
||||
|
|
|
@ -597,6 +597,7 @@ impl<S, F, B> Service<Request<B>> for Route<S, F>
|
|||
where
|
||||
S: Service<Request<B>, Response = Response<BoxBody>> + Clone,
|
||||
F: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error> + Clone,
|
||||
B: Send + Sync + 'static,
|
||||
{
|
||||
type Response = Response<BoxBody>;
|
||||
type Error = S::Error;
|
||||
|
@ -610,7 +611,7 @@ where
|
|||
if let Some(captures) = self.pattern.full_match(&req) {
|
||||
insert_url_params(&mut req, captures);
|
||||
let fut = self.svc.clone().oneshot(req);
|
||||
RouteFuture::a(fut)
|
||||
RouteFuture::a(fut, self.fallback.clone())
|
||||
} else {
|
||||
let fut = self.fallback.clone().oneshot(req);
|
||||
RouteFuture::b(fut)
|
||||
|
@ -689,12 +690,23 @@ where
|
|||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, request: Request<B>) -> Self::Future {
|
||||
fn call(&mut self, mut request: Request<B>) -> Self::Future {
|
||||
if self.status == StatusCode::METHOD_NOT_ALLOWED {
|
||||
// we're inside a route but there was no method that matched
|
||||
// so record that so we can override the status if no other
|
||||
// routes match
|
||||
request.extensions_mut().insert(NoMethodMatch);
|
||||
}
|
||||
|
||||
if self.status == StatusCode::NOT_FOUND
|
||||
&& request.extensions().get::<NoMethodMatch>().is_some()
|
||||
{
|
||||
self.status = StatusCode::METHOD_NOT_ALLOWED;
|
||||
}
|
||||
|
||||
let mut res = Response::new(crate::body::empty());
|
||||
|
||||
if request.extensions().get::<OrDepth>().is_some() {
|
||||
res.extensions_mut().insert(FromEmptyRouter { request });
|
||||
}
|
||||
res.extensions_mut().insert(FromEmptyRouter { request });
|
||||
|
||||
*res.status_mut() = self.status;
|
||||
EmptyRouterFuture {
|
||||
|
@ -703,6 +715,9 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct NoMethodMatch;
|
||||
|
||||
/// Response extension used by [`EmptyRouter`] to send the request back to [`Or`] so
|
||||
/// the other service can be called.
|
||||
///
|
||||
|
@ -713,39 +728,6 @@ struct FromEmptyRouter<B> {
|
|||
request: Request<B>,
|
||||
}
|
||||
|
||||
/// We need to track whether we're inside an `Or` or not, and only if we then
|
||||
/// should we save the request into the response extensions.
|
||||
///
|
||||
/// This is to work around https://github.com/hyperium/hyper/issues/2621.
|
||||
///
|
||||
/// Since ours can be nested we have to track the depth to know when we're
|
||||
/// leaving the top most `Or`.
|
||||
///
|
||||
/// Hopefully when https://github.com/hyperium/hyper/issues/2621 is resolved we
|
||||
/// can remove this nasty hack.
|
||||
#[derive(Debug)]
|
||||
struct OrDepth(usize);
|
||||
|
||||
impl OrDepth {
|
||||
fn new() -> Self {
|
||||
Self(1)
|
||||
}
|
||||
|
||||
fn increment(&mut self) {
|
||||
self.0 += 1;
|
||||
}
|
||||
|
||||
fn decrement(&mut self) {
|
||||
self.0 -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<usize> for &mut OrDepth {
|
||||
fn eq(&self, other: &usize) -> bool {
|
||||
self.0 == *other
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct PathPattern(Arc<Inner>);
|
||||
|
||||
|
@ -945,6 +927,7 @@ impl<S, F, B> Service<Request<B>> for Nested<S, F>
|
|||
where
|
||||
S: Service<Request<B>, Response = Response<BoxBody>> + Clone,
|
||||
F: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error> + Clone,
|
||||
B: Send + Sync + 'static,
|
||||
{
|
||||
type Response = Response<BoxBody>;
|
||||
type Error = S::Error;
|
||||
|
@ -966,7 +949,7 @@ where
|
|||
|
||||
insert_url_params(&mut req, captures);
|
||||
let fut = self.svc.clone().oneshot(req);
|
||||
RouteFuture::a(fut)
|
||||
RouteFuture::a(fut, self.fallback.clone())
|
||||
} else {
|
||||
let fut = self.fallback.clone().oneshot(req);
|
||||
RouteFuture::b(fut)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
//! [`Or`] used to combine two services into one.
|
||||
|
||||
use super::{FromEmptyRouter, OrDepth};
|
||||
use super::FromEmptyRouter;
|
||||
use crate::body::BoxBody;
|
||||
use futures_util::ready;
|
||||
use http::{Request, Response};
|
||||
|
@ -45,12 +45,6 @@ where
|
|||
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
|
||||
let original_uri = req.uri().clone();
|
||||
|
||||
if let Some(count) = req.extensions_mut().get_mut::<OrDepth>() {
|
||||
count.increment();
|
||||
} else {
|
||||
req.extensions_mut().insert(OrDepth::new());
|
||||
}
|
||||
|
||||
ResponseFuture {
|
||||
state: State::FirstFuture {
|
||||
f: self.first.clone().oneshot(req),
|
||||
|
@ -119,18 +113,6 @@ where
|
|||
|
||||
*req.uri_mut() = this.original_uri.take().unwrap();
|
||||
|
||||
let mut leaving_outermost_or = false;
|
||||
if let Some(depth) = req.extensions_mut().get_mut::<OrDepth>() {
|
||||
if depth == 1 {
|
||||
leaving_outermost_or = true;
|
||||
} else {
|
||||
depth.decrement();
|
||||
}
|
||||
}
|
||||
if leaving_outermost_or {
|
||||
req.extensions_mut().remove::<OrDepth>();
|
||||
}
|
||||
|
||||
let second = this.second.take().expect("future polled after completion");
|
||||
|
||||
State::SecondFuture {
|
||||
|
|
|
@ -632,6 +632,43 @@ async fn handler_into_service() {
|
|||
assert_eq!(res.text().await.unwrap(), "you said: hi there!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn when_multiple_routes_match() {
|
||||
let app = Router::new()
|
||||
.route("/", post(|| async {}))
|
||||
.route("/", get(|| async {}))
|
||||
.route("/foo", get(|| async {}))
|
||||
.nest("/foo", Router::new().route("/bar", get(|| async {})));
|
||||
|
||||
let addr = run_in_background(app).await;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let res = client
|
||||
.post(format!("http://{}", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo/bar", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let res = client
|
||||
.get(format!("http://{}/foo", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
/// Run a `tower::Service` in the background and get a URI for it.
|
||||
pub(crate) async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
|
||||
where
|
||||
|
|
Loading…
Add table
Reference in a new issue