Fallback to calling next route if no methods match (#224)

This removes a small foot gun from the routing.

This means matching different HTTP methods for the same route that
aren't defined together now works.

So `Router::new().route("/", get(...)).route("/", post(...))` now
accepts both `GET` and `POST`. Previously only `POST` would be accepted.
This commit is contained in:
David Pedersen 2021-08-21 01:00:12 +02:00 committed by GitHub
parent 971c0a394a
commit 0d8f8b7b6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 112 additions and 94 deletions

View file

@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `Headers` for easily customizing headers on a response ([#193](https://github.com/tokio-rs/axum/pull/193))
- Add `Redirect` response ([#192](https://github.com/tokio-rs/axum/pull/192))
- Make `RequestParts::{new, try_into_request}` public ([#194](https://github.com/tokio-rs/axum/pull/194))
- Support matching different HTTP methods for the same route that aren't defined
together. So `Router::new().route("/", get(...)).route("/", post(...))` now
accepts both `GET` and `POST`. Previously only `POST` would be accepted ([#224](https://github.com/tokio-rs/axum/pull/224))
## Breaking changes

View file

@ -203,30 +203,6 @@
//! # }
//! ```
//!
//! ## Matching multiple methods
//!
//! If you want a path to accept multiple HTTP methods you must add them all at
//! once:
//!
//! ```rust,no_run
//! use axum::{
//! Router,
//! handler::{get, post},
//! };
//!
//! // `GET /` and `POST /` are both accepted
//! let app = Router::new().route("/", get(handler).post(handler));
//!
//! // This will _not_ work. Only `POST /` will be accessible.
//! let wont_work = Router::new().route("/", get(handler)).route("/", post(handler));
//!
//! async fn handler() {}
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # axum::Server::bind(&"".parse().unwrap()).serve(wont_work.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! ## Routing to any [`Service`]
//!
//! axum also supports routing to general [`Service`]s:

View file

@ -1,6 +1,7 @@
//! Future types.
use crate::{body::BoxBody, buffer::MpscBuffer};
use crate::{body::BoxBody, buffer::MpscBuffer, routing::FromEmptyRouter};
use futures_util::ready;
use http::{Request, Response};
use pin_project_lite::pin_project;
use std::{
@ -11,7 +12,7 @@ use std::{
};
use tower::{
util::{BoxService, Oneshot},
BoxError, Service,
BoxError, Service, ServiceExt,
};
pub use super::or::ResponseFuture as OrResponseFuture;
@ -68,7 +69,7 @@ pin_project! {
F: Service<Request<B>>
{
#[pin]
inner: RouteFutureInner<S, F, B>,
state: RouteFutureInner<S, F, B>,
}
}
@ -77,15 +78,18 @@ where
S: Service<Request<B>>,
F: Service<Request<B>>,
{
pub(crate) fn a(a: Oneshot<S, Request<B>>) -> Self {
pub(crate) fn a(a: Oneshot<S, Request<B>>, fallback: F) -> Self {
RouteFuture {
inner: RouteFutureInner::A { a },
state: RouteFutureInner::A {
a,
fallback: Some(fallback),
},
}
}
pub(crate) fn b(b: Oneshot<F, Request<B>>) -> Self {
RouteFuture {
inner: RouteFutureInner::B { b },
state: RouteFutureInner::B { b },
}
}
}
@ -98,8 +102,15 @@ pin_project! {
S: Service<Request<B>>,
F: Service<Request<B>>,
{
A { #[pin] a: Oneshot<S, Request<B>> },
B { #[pin] b: Oneshot<F, Request<B>> },
A {
#[pin]
a: Oneshot<S, Request<B>>,
fallback: Option<F>,
},
B {
#[pin]
b: Oneshot<F, Request<B>>
},
}
}
@ -107,13 +118,38 @@ impl<S, F, B> Future for RouteFuture<S, F, B>
where
S: Service<Request<B>, Response = Response<BoxBody>>,
F: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error>,
B: Send + Sync + 'static,
{
type Output = Result<Response<BoxBody>, S::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().inner.project() {
RouteFutureInnerProj::A { a } => a.poll(cx),
RouteFutureInnerProj::B { b } => b.poll(cx),
#[allow(warnings)]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let mut this = self.as_mut().project();
let new_state = match this.state.as_mut().project() {
RouteFutureInnerProj::A { a, fallback } => {
let mut response = ready!(a.poll(cx))?;
let req = if let Some(ext) =
response.extensions_mut().remove::<FromEmptyRouter<B>>()
{
ext.request
} else {
return Poll::Ready(Ok(response));
};
RouteFutureInner::B {
b: fallback
.take()
.expect("future polled after completion")
.oneshot(req),
}
}
RouteFutureInnerProj::B { b } => return b.poll(cx),
};
this.state.set(new_state);
}
}
}
@ -135,6 +171,7 @@ impl<S, F, B> Future for NestedFuture<S, F, B>
where
S: Service<Request<B>, Response = Response<BoxBody>>,
F: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error>,
B: Send + Sync + 'static,
{
type Output = Result<Response<BoxBody>, S::Error>;

View file

@ -597,6 +597,7 @@ impl<S, F, B> Service<Request<B>> for Route<S, F>
where
S: Service<Request<B>, Response = Response<BoxBody>> + Clone,
F: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error> + Clone,
B: Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = S::Error;
@ -610,7 +611,7 @@ where
if let Some(captures) = self.pattern.full_match(&req) {
insert_url_params(&mut req, captures);
let fut = self.svc.clone().oneshot(req);
RouteFuture::a(fut)
RouteFuture::a(fut, self.fallback.clone())
} else {
let fut = self.fallback.clone().oneshot(req);
RouteFuture::b(fut)
@ -689,12 +690,23 @@ where
Poll::Ready(Ok(()))
}
fn call(&mut self, request: Request<B>) -> Self::Future {
fn call(&mut self, mut request: Request<B>) -> Self::Future {
if self.status == StatusCode::METHOD_NOT_ALLOWED {
// we're inside a route but there was no method that matched
// so record that so we can override the status if no other
// routes match
request.extensions_mut().insert(NoMethodMatch);
}
if self.status == StatusCode::NOT_FOUND
&& request.extensions().get::<NoMethodMatch>().is_some()
{
self.status = StatusCode::METHOD_NOT_ALLOWED;
}
let mut res = Response::new(crate::body::empty());
if request.extensions().get::<OrDepth>().is_some() {
res.extensions_mut().insert(FromEmptyRouter { request });
}
res.extensions_mut().insert(FromEmptyRouter { request });
*res.status_mut() = self.status;
EmptyRouterFuture {
@ -703,6 +715,9 @@ where
}
}
#[derive(Clone, Copy)]
struct NoMethodMatch;
/// Response extension used by [`EmptyRouter`] to send the request back to [`Or`] so
/// the other service can be called.
///
@ -713,39 +728,6 @@ struct FromEmptyRouter<B> {
request: Request<B>,
}
/// We need to track whether we're inside an `Or` or not, and only if we then
/// should we save the request into the response extensions.
///
/// This is to work around https://github.com/hyperium/hyper/issues/2621.
///
/// Since ours can be nested we have to track the depth to know when we're
/// leaving the top most `Or`.
///
/// Hopefully when https://github.com/hyperium/hyper/issues/2621 is resolved we
/// can remove this nasty hack.
#[derive(Debug)]
struct OrDepth(usize);
impl OrDepth {
fn new() -> Self {
Self(1)
}
fn increment(&mut self) {
self.0 += 1;
}
fn decrement(&mut self) {
self.0 -= 1;
}
}
impl PartialEq<usize> for &mut OrDepth {
fn eq(&self, other: &usize) -> bool {
self.0 == *other
}
}
#[derive(Debug, Clone)]
pub(crate) struct PathPattern(Arc<Inner>);
@ -945,6 +927,7 @@ impl<S, F, B> Service<Request<B>> for Nested<S, F>
where
S: Service<Request<B>, Response = Response<BoxBody>> + Clone,
F: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error> + Clone,
B: Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = S::Error;
@ -966,7 +949,7 @@ where
insert_url_params(&mut req, captures);
let fut = self.svc.clone().oneshot(req);
RouteFuture::a(fut)
RouteFuture::a(fut, self.fallback.clone())
} else {
let fut = self.fallback.clone().oneshot(req);
RouteFuture::b(fut)

View file

@ -1,6 +1,6 @@
//! [`Or`] used to combine two services into one.
use super::{FromEmptyRouter, OrDepth};
use super::FromEmptyRouter;
use crate::body::BoxBody;
use futures_util::ready;
use http::{Request, Response};
@ -45,12 +45,6 @@ where
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
let original_uri = req.uri().clone();
if let Some(count) = req.extensions_mut().get_mut::<OrDepth>() {
count.increment();
} else {
req.extensions_mut().insert(OrDepth::new());
}
ResponseFuture {
state: State::FirstFuture {
f: self.first.clone().oneshot(req),
@ -119,18 +113,6 @@ where
*req.uri_mut() = this.original_uri.take().unwrap();
let mut leaving_outermost_or = false;
if let Some(depth) = req.extensions_mut().get_mut::<OrDepth>() {
if depth == 1 {
leaving_outermost_or = true;
} else {
depth.decrement();
}
}
if leaving_outermost_or {
req.extensions_mut().remove::<OrDepth>();
}
let second = this.second.take().expect("future polled after completion");
State::SecondFuture {

View file

@ -632,6 +632,43 @@ async fn handler_into_service() {
assert_eq!(res.text().await.unwrap(), "you said: hi there!");
}
#[tokio::test]
async fn when_multiple_routes_match() {
let app = Router::new()
.route("/", post(|| async {}))
.route("/", get(|| async {}))
.route("/foo", get(|| async {}))
.nest("/foo", Router::new().route("/bar", get(|| async {})));
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = client
.post(format!("http://{}", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = client
.get(format!("http://{}/foo/bar", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = client
.get(format!("http://{}/foo", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
/// Run a `tower::Service` in the background and get a URI for it.
pub(crate) async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
where