mirror of
https://github.com/tokio-rs/axum.git
synced 2024-10-24 01:46:51 +02:00
Work around for http2 hang with Or
(#199)
This is a nasty hack that works around https://github.com/hyperium/hyper/issues/2621. Fixes https://github.com/tokio-rs/axum/issues/191
This commit is contained in:
parent
97c140cdf7
commit
93cdfe8c5f
4 changed files with 175 additions and 5 deletions
|
@ -16,7 +16,7 @@ async fn main() {
|
||||||
tracing_subscriber::fmt::init();
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
// build our application with a route
|
// build our application with a route
|
||||||
let app = route("/", get(handler));
|
let app = route("/foo", get(handler)).or(route("/bar", get(handler)));
|
||||||
|
|
||||||
// run it
|
// run it
|
||||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||||
|
|
|
@ -547,7 +547,11 @@ where
|
||||||
|
|
||||||
fn call(&mut self, request: Request<B>) -> Self::Future {
|
fn call(&mut self, request: Request<B>) -> Self::Future {
|
||||||
let mut res = Response::new(crate::body::empty());
|
let mut res = Response::new(crate::body::empty());
|
||||||
res.extensions_mut().insert(FromEmptyRouter { request });
|
|
||||||
|
if request.extensions().get::<OrDepth>().is_some() {
|
||||||
|
res.extensions_mut().insert(FromEmptyRouter { request });
|
||||||
|
}
|
||||||
|
|
||||||
*res.status_mut() = self.status;
|
*res.status_mut() = self.status;
|
||||||
EmptyRouterFuture {
|
EmptyRouterFuture {
|
||||||
future: futures_util::future::ok(res),
|
future: futures_util::future::ok(res),
|
||||||
|
@ -565,6 +569,39 @@ struct FromEmptyRouter<B> {
|
||||||
request: Request<B>,
|
request: Request<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// We need to track whether we're inside an `Or` or not, and only if we then
|
||||||
|
/// should we save the request into the response extensions.
|
||||||
|
///
|
||||||
|
/// This is to work around https://github.com/hyperium/hyper/issues/2621.
|
||||||
|
///
|
||||||
|
/// Since ours can be nested we have to track the depth to know when we're
|
||||||
|
/// leaving the top most `Or`.
|
||||||
|
///
|
||||||
|
/// Hopefully when https://github.com/hyperium/hyper/issues/2621 is resolved we
|
||||||
|
/// can remove this nasty hack.
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct OrDepth(usize);
|
||||||
|
|
||||||
|
impl OrDepth {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn increment(&mut self) {
|
||||||
|
self.0 += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decrement(&mut self) {
|
||||||
|
self.0 -= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq<usize> for &mut OrDepth {
|
||||||
|
fn eq(&self, other: &usize) -> bool {
|
||||||
|
self.0 == *other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct PathPattern(Arc<Inner>);
|
pub(crate) struct PathPattern(Arc<Inner>);
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
//! [`Or`] used to combine two services into one.
|
//! [`Or`] used to combine two services into one.
|
||||||
|
|
||||||
use super::{FromEmptyRouter, RoutingDsl};
|
use super::{FromEmptyRouter, OrDepth, RoutingDsl};
|
||||||
use crate::body::BoxBody;
|
use crate::body::BoxBody;
|
||||||
use futures_util::ready;
|
use futures_util::ready;
|
||||||
use http::{Request, Response};
|
use http::{Request, Response};
|
||||||
|
@ -46,7 +46,13 @@ where
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
|
||||||
|
if let Some(count) = req.extensions_mut().get_mut::<OrDepth>() {
|
||||||
|
count.increment();
|
||||||
|
} else {
|
||||||
|
req.extensions_mut().insert(OrDepth::new());
|
||||||
|
}
|
||||||
|
|
||||||
ResponseFuture {
|
ResponseFuture {
|
||||||
state: State::FirstFuture {
|
state: State::FirstFuture {
|
||||||
f: self.first.clone().oneshot(req),
|
f: self.first.clone().oneshot(req),
|
||||||
|
@ -100,7 +106,7 @@ where
|
||||||
StateProj::FirstFuture { f } => {
|
StateProj::FirstFuture { f } => {
|
||||||
let mut response = ready!(f.poll(cx)?);
|
let mut response = ready!(f.poll(cx)?);
|
||||||
|
|
||||||
let req = if let Some(ext) = response
|
let mut req = if let Some(ext) = response
|
||||||
.extensions_mut()
|
.extensions_mut()
|
||||||
.remove::<FromEmptyRouter<ReqBody>>()
|
.remove::<FromEmptyRouter<ReqBody>>()
|
||||||
{
|
{
|
||||||
|
@ -109,6 +115,18 @@ where
|
||||||
return Poll::Ready(Ok(response));
|
return Poll::Ready(Ok(response));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut leaving_outermost_or = false;
|
||||||
|
if let Some(depth) = req.extensions_mut().get_mut::<OrDepth>() {
|
||||||
|
if depth == 1 {
|
||||||
|
leaving_outermost_or = true;
|
||||||
|
} else {
|
||||||
|
depth.decrement();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if leaving_outermost_or {
|
||||||
|
req.extensions_mut().remove::<OrDepth>();
|
||||||
|
}
|
||||||
|
|
||||||
let second = this.second.take().expect("future polled after completion");
|
let second = this.second.take().expect("future polled after completion");
|
||||||
|
|
||||||
State::SecondFuture {
|
State::SecondFuture {
|
||||||
|
|
115
src/tests/or.rs
115
src/tests/or.rs
|
@ -41,6 +41,121 @@ async fn basic() {
|
||||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn multiple_ors_balanced_differently() {
|
||||||
|
let one = route("/one", get(|| async { "one" }));
|
||||||
|
let two = route("/two", get(|| async { "two" }));
|
||||||
|
let three = route("/three", get(|| async { "three" }));
|
||||||
|
let four = route("/four", get(|| async { "four" }));
|
||||||
|
|
||||||
|
test(
|
||||||
|
"one",
|
||||||
|
one.clone()
|
||||||
|
.or(two.clone())
|
||||||
|
.or(three.clone())
|
||||||
|
.or(four.clone()),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
test(
|
||||||
|
"two",
|
||||||
|
one.clone()
|
||||||
|
.or(two.clone())
|
||||||
|
.or(three.clone().or(four.clone())),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
test(
|
||||||
|
"three",
|
||||||
|
one.clone()
|
||||||
|
.or(two.clone().or(three.clone()).or(four.clone())),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
test("four", one.or(two.or(three.or(four)))).await;
|
||||||
|
|
||||||
|
async fn test<S, ResBody>(name: &str, app: S)
|
||||||
|
where
|
||||||
|
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
|
||||||
|
ResBody: http_body::Body + Send + 'static,
|
||||||
|
ResBody::Data: Send,
|
||||||
|
ResBody::Error: Into<BoxError>,
|
||||||
|
S::Future: Send,
|
||||||
|
S::Error: Into<BoxError>,
|
||||||
|
{
|
||||||
|
let addr = run_in_background(app).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
for n in ["one", "two", "three", "four"].iter() {
|
||||||
|
println!("running: {} / {}", name, n);
|
||||||
|
let res = client
|
||||||
|
.get(format!("http://{}/{}", addr, n))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
|
assert_eq!(res.text().await.unwrap(), *n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn or_nested_inside_other_thing() {
|
||||||
|
let inner = route("/bar", get(|| async {})).or(route("/baz", get(|| async {})));
|
||||||
|
let app = nest("/foo", inner);
|
||||||
|
|
||||||
|
let addr = run_in_background(app).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.get(format!("http://{}/foo/bar", addr))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.get(format!("http://{}/foo/baz", addr))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn or_with_route_following() {
|
||||||
|
let one = route("/one", get(|| async { "one" }));
|
||||||
|
let two = route("/two", get(|| async { "two" }));
|
||||||
|
let app = one.or(two).route("/three", get(|| async { "three" }));
|
||||||
|
|
||||||
|
let addr = run_in_background(app).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.get(format!("http://{}/one", addr))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.get(format!("http://{}/two", addr))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.get(format!("http://{}/three", addr))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn layer() {
|
async fn layer() {
|
||||||
let one = route("/foo", get(|| async {}));
|
let one = route("/foo", get(|| async {}));
|
||||||
|
|
Loading…
Reference in a new issue