mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-04 02:01:23 +01:00
Fix fallback panic on CONNECT requests (#1963)
This commit is contained in:
parent
d1765d9a00
commit
c7d4af9bd8
5 changed files with 123 additions and 38 deletions
|
@ -62,6 +62,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
[#1868]: https://github.com/tokio-rs/axum/pull/1868
|
[#1868]: https://github.com/tokio-rs/axum/pull/1868
|
||||||
[#1956]: https://github.com/tokio-rs/axum/pull/1956
|
[#1956]: https://github.com/tokio-rs/axum/pull/1956
|
||||||
|
|
||||||
|
# 0.6.17 (25. April, 2023)
|
||||||
|
|
||||||
|
- **fixed:** Fix fallbacks causing a panic on `CONNECT` requests ([#1958])
|
||||||
|
|
||||||
|
[#1958]: https://github.com/tokio-rs/axum/pull/1958
|
||||||
|
|
||||||
# 0.6.16 (18. April, 2023)
|
# 0.6.16 (18. April, 2023)
|
||||||
|
|
||||||
- **fixed:** Don't allow extracting `MatchedPath` in fallbacks ([#1934])
|
- **fixed:** Don't allow extracting `MatchedPath` in fallbacks ([#1934])
|
||||||
|
|
|
@ -1072,15 +1072,7 @@ where
|
||||||
call!(req, method, DELETE, delete);
|
call!(req, method, DELETE, delete);
|
||||||
call!(req, method, TRACE, trace);
|
call!(req, method, TRACE, trace);
|
||||||
|
|
||||||
let future = match fallback {
|
let future = fallback.call_with_state(req, state);
|
||||||
Fallback::Default(route) | Fallback::Service(route) => {
|
|
||||||
RouteFuture::from_future(route.oneshot_inner(req))
|
|
||||||
}
|
|
||||||
Fallback::BoxedHandler(handler) => {
|
|
||||||
let mut route = handler.clone().into_route(state);
|
|
||||||
RouteFuture::from_future(route.oneshot_inner(req))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
match allow_header {
|
match allow_header {
|
||||||
AllowHeader::None => future.allow_header(Bytes::new()),
|
AllowHeader::None => future.allow_header(Bytes::new()),
|
||||||
|
|
|
@ -62,6 +62,7 @@ pub struct Router<S = ()> {
|
||||||
path_router: PathRouter<S, false>,
|
path_router: PathRouter<S, false>,
|
||||||
fallback_router: PathRouter<S, true>,
|
fallback_router: PathRouter<S, true>,
|
||||||
default_fallback: bool,
|
default_fallback: bool,
|
||||||
|
catch_all_fallback: Fallback<S>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> Clone for Router<S> {
|
impl<S> Clone for Router<S> {
|
||||||
|
@ -70,6 +71,7 @@ impl<S> Clone for Router<S> {
|
||||||
path_router: self.path_router.clone(),
|
path_router: self.path_router.clone(),
|
||||||
fallback_router: self.fallback_router.clone(),
|
fallback_router: self.fallback_router.clone(),
|
||||||
default_fallback: self.default_fallback,
|
default_fallback: self.default_fallback,
|
||||||
|
catch_all_fallback: self.catch_all_fallback.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -89,6 +91,7 @@ impl<S> fmt::Debug for Router<S> {
|
||||||
.field("path_router", &self.path_router)
|
.field("path_router", &self.path_router)
|
||||||
.field("fallback_router", &self.fallback_router)
|
.field("fallback_router", &self.fallback_router)
|
||||||
.field("default_fallback", &self.default_fallback)
|
.field("default_fallback", &self.default_fallback)
|
||||||
|
.field("catch_all_fallback", &self.catch_all_fallback)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -106,14 +109,12 @@ where
|
||||||
/// Unless you add additional routes this will respond with `404 Not Found` to
|
/// Unless you add additional routes this will respond with `404 Not Found` to
|
||||||
/// all requests.
|
/// all requests.
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let mut this = Self {
|
Self {
|
||||||
path_router: Default::default(),
|
path_router: Default::default(),
|
||||||
fallback_router: Default::default(),
|
fallback_router: PathRouter::new_fallback(),
|
||||||
default_fallback: true,
|
default_fallback: true,
|
||||||
};
|
catch_all_fallback: Fallback::Default(Route::new(NotFound)),
|
||||||
this = this.fallback_service(NotFound);
|
}
|
||||||
this.default_fallback = true;
|
|
||||||
this
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc = include_str!("../docs/routing/route.md")]
|
#[doc = include_str!("../docs/routing/route.md")]
|
||||||
|
@ -151,6 +152,10 @@ where
|
||||||
path_router,
|
path_router,
|
||||||
fallback_router,
|
fallback_router,
|
||||||
default_fallback,
|
default_fallback,
|
||||||
|
// we don't need to inherit the catch-all fallback. It is only used for CONNECT
|
||||||
|
// requests with an empty path. If we were to inherit the catch-all fallback
|
||||||
|
// it would end up matching `/{path}/*` which doesn't match empty paths.
|
||||||
|
catch_all_fallback: _,
|
||||||
} = router;
|
} = router;
|
||||||
|
|
||||||
panic_on_err!(self.path_router.nest(path, path_router));
|
panic_on_err!(self.path_router.nest(path, path_router));
|
||||||
|
@ -184,6 +189,7 @@ where
|
||||||
path_router,
|
path_router,
|
||||||
fallback_router: other_fallback,
|
fallback_router: other_fallback,
|
||||||
default_fallback,
|
default_fallback,
|
||||||
|
catch_all_fallback,
|
||||||
} = other.into();
|
} = other.into();
|
||||||
|
|
||||||
panic_on_err!(self.path_router.merge(path_router));
|
panic_on_err!(self.path_router.merge(path_router));
|
||||||
|
@ -208,6 +214,11 @@ where
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
self.catch_all_fallback = self
|
||||||
|
.catch_all_fallback
|
||||||
|
.merge(catch_all_fallback)
|
||||||
|
.unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback"));
|
||||||
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -222,8 +233,9 @@ where
|
||||||
{
|
{
|
||||||
Router {
|
Router {
|
||||||
path_router: self.path_router.layer(layer.clone()),
|
path_router: self.path_router.layer(layer.clone()),
|
||||||
fallback_router: self.fallback_router.layer(layer),
|
fallback_router: self.fallback_router.layer(layer.clone()),
|
||||||
default_fallback: self.default_fallback,
|
default_fallback: self.default_fallback,
|
||||||
|
catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,36 +253,38 @@ where
|
||||||
path_router: self.path_router.route_layer(layer),
|
path_router: self.path_router.route_layer(layer),
|
||||||
fallback_router: self.fallback_router,
|
fallback_router: self.fallback_router,
|
||||||
default_fallback: self.default_fallback,
|
default_fallback: self.default_fallback,
|
||||||
|
catch_all_fallback: self.catch_all_fallback,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[track_caller]
|
#[track_caller]
|
||||||
#[doc = include_str!("../docs/routing/fallback.md")]
|
#[doc = include_str!("../docs/routing/fallback.md")]
|
||||||
pub fn fallback<H, T>(self, handler: H) -> Self
|
pub fn fallback<H, T>(mut self, handler: H) -> Self
|
||||||
where
|
where
|
||||||
H: Handler<T, S>,
|
H: Handler<T, S>,
|
||||||
T: 'static,
|
T: 'static,
|
||||||
{
|
{
|
||||||
let endpoint = Endpoint::MethodRouter(any(handler));
|
self.catch_all_fallback =
|
||||||
self.fallback_endpoint(endpoint)
|
Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone()));
|
||||||
|
self.fallback_endpoint(Endpoint::MethodRouter(any(handler)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a fallback [`Service`] to the router.
|
/// Add a fallback [`Service`] to the router.
|
||||||
///
|
///
|
||||||
/// See [`Router::fallback`] for more details.
|
/// See [`Router::fallback`] for more details.
|
||||||
pub fn fallback_service<T>(self, service: T) -> Self
|
pub fn fallback_service<T>(mut self, service: T) -> Self
|
||||||
where
|
where
|
||||||
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
|
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
|
||||||
T::Response: IntoResponse,
|
T::Response: IntoResponse,
|
||||||
T::Future: Send + 'static,
|
T::Future: Send + 'static,
|
||||||
{
|
{
|
||||||
self.fallback_endpoint(Endpoint::Route(Route::new(service)))
|
let route = Route::new(service);
|
||||||
|
self.catch_all_fallback = Fallback::Service(route.clone());
|
||||||
|
self.fallback_endpoint(Endpoint::Route(route))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fallback_endpoint(mut self, endpoint: Endpoint<S>) -> Self {
|
fn fallback_endpoint(mut self, endpoint: Endpoint<S>) -> Self {
|
||||||
self.fallback_router.replace_endpoint("/", endpoint.clone());
|
self.fallback_router.set_fallback(endpoint);
|
||||||
self.fallback_router
|
|
||||||
.replace_endpoint(&format!("/*{FALLBACK_PARAM}"), endpoint);
|
|
||||||
self.default_fallback = false;
|
self.default_fallback = false;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
@ -279,21 +293,24 @@ where
|
||||||
pub fn with_state<S2>(self, state: S) -> Router<S2> {
|
pub fn with_state<S2>(self, state: S) -> Router<S2> {
|
||||||
Router {
|
Router {
|
||||||
path_router: self.path_router.with_state(state.clone()),
|
path_router: self.path_router.with_state(state.clone()),
|
||||||
fallback_router: self.fallback_router.with_state(state),
|
fallback_router: self.fallback_router.with_state(state.clone()),
|
||||||
default_fallback: self.default_fallback,
|
default_fallback: self.default_fallback,
|
||||||
|
catch_all_fallback: self.catch_all_fallback.with_state(state),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<Infallible> {
|
pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<Infallible> {
|
||||||
match self.path_router.call_with_state(req, state) {
|
let (req, state) = match self.path_router.call_with_state(req, state) {
|
||||||
Ok(future) => future,
|
Ok(future) => return future,
|
||||||
Err((req, state)) => match self.fallback_router.call_with_state(req, state) {
|
Err((req, state)) => (req, state),
|
||||||
Ok(future) => future,
|
};
|
||||||
Err((_req, _state)) => {
|
|
||||||
unreachable!("the default fallback added in `Router::new` matches everything")
|
let (req, state) = match self.fallback_router.call_with_state(req, state) {
|
||||||
}
|
Ok(future) => return future,
|
||||||
},
|
Err((req, state)) => (req, state),
|
||||||
}
|
};
|
||||||
|
|
||||||
|
self.catch_all_fallback.call_with_state(req, state)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert the router into a borrowed [`Service`] with a fixed request body type, to aid type
|
/// Convert the router into a borrowed [`Service`] with a fixed request body type, to aid type
|
||||||
|
@ -564,6 +581,18 @@ where
|
||||||
Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)),
|
Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<E> {
|
||||||
|
match self {
|
||||||
|
Fallback::Default(route) | Fallback::Service(route) => {
|
||||||
|
RouteFuture::from_future(route.oneshot_inner(req))
|
||||||
|
}
|
||||||
|
Fallback::BoxedHandler(handler) => {
|
||||||
|
let mut route = handler.clone().into_route(state);
|
||||||
|
RouteFuture::from_future(route.oneshot_inner(req))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, E> Clone for Fallback<S, E> {
|
impl<S, E> Clone for Fallback<S, E> {
|
||||||
|
|
|
@ -6,8 +6,8 @@ use tower_layer::Layer;
|
||||||
use tower_service::Service;
|
use tower_service::Service;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
future::RouteFuture, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route,
|
future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
|
||||||
RouteId, NEST_TAIL_PARAM,
|
MethodRouter, Route, RouteId, FALLBACK_PARAM, NEST_TAIL_PARAM,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
|
pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
|
||||||
|
@ -16,6 +16,22 @@ pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
|
||||||
prev_route_id: RouteId,
|
prev_route_id: RouteId,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<S> PathRouter<S, true>
|
||||||
|
where
|
||||||
|
S: Clone + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
pub(super) fn new_fallback() -> Self {
|
||||||
|
let mut this = Self::default();
|
||||||
|
this.set_fallback(Endpoint::Route(Route::new(NotFound)));
|
||||||
|
this
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S>) {
|
||||||
|
self.replace_endpoint("/", endpoint.clone());
|
||||||
|
self.replace_endpoint(&format!("/*{FALLBACK_PARAM}"), endpoint);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<S, const IS_FALLBACK: bool> PathRouter<S, IS_FALLBACK>
|
impl<S, const IS_FALLBACK: bool> PathRouter<S, IS_FALLBACK>
|
||||||
where
|
where
|
||||||
S: Clone + Send + Sync + 'static,
|
S: Clone + Send + Sync + 'static,
|
||||||
|
|
|
@ -16,7 +16,11 @@ use crate::{
|
||||||
};
|
};
|
||||||
use axum_core::extract::Request;
|
use axum_core::extract::Request;
|
||||||
use futures_util::stream::StreamExt;
|
use futures_util::stream::StreamExt;
|
||||||
use http::{header::ALLOW, header::CONTENT_LENGTH, HeaderMap, StatusCode, Uri};
|
use http::{
|
||||||
|
header::CONTENT_LENGTH,
|
||||||
|
header::{ALLOW, HOST},
|
||||||
|
HeaderMap, Method, StatusCode, Uri,
|
||||||
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::{
|
use std::{
|
||||||
|
@ -26,7 +30,9 @@ use std::{
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
use tower::{service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder};
|
use tower::{
|
||||||
|
service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder, ServiceExt,
|
||||||
|
};
|
||||||
use tower_http::{limit::RequestBodyLimitLayer, validate_request::ValidateRequestHeaderLayer};
|
use tower_http::{limit::RequestBodyLimitLayer, validate_request::ValidateRequestHeaderLayer};
|
||||||
use tower_service::Service;
|
use tower_service::Service;
|
||||||
|
|
||||||
|
@ -979,3 +985,39 @@ async fn logging_rejections() {
|
||||||
])
|
])
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/tokio-rs/axum/issues/1955
|
||||||
|
#[crate::test]
|
||||||
|
async fn connect_going_to_custom_fallback() {
|
||||||
|
let app = Router::new().fallback(|| async { (StatusCode::NOT_FOUND, "custom fallback") });
|
||||||
|
|
||||||
|
let req = Request::builder()
|
||||||
|
.uri("example.com:443")
|
||||||
|
.method(Method::CONNECT)
|
||||||
|
.header(HOST, "example.com:443")
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let res = app.oneshot(req).await.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||||
|
let text = String::from_utf8(hyper::body::to_bytes(res).await.unwrap().to_vec()).unwrap();
|
||||||
|
assert_eq!(text, "custom fallback");
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/tokio-rs/axum/issues/1955
|
||||||
|
#[crate::test]
|
||||||
|
async fn connect_going_to_default_fallback() {
|
||||||
|
let app = Router::new();
|
||||||
|
|
||||||
|
let req = Request::builder()
|
||||||
|
.uri("example.com:443")
|
||||||
|
.method(Method::CONNECT)
|
||||||
|
.header(HOST, "example.com:443")
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let res = app.oneshot(req).await.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||||
|
let body = hyper::body::to_bytes(res).await.unwrap();
|
||||||
|
assert!(body.is_empty());
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue