From c7d4af9bd8031f85244319dbd12f1d0698b12432 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 25 Apr 2023 16:01:39 +0200 Subject: [PATCH] Fix fallback panic on CONNECT requests (#1963) --- axum/CHANGELOG.md | 6 +++ axum/src/routing/method_routing.rs | 10 +--- axum/src/routing/mod.rs | 79 ++++++++++++++++++++---------- axum/src/routing/path_router.rs | 20 +++++++- axum/src/routing/tests/mod.rs | 46 ++++++++++++++++- 5 files changed, 123 insertions(+), 38 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index a6d827e3..b7fd2348 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -62,6 +62,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1868]: https://github.com/tokio-rs/axum/pull/1868 [#1956]: https://github.com/tokio-rs/axum/pull/1956 +# 0.6.17 (25. April, 2023) + +- **fixed:** Fix fallbacks causing a panic on `CONNECT` requests ([#1958]) + +[#1958]: https://github.com/tokio-rs/axum/pull/1958 + # 0.6.16 (18. April, 2023) - **fixed:** Don't allow extracting `MatchedPath` in fallbacks ([#1934]) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index dc8f9ec0..e1545a00 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1072,15 +1072,7 @@ where call!(req, method, DELETE, delete); call!(req, method, TRACE, trace); - let future = match fallback { - Fallback::Default(route) | Fallback::Service(route) => { - RouteFuture::from_future(route.oneshot_inner(req)) - } - Fallback::BoxedHandler(handler) => { - let mut route = handler.clone().into_route(state); - RouteFuture::from_future(route.oneshot_inner(req)) - } - }; + let future = fallback.call_with_state(req, state); match allow_header { AllowHeader::None => future.allow_header(Bytes::new()), diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 6a5f6887..eaf489b5 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -62,6 +62,7 @@ pub struct Router { path_router: PathRouter, fallback_router: PathRouter, default_fallback: bool, + catch_all_fallback: Fallback, } impl Clone for Router { @@ -70,6 +71,7 @@ impl Clone for Router { path_router: self.path_router.clone(), fallback_router: self.fallback_router.clone(), default_fallback: self.default_fallback, + catch_all_fallback: self.catch_all_fallback.clone(), } } } @@ -89,6 +91,7 @@ impl fmt::Debug for Router { .field("path_router", &self.path_router) .field("fallback_router", &self.fallback_router) .field("default_fallback", &self.default_fallback) + .field("catch_all_fallback", &self.catch_all_fallback) .finish() } } @@ -106,14 +109,12 @@ where /// Unless you add additional routes this will respond with `404 Not Found` to /// all requests. pub fn new() -> Self { - let mut this = Self { + Self { path_router: Default::default(), - fallback_router: Default::default(), + fallback_router: PathRouter::new_fallback(), default_fallback: true, - }; - this = this.fallback_service(NotFound); - this.default_fallback = true; - this + catch_all_fallback: Fallback::Default(Route::new(NotFound)), + } } #[doc = include_str!("../docs/routing/route.md")] @@ -151,6 +152,10 @@ where path_router, fallback_router, default_fallback, + // we don't need to inherit the catch-all fallback. It is only used for CONNECT + // requests with an empty path. If we were to inherit the catch-all fallback + // it would end up matching `/{path}/*` which doesn't match empty paths. + catch_all_fallback: _, } = router; panic_on_err!(self.path_router.nest(path, path_router)); @@ -184,6 +189,7 @@ where path_router, fallback_router: other_fallback, default_fallback, + catch_all_fallback, } = other.into(); panic_on_err!(self.path_router.merge(path_router)); @@ -208,6 +214,11 @@ where } }; + self.catch_all_fallback = self + .catch_all_fallback + .merge(catch_all_fallback) + .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback")); + self } @@ -222,8 +233,9 @@ where { Router { path_router: self.path_router.layer(layer.clone()), - fallback_router: self.fallback_router.layer(layer), + fallback_router: self.fallback_router.layer(layer.clone()), default_fallback: self.default_fallback, + catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)), } } @@ -241,36 +253,38 @@ where path_router: self.path_router.route_layer(layer), fallback_router: self.fallback_router, default_fallback: self.default_fallback, + catch_all_fallback: self.catch_all_fallback, } } #[track_caller] #[doc = include_str!("../docs/routing/fallback.md")] - pub fn fallback(self, handler: H) -> Self + pub fn fallback(mut self, handler: H) -> Self where H: Handler, T: 'static, { - let endpoint = Endpoint::MethodRouter(any(handler)); - self.fallback_endpoint(endpoint) + self.catch_all_fallback = + Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone())); + self.fallback_endpoint(Endpoint::MethodRouter(any(handler))) } /// Add a fallback [`Service`] to the router. /// /// See [`Router::fallback`] for more details. - pub fn fallback_service(self, service: T) -> Self + pub fn fallback_service(mut self, service: T) -> Self where T: Service + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { - self.fallback_endpoint(Endpoint::Route(Route::new(service))) + let route = Route::new(service); + self.catch_all_fallback = Fallback::Service(route.clone()); + self.fallback_endpoint(Endpoint::Route(route)) } fn fallback_endpoint(mut self, endpoint: Endpoint) -> Self { - self.fallback_router.replace_endpoint("/", endpoint.clone()); - self.fallback_router - .replace_endpoint(&format!("/*{FALLBACK_PARAM}"), endpoint); + self.fallback_router.set_fallback(endpoint); self.default_fallback = false; self } @@ -279,21 +293,24 @@ where pub fn with_state(self, state: S) -> Router { Router { path_router: self.path_router.with_state(state.clone()), - fallback_router: self.fallback_router.with_state(state), + fallback_router: self.fallback_router.with_state(state.clone()), default_fallback: self.default_fallback, + catch_all_fallback: self.catch_all_fallback.with_state(state), } } pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { - match self.path_router.call_with_state(req, state) { - Ok(future) => future, - Err((req, state)) => match self.fallback_router.call_with_state(req, state) { - Ok(future) => future, - Err((_req, _state)) => { - unreachable!("the default fallback added in `Router::new` matches everything") - } - }, - } + let (req, state) = match self.path_router.call_with_state(req, state) { + Ok(future) => return future, + Err((req, state)) => (req, state), + }; + + let (req, state) = match self.fallback_router.call_with_state(req, state) { + Ok(future) => return future, + Err((req, state)) => (req, state), + }; + + self.catch_all_fallback.call_with_state(req, state) } /// Convert the router into a borrowed [`Service`] with a fixed request body type, to aid type @@ -564,6 +581,18 @@ where Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)), } } + + fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { + match self { + Fallback::Default(route) | Fallback::Service(route) => { + RouteFuture::from_future(route.oneshot_inner(req)) + } + Fallback::BoxedHandler(handler) => { + let mut route = handler.clone().into_route(state); + RouteFuture::from_future(route.oneshot_inner(req)) + } + } + } } impl Clone for Fallback { diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index 6baa7943..432066ac 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -6,8 +6,8 @@ use tower_layer::Layer; use tower_service::Service; use super::{ - future::RouteFuture, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route, - RouteId, NEST_TAIL_PARAM, + future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint, + MethodRouter, Route, RouteId, FALLBACK_PARAM, NEST_TAIL_PARAM, }; pub(super) struct PathRouter { @@ -16,6 +16,22 @@ pub(super) struct PathRouter { prev_route_id: RouteId, } +impl PathRouter +where + S: Clone + Send + Sync + 'static, +{ + pub(super) fn new_fallback() -> Self { + let mut this = Self::default(); + this.set_fallback(Endpoint::Route(Route::new(NotFound))); + this + } + + pub(super) fn set_fallback(&mut self, endpoint: Endpoint) { + self.replace_endpoint("/", endpoint.clone()); + self.replace_endpoint(&format!("/*{FALLBACK_PARAM}"), endpoint); + } +} + impl PathRouter where S: Clone + Send + Sync + 'static, diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 40585eed..02296feb 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -16,7 +16,11 @@ use crate::{ }; use axum_core::extract::Request; use futures_util::stream::StreamExt; -use http::{header::ALLOW, header::CONTENT_LENGTH, HeaderMap, StatusCode, Uri}; +use http::{ + header::CONTENT_LENGTH, + header::{ALLOW, HOST}, + HeaderMap, Method, StatusCode, Uri, +}; use serde::Deserialize; use serde_json::json; use std::{ @@ -26,7 +30,9 @@ use std::{ task::{Context, Poll}, time::Duration, }; -use tower::{service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder}; +use tower::{ + service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder, ServiceExt, +}; use tower_http::{limit::RequestBodyLimitLayer, validate_request::ValidateRequestHeaderLayer}; use tower_service::Service; @@ -979,3 +985,39 @@ async fn logging_rejections() { ]) ) } + +// https://github.com/tokio-rs/axum/issues/1955 +#[crate::test] +async fn connect_going_to_custom_fallback() { + let app = Router::new().fallback(|| async { (StatusCode::NOT_FOUND, "custom fallback") }); + + let req = Request::builder() + .uri("example.com:443") + .method(Method::CONNECT) + .header(HOST, "example.com:443") + .body(Body::empty()) + .unwrap(); + + let res = app.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + let text = String::from_utf8(hyper::body::to_bytes(res).await.unwrap().to_vec()).unwrap(); + assert_eq!(text, "custom fallback"); +} + +// https://github.com/tokio-rs/axum/issues/1955 +#[crate::test] +async fn connect_going_to_default_fallback() { + let app = Router::new(); + + let req = Request::builder() + .uri("example.com:443") + .method(Method::CONNECT) + .header(HOST, "example.com:443") + .body(Body::empty()) + .unwrap(); + + let res = app.oneshot(req).await.unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); + let body = hyper::body::to_bytes(res).await.unwrap(); + assert!(body.is_empty()); +}