Implement nesting as regular wildcard routes (#426)

When fixing bugs with `MatchedPath` (introduced to fix https://github.com/tokio-rs/axum/issues/386) I realized nesting could basically be implemented using regular routes, if we can detect that the service passed to `nest` is in fact a `Router`. Then we can transfer over all its routes and add the prefix.

This makes nesting much simpler in general and should also be slightly faster since we're no longer nesting routers.
This commit is contained in:
David Pedersen 2021-10-26 22:31:22 +02:00 committed by GitHub
parent fc9bfb8a50
commit 92f96a201c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 223 additions and 153 deletions

View file

@ -6,7 +6,7 @@ use crate::{
body::{box_body, Body, BoxBody},
extract::{
connect_info::{Connected, IntoMakeServiceWithConnectInfo},
OriginalUri,
MatchedPath, OriginalUri,
},
util::{ByteStr, PercentDecodedByteStr},
BoxError,
@ -33,9 +33,9 @@ pub mod service_method_router;
mod into_make_service;
mod method_filter;
mod method_not_allowed;
mod nested;
mod not_found;
mod route;
mod strip_prefix;
pub(crate) use self::method_not_allowed::MethodNotAllowed;
pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route};
@ -92,7 +92,8 @@ impl<B> fmt::Debug for Router<B> {
}
}
const NEST_TAIL_PARAM: &str = "__axum_internal_nest_capture";
pub(crate) const NEST_TAIL_PARAM: &str = "axum_nest";
const NEST_TAIL_PARAM_CAPTURE: &str = "/*axum_nest";
impl<B> Router<B>
where
@ -342,20 +343,43 @@ where
panic!("Invalid route: nested routes cannot contain wildcards (*)");
}
let id = RouteId::next();
let prefix = path;
let path = if path == "/" {
format!("/*{}", NEST_TAIL_PARAM)
} else {
format!("{}/*{}", path, NEST_TAIL_PARAM)
};
match try_downcast::<Router<B>, _>(svc) {
// if the user is nesting a `Router` we can implement nesting
// by simplying copying all the routes and adding the prefix in
// front
Ok(router) => {
let Router {
mut routes,
node,
fallback: _,
} = router;
if let Err(err) = self.node.insert(path, id) {
panic!("Invalid route: {}", err);
for (id, nested_path) in node.paths {
let route = routes.remove(&id).unwrap();
let full_path = if &*nested_path == "/" {
path.to_string()
} else {
format!("{}{}", path, nested_path)
};
self = self.route(&full_path, strip_prefix::StripPrefix::new(route, prefix));
}
debug_assert!(routes.is_empty());
}
// otherwise we add a wildcard route to the service
Err(svc) => {
let path = if path == "/" {
format!("/*{}", NEST_TAIL_PARAM)
} else {
format!("{}/*{}", path, NEST_TAIL_PARAM)
};
self = self.route(&path, strip_prefix::StripPrefix::new(svc, prefix));
}
}
self.routes.insert(id, Route::new(nested::Nested { svc }));
self
}
@ -700,31 +724,8 @@ where
/// ## When used with `Router::nest`
///
/// If a router with a fallback is nested inside another router the fallback
/// will only apply to requests that matches the prefix:
///
/// ```rust
/// use axum::{
/// Router,
/// routing::get,
/// handler::Handler,
/// response::IntoResponse,
/// http::{StatusCode, Uri},
/// };
///
/// let api = Router::new()
/// .route("/", get(|| async { /* ... */ }))
/// .fallback(api_fallback.into_service());
///
/// let app = Router::new().nest("/api", api);
///
/// async fn api_fallback() -> impl IntoResponse { /* ... */ }
///
/// // `api_fallback` will be called for `/api/some-unknown-path` but not for
/// // `/some-unknown-path` as the path doesn't start with `/api`
/// # async {
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
/// of the nested router will be discarded and not used. This is such that
/// the outer router's fallback takes precedence.
pub fn fallback<T>(mut self, svc: T) -> Self
where
T: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible>
@ -743,8 +744,22 @@ where
req.extensions_mut().insert(id);
if let Some(matched_path) = self.node.paths.get(&id) {
req.extensions_mut()
.insert(crate::extract::MatchedPath(matched_path.clone()));
let matched_path = if let Some(previous) = req.extensions_mut().get::<MatchedPath>() {
// a previous `MatchedPath` might exist if we're inside a nested Router
let previous = if let Some(previous) =
previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE)
{
previous
} else {
previous.as_str()
};
let matched_path = format!("{}{}", previous, matched_path);
matched_path.into()
} else {
Arc::clone(matched_path)
};
req.extensions_mut().insert(MatchedPath(matched_path));
}
let params = match_
@ -754,11 +769,6 @@ where
.map(|(key, value)| (key.to_string(), value.to_string()))
.collect::<Vec<_>>();
if let Some(tail) = match_.params.get(NEST_TAIL_PARAM) {
req.extensions_mut()
.insert(nested::NestMatchTail(tail.to_string()));
}
insert_url_params(&mut req, params);
let route = self
@ -970,14 +980,22 @@ impl<B> Fallback<B> {
}
}
#[cfg(test)]
mod tests {
use super::*;
fn try_downcast<T, K>(k: K) -> Result<T, K>
where
T: 'static,
K: Send + 'static,
{
use std::any::Any;
#[test]
fn traits() {
use crate::tests::*;
assert_send::<Router<()>>();
let k = Box::new(k) as Box<dyn Any + Send + 'static>;
match k.downcast() {
Ok(t) => Ok(*t),
Err(other) => Err(*other.downcast().unwrap()),
}
}
#[test]
fn traits() {
use crate::tests::*;
assert_send::<Router<()>>();
}

View file

@ -1,69 +0,0 @@
use crate::body::BoxBody;
use http::{Request, Response, Uri};
use std::{
convert::Infallible,
task::{Context, Poll},
};
use tower::util::Oneshot;
use tower::ServiceExt;
use tower_service::Service;
/// A [`Service`] that has been nested inside a router at some path.
///
/// Created with [`Router::nest`].
#[derive(Debug, Clone)]
pub(super) struct Nested<S> {
pub(super) svc: S,
}
impl<B, S> Service<Request<B>> for Nested<S>
where
S: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible> + Clone,
B: Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = Oneshot<S, Request<B>>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
// strip the prefix from the URI just before calling the inner service
// such that any surrounding middleware still see the full path
if let Some(tail) = req.extensions_mut().remove::<NestMatchTail>() {
UriStack::push(&mut req);
let new_uri = super::with_path(req.uri(), &tail.0);
*req.uri_mut() = new_uri;
}
self.svc.clone().oneshot(req)
}
}
pub(crate) struct UriStack(Vec<Uri>);
impl UriStack {
fn push<B>(req: &mut Request<B>) {
let uri = req.uri().clone();
if let Some(stack) = req.extensions_mut().get_mut::<Self>() {
stack.0.push(uri);
} else {
req.extensions_mut().insert(Self(vec![uri]));
}
}
}
#[derive(Clone)]
pub(super) struct NestMatchTail(pub(super) String);
#[test]
fn traits() {
use crate::tests::*;
assert_send::<Nested<()>>();
assert_sync::<Nested<()>>();
}

