Implement nesting as regular wildcard routes (#426)

When fixing bugs with `MatchedPath` (introduced to fix https://github.com/tokio-rs/axum/issues/386) I realized nesting could basically be implemented using regular routes, if we can detect that the service passed to `nest` is in fact a `Router`. Then we can transfer over all its routes and add the prefix.

This makes nesting much simpler in general and should also be slightly faster since we're no longer nesting routers.
This commit is contained in:
David Pedersen 2021-10-26 22:31:22 +02:00 committed by GitHub
parent fc9bfb8a50
commit 92f96a201c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 223 additions and 153 deletions

View file

@ -6,7 +6,7 @@ use crate::{
body::{box_body, Body, BoxBody}, body::{box_body, Body, BoxBody},
extract::{ extract::{
connect_info::{Connected, IntoMakeServiceWithConnectInfo}, connect_info::{Connected, IntoMakeServiceWithConnectInfo},
OriginalUri, MatchedPath, OriginalUri,
}, },
util::{ByteStr, PercentDecodedByteStr}, util::{ByteStr, PercentDecodedByteStr},
BoxError, BoxError,
@ -33,9 +33,9 @@ pub mod service_method_router;
mod into_make_service; mod into_make_service;
mod method_filter; mod method_filter;
mod method_not_allowed; mod method_not_allowed;
mod nested;
mod not_found; mod not_found;
mod route; mod route;
mod strip_prefix;
pub(crate) use self::method_not_allowed::MethodNotAllowed; pub(crate) use self::method_not_allowed::MethodNotAllowed;
pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route}; pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route};
@ -92,7 +92,8 @@ impl<B> fmt::Debug for Router<B> {
} }
} }
const NEST_TAIL_PARAM: &str = "__axum_internal_nest_capture"; pub(crate) const NEST_TAIL_PARAM: &str = "axum_nest";
const NEST_TAIL_PARAM_CAPTURE: &str = "/*axum_nest";
impl<B> Router<B> impl<B> Router<B>
where where
@ -342,20 +343,43 @@ where
panic!("Invalid route: nested routes cannot contain wildcards (*)"); panic!("Invalid route: nested routes cannot contain wildcards (*)");
} }
let id = RouteId::next(); let prefix = path;
let path = if path == "/" { match try_downcast::<Router<B>, _>(svc) {
format!("/*{}", NEST_TAIL_PARAM) // if the user is nesting a `Router` we can implement nesting
} else { // by simplying copying all the routes and adding the prefix in
format!("{}/*{}", path, NEST_TAIL_PARAM) // front
}; Ok(router) => {
let Router {
mut routes,
node,
fallback: _,
} = router;
if let Err(err) = self.node.insert(path, id) { for (id, nested_path) in node.paths {
panic!("Invalid route: {}", err); let route = routes.remove(&id).unwrap();
let full_path = if &*nested_path == "/" {
path.to_string()
} else {
format!("{}{}", path, nested_path)
};
self = self.route(&full_path, strip_prefix::StripPrefix::new(route, prefix));
}
debug_assert!(routes.is_empty());
}
// otherwise we add a wildcard route to the service
Err(svc) => {
let path = if path == "/" {
format!("/*{}", NEST_TAIL_PARAM)
} else {
format!("{}/*{}", path, NEST_TAIL_PARAM)
};
self = self.route(&path, strip_prefix::StripPrefix::new(svc, prefix));
}
} }
self.routes.insert(id, Route::new(nested::Nested { svc }));
self self
} }
@ -700,31 +724,8 @@ where
/// ## When used with `Router::nest` /// ## When used with `Router::nest`
/// ///
/// If a router with a fallback is nested inside another router the fallback /// If a router with a fallback is nested inside another router the fallback
/// will only apply to requests that matches the prefix: /// of the nested router will be discarded and not used. This is such that
/// /// the outer router's fallback takes precedence.
/// ```rust
/// use axum::{
/// Router,
/// routing::get,
/// handler::Handler,
/// response::IntoResponse,
/// http::{StatusCode, Uri},
/// };
///
/// let api = Router::new()
/// .route("/", get(|| async { /* ... */ }))
/// .fallback(api_fallback.into_service());
///
/// let app = Router::new().nest("/api", api);
///
/// async fn api_fallback() -> impl IntoResponse { /* ... */ }
///
/// // `api_fallback` will be called for `/api/some-unknown-path` but not for
/// // `/some-unknown-path` as the path doesn't start with `/api`
/// # async {
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn fallback<T>(mut self, svc: T) -> Self pub fn fallback<T>(mut self, svc: T) -> Self
where where
T: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible> T: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible>
@ -743,8 +744,22 @@ where
req.extensions_mut().insert(id); req.extensions_mut().insert(id);
if let Some(matched_path) = self.node.paths.get(&id) { if let Some(matched_path) = self.node.paths.get(&id) {
req.extensions_mut() let matched_path = if let Some(previous) = req.extensions_mut().get::<MatchedPath>() {
.insert(crate::extract::MatchedPath(matched_path.clone())); // a previous `MatchedPath` might exist if we're inside a nested Router
let previous = if let Some(previous) =
previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE)
{
previous
} else {
previous.as_str()
};
let matched_path = format!("{}{}", previous, matched_path);
matched_path.into()
} else {
Arc::clone(matched_path)
};
req.extensions_mut().insert(MatchedPath(matched_path));
} }
let params = match_ let params = match_
@ -754,11 +769,6 @@ where
.map(|(key, value)| (key.to_string(), value.to_string())) .map(|(key, value)| (key.to_string(), value.to_string()))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if let Some(tail) = match_.params.get(NEST_TAIL_PARAM) {
req.extensions_mut()
.insert(nested::NestMatchTail(tail.to_string()));
}
insert_url_params(&mut req, params); insert_url_params(&mut req, params);
let route = self let route = self
@ -970,14 +980,22 @@ impl<B> Fallback<B> {
} }
} }
#[cfg(test)] fn try_downcast<T, K>(k: K) -> Result<T, K>
mod tests { where
use super::*; T: 'static,
K: Send + 'static,
{
use std::any::Any;
#[test] let k = Box::new(k) as Box<dyn Any + Send + 'static>;
fn traits() { match k.downcast() {
use crate::tests::*; Ok(t) => Ok(*t),
Err(other) => Err(*other.downcast().unwrap()),
assert_send::<Router<()>>();
} }
} }
#[test]
fn traits() {
use crate::tests::*;
assert_send::<Router<()>>();
}