View file

@ -0,0 +1,77 @@
use http::{Request, Uri};
use std::{
borrow::Cow,
sync::Arc,
task::{Context, Poll},
};
use tower_service::Service;
#[derive(Clone)]
pub(super) struct StripPrefix<S> {
inner: S,
prefix: Arc<str>,
}
impl<S> StripPrefix<S> {
pub(super) fn new(inner: S, prefix: &str) -> Self {
Self {
inner,
prefix: prefix.into(),
}
}
}
impl<S, B> Service<Request<B>> for StripPrefix<S>
where
S: Service<Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
let new_uri = strip_prefix(req.uri(), &self.prefix);
*req.uri_mut() = new_uri;
self.inner.call(req)
}
}
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
path
} else {
path_and_query.path()
};
let new_path = if new_path.starts_with('/') {
Cow::Borrowed(new_path)
} else {
Cow::Owned(format!("/{}", new_path))
};
if let Some(query) = path_and_query.query() {
Some(
format!("{}?{}", new_path, query)
.parse::<http::uri::PathAndQuery>()
.unwrap(),
)
} else {
Some(new_path.parse().unwrap())
}
} else {
None
};
let mut parts = http::uri::Parts::default();
parts.scheme = uri.scheme().cloned();
parts.authority = uri.authority().cloned();
parts.path_and_query = path_and_query;
Uri::from_parts(parts).unwrap()
}