View file

@ -1,69 +0,0 @@
use crate::body::BoxBody;
use http::{Request, Response, Uri};
use std::{
convert::Infallible,
task::{Context, Poll},
};
use tower::util::Oneshot;
use tower::ServiceExt;
use tower_service::Service;
/// A [`Service`] that has been nested inside a router at some path.
///
/// Created with [`Router::nest`].
#[derive(Debug, Clone)]
pub(super) struct Nested<S> {
pub(super) svc: S,
}
impl<B, S> Service<Request<B>> for Nested<S>
where
S: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible> + Clone,
B: Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = Oneshot<S, Request<B>>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
// strip the prefix from the URI just before calling the inner service
// such that any surrounding middleware still see the full path
if let Some(tail) = req.extensions_mut().remove::<NestMatchTail>() {
UriStack::push(&mut req);
let new_uri = super::with_path(req.uri(), &tail.0);
*req.uri_mut() = new_uri;
}
self.svc.clone().oneshot(req)
}
}
pub(crate) struct UriStack(Vec<Uri>);
impl UriStack {
fn push<B>(req: &mut Request<B>) {
let uri = req.uri().clone();
if let Some(stack) = req.extensions_mut().get_mut::<Self>() {
stack.0.push(uri);
} else {
req.extensions_mut().insert(Self(vec![uri]));
}
}
}
#[derive(Clone)]
pub(super) struct NestMatchTail(pub(super) String);
#[test]
fn traits() {
use crate::tests::*;
assert_send::<Nested<()>>();
assert_sync::<Nested<()>>();
}

View file

@ -0,0 +1,77 @@
use http::{Request, Uri};
use std::{
borrow::Cow,
sync::Arc,
task::{Context, Poll},
};
use tower_service::Service;
#[derive(Clone)]
pub(super) struct StripPrefix<S> {
inner: S,
prefix: Arc<str>,
}
impl<S> StripPrefix<S> {
pub(super) fn new(inner: S, prefix: &str) -> Self {
Self {
inner,
prefix: prefix.into(),
}
}
}
impl<S, B> Service<Request<B>> for StripPrefix<S>
where
S: Service<Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
let new_uri = strip_prefix(req.uri(), &self.prefix);
*req.uri_mut() = new_uri;
self.inner.call(req)
}
}
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
path
} else {
path_and_query.path()
};
let new_path = if new_path.starts_with('/') {
Cow::Borrowed(new_path)
} else {
Cow::Owned(format!("/{}", new_path))
};
if let Some(query) = path_and_query.query() {
Some(
format!("{}?{}", new_path, query)
.parse::<http::uri::PathAndQuery>()
.unwrap(),
)
} else {
Some(new_path.parse().unwrap())
}
} else {
None
};
let mut parts = http::uri::Parts::default();
parts.scheme = uri.scheme().cloned();
parts.authority = uri.authority().cloned();
parts.path_and_query = path_and_query;
Uri::from_parts(parts).unwrap()
}

View file

@ -31,29 +31,6 @@ async fn nest() {
assert_eq!(res.text().await, "fallback"); assert_eq!(res.text().await, "fallback");
} }
#[tokio::test]
async fn nesting_with_fallback() {
let app = Router::new().nest(
"/foo",
Router::new()
.route("/bar", get(|| async {}))
.fallback((|| async { "fallback" }).into_service()),
);
let client = TestClient::new(app);
assert_eq!(client.get("/foo/bar").send().await.status(), StatusCode::OK);
// this shouldn't exist because the fallback is inside the nested router
let res = client.get("/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
// this should work since we get into the nested router
let res = client.get("/foo/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "fallback");
}
#[tokio::test] #[tokio::test]
async fn or() { async fn or() {
let one = Router::new().route("/one", get(|| async {})); let one = Router::new().route("/one", get(|| async {}));

View file

@ -1,7 +1,7 @@
#![allow(clippy::blacklisted_name)] #![allow(clippy::blacklisted_name)]
use crate::error_handling::HandleErrorLayer; use crate::error_handling::HandleErrorLayer;
use crate::extract::MatchedPath; use crate::extract::{Extension, MatchedPath};
use crate::BoxError; use crate::BoxError;
use crate::{ use crate::{
extract::{self, Path}, extract::{self, Path},
@ -29,7 +29,6 @@ use std::{
}; };
use tower::{service_fn, timeout::TimeoutLayer, ServiceBuilder}; use tower::{service_fn, timeout::TimeoutLayer, ServiceBuilder};
use tower_http::auth::RequireAuthorizationLayer; use tower_http::auth::RequireAuthorizationLayer;
use tower_http::trace::TraceLayer;
use tower_service::Service; use tower_service::Service;
pub(crate) use helpers::*; pub(crate) use helpers::*;
@ -592,24 +591,92 @@ async fn with_and_without_trailing_slash() {
assert_eq!(res.text().await, "without tsr"); assert_eq!(res.text().await, "without tsr");
} }
#[derive(Clone)]
struct SetMatchedPathExtension<S>(S);
impl<B, S> Service<Request<B>> for SetMatchedPathExtension<S>
where
S: Service<Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
let path = req
.extensions()
.get::<MatchedPath>()
.unwrap()
.as_str()
.to_string();
req.extensions_mut().insert(MatchedPathFromMiddleware(path));
self.0.call(req)
}
}
#[derive(Clone)]
struct MatchedPathFromMiddleware(String);
#[tokio::test] #[tokio::test]
async fn access_matched_path() { async fn access_matched_path() {
let api = Router::new().route(
"/users/:id",
get(|path: MatchedPath| async move { path.as_str().to_string() }),
);
async fn handler(
path: MatchedPath,
Extension(MatchedPathFromMiddleware(path_from_middleware)): Extension<
MatchedPathFromMiddleware,
>,
) -> String {
format!(
"extractor = {}, middleware = {}",
path.as_str(),
path_from_middleware
)
}
let app = Router::new() let app = Router::new()
.route( .route(
"/:key", "/:key",
get(|path: MatchedPath| async move { path.as_str().to_string() }), get(|path: MatchedPath| async move { path.as_str().to_string() }),
) )
.layer( .nest("/api", api)
TraceLayer::new_for_http().make_span_with(|req: &Request<_>| { .nest(
let path = req.extensions().get::<MatchedPath>().unwrap().as_str(); "/public",
tracing::info_span!("http-request", %path) Router::new().route("/assets/*path", get(handler)),
}), )
); .nest("/foo", handler.into_service())
.layer(tower::layer::layer_fn(SetMatchedPathExtension));
let client = TestClient::new(app); let client = TestClient::new(app);
let res = client.get("/foo").send().await; let res = client.get("/foo").send().await;
assert_eq!(res.text().await, "/:key"); assert_eq!(res.text().await, "/:key");
let res = client.get("/api/users/123").send().await;
assert_eq!(res.text().await, "/api/users/:id");
let res = client.get("/public/assets/css/style.css").send().await;
assert_eq!(
res.text().await,
"extractor = /public/assets/*path, middleware = /public/assets/*path"
);
let res = client.get("/foo/bar/baz").send().await;
assert_eq!(
res.text().await,
format!(
"extractor = /foo/*{}, middleware = /foo/*{}",
crate::routing::NEST_TAIL_PARAM,
crate::routing::NEST_TAIL_PARAM,
),
);
} }
#[tokio::test] #[tokio::test]