View file

@ -31,29 +31,6 @@ async fn nest() {
assert_eq!(res.text().await, "fallback");
}
#[tokio::test]
async fn nesting_with_fallback() {
let app = Router::new().nest(
"/foo",
Router::new()
.route("/bar", get(|| async {}))
.fallback((|| async { "fallback" }).into_service()),
);
let client = TestClient::new(app);
assert_eq!(client.get("/foo/bar").send().await.status(), StatusCode::OK);
// this shouldn't exist because the fallback is inside the nested router
let res = client.get("/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
// this should work since we get into the nested router
let res = client.get("/foo/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "fallback");
}
#[tokio::test]
async fn or() {
let one = Router::new().route("/one", get(|| async {}));

View file

@ -1,7 +1,7 @@
#![allow(clippy::blacklisted_name)]
use crate::error_handling::HandleErrorLayer;
use crate::extract::MatchedPath;
use crate::extract::{Extension, MatchedPath};
use crate::BoxError;
use crate::{
extract::{self, Path},
@ -29,7 +29,6 @@ use std::{
};
use tower::{service_fn, timeout::TimeoutLayer, ServiceBuilder};
use tower_http::auth::RequireAuthorizationLayer;
use tower_http::trace::TraceLayer;
use tower_service::Service;
pub(crate) use helpers::*;
@ -592,24 +591,92 @@ async fn with_and_without_trailing_slash() {
assert_eq!(res.text().await, "without tsr");
}
#[derive(Clone)]
struct SetMatchedPathExtension<S>(S);
impl<B, S> Service<Request<B>> for SetMatchedPathExtension<S>
where
S: Service<Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
let path = req
.extensions()
.get::<MatchedPath>()
.unwrap()
.as_str()
.to_string();
req.extensions_mut().insert(MatchedPathFromMiddleware(path));
self.0.call(req)
}
}
#[derive(Clone)]
struct MatchedPathFromMiddleware(String);
#[tokio::test]
async fn access_matched_path() {
let api = Router::new().route(
"/users/:id",
get(|path: MatchedPath| async move { path.as_str().to_string() }),
);
async fn handler(
path: MatchedPath,
Extension(MatchedPathFromMiddleware(path_from_middleware)): Extension<
MatchedPathFromMiddleware,
>,
) -> String {
format!(
"extractor = {}, middleware = {}",
path.as_str(),
path_from_middleware
)
}
let app = Router::new()
.route(
"/:key",
get(|path: MatchedPath| async move { path.as_str().to_string() }),
)
.layer(
TraceLayer::new_for_http().make_span_with(|req: &Request<_>| {
let path = req.extensions().get::<MatchedPath>().unwrap().as_str();
tracing::info_span!("http-request", %path)
}),
);
.nest("/api", api)
.nest(
"/public",
Router::new().route("/assets/*path", get(handler)),
)
.nest("/foo", handler.into_service())
.layer(tower::layer::layer_fn(SetMatchedPathExtension));
let client = TestClient::new(app);
let res = client.get("/foo").send().await;
assert_eq!(res.text().await, "/:key");
let res = client.get("/api/users/123").send().await;
assert_eq!(res.text().await, "/api/users/:id");
let res = client.get("/public/assets/css/style.css").send().await;
assert_eq!(
res.text().await,
"extractor = /public/assets/*path, middleware = /public/assets/*path"
);
let res = client.get("/foo/bar/baz").send().await;
assert_eq!(
res.text().await,
format!(
"extractor = /foo/*{}, middleware = /foo/*{}",
crate::routing::NEST_TAIL_PARAM,
crate::routing::NEST_TAIL_PARAM,
),
);
}
#[tokio::test]