mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-29 07:48:39 +01:00
Implement nesting as regular wildcard routes (#426)
When fixing bugs with `MatchedPath` (introduced to fix https://github.com/tokio-rs/axum/issues/386) I realized nesting could basically be implemented using regular routes, if we can detect that the service passed to `nest` is in fact a `Router`. Then we can transfer over all its routes and add the prefix. This makes nesting much simpler in general and should also be slightly faster since we're no longer nesting routers.
This commit is contained in:
parent
fc9bfb8a50
commit
92f96a201c
5 changed files with 223 additions and 153 deletions
|
@ -6,7 +6,7 @@ use crate::{
|
|||
body::{box_body, Body, BoxBody},
|
||||
extract::{
|
||||
connect_info::{Connected, IntoMakeServiceWithConnectInfo},
|
||||
OriginalUri,
|
||||
MatchedPath, OriginalUri,
|
||||
},
|
||||
util::{ByteStr, PercentDecodedByteStr},
|
||||
BoxError,
|
||||
|
@ -33,9 +33,9 @@ pub mod service_method_router;
|
|||
mod into_make_service;
|
||||
mod method_filter;
|
||||
mod method_not_allowed;
|
||||
mod nested;
|
||||
mod not_found;
|
||||
mod route;
|
||||
mod strip_prefix;
|
||||
|
||||
pub(crate) use self::method_not_allowed::MethodNotAllowed;
|
||||
pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route};
|
||||
|
@ -92,7 +92,8 @@ impl<B> fmt::Debug for Router<B> {
|
|||
}
|
||||
}
|
||||
|
||||
const NEST_TAIL_PARAM: &str = "__axum_internal_nest_capture";
|
||||
pub(crate) const NEST_TAIL_PARAM: &str = "axum_nest";
|
||||
const NEST_TAIL_PARAM_CAPTURE: &str = "/*axum_nest";
|
||||
|
||||
impl<B> Router<B>
|
||||
where
|
||||
|
@ -342,20 +343,43 @@ where
|
|||
panic!("Invalid route: nested routes cannot contain wildcards (*)");
|
||||
}
|
||||
|
||||
let id = RouteId::next();
|
||||
let prefix = path;
|
||||
|
||||
let path = if path == "/" {
|
||||
format!("/*{}", NEST_TAIL_PARAM)
|
||||
} else {
|
||||
format!("{}/*{}", path, NEST_TAIL_PARAM)
|
||||
};
|
||||
match try_downcast::<Router<B>, _>(svc) {
|
||||
// if the user is nesting a `Router` we can implement nesting
|
||||
// by simplying copying all the routes and adding the prefix in
|
||||
// front
|
||||
Ok(router) => {
|
||||
let Router {
|
||||
mut routes,
|
||||
node,
|
||||
fallback: _,
|
||||
} = router;
|
||||
|
||||
if let Err(err) = self.node.insert(path, id) {
|
||||
panic!("Invalid route: {}", err);
|
||||
for (id, nested_path) in node.paths {
|
||||
let route = routes.remove(&id).unwrap();
|
||||
let full_path = if &*nested_path == "/" {
|
||||
path.to_string()
|
||||
} else {
|
||||
format!("{}{}", path, nested_path)
|
||||
};
|
||||
self = self.route(&full_path, strip_prefix::StripPrefix::new(route, prefix));
|
||||
}
|
||||
|
||||
debug_assert!(routes.is_empty());
|
||||
}
|
||||
// otherwise we add a wildcard route to the service
|
||||
Err(svc) => {
|
||||
let path = if path == "/" {
|
||||
format!("/*{}", NEST_TAIL_PARAM)
|
||||
} else {
|
||||
format!("{}/*{}", path, NEST_TAIL_PARAM)
|
||||
};
|
||||
|
||||
self = self.route(&path, strip_prefix::StripPrefix::new(svc, prefix));
|
||||
}
|
||||
}
|
||||
|
||||
self.routes.insert(id, Route::new(nested::Nested { svc }));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -700,31 +724,8 @@ where
|
|||
/// ## When used with `Router::nest`
|
||||
///
|
||||
/// If a router with a fallback is nested inside another router the fallback
|
||||
/// will only apply to requests that matches the prefix:
|
||||
///
|
||||
/// ```rust
|
||||
/// use axum::{
|
||||
/// Router,
|
||||
/// routing::get,
|
||||
/// handler::Handler,
|
||||
/// response::IntoResponse,
|
||||
/// http::{StatusCode, Uri},
|
||||
/// };
|
||||
///
|
||||
/// let api = Router::new()
|
||||
/// .route("/", get(|| async { /* ... */ }))
|
||||
/// .fallback(api_fallback.into_service());
|
||||
///
|
||||
/// let app = Router::new().nest("/api", api);
|
||||
///
|
||||
/// async fn api_fallback() -> impl IntoResponse { /* ... */ }
|
||||
///
|
||||
/// // `api_fallback` will be called for `/api/some-unknown-path` but not for
|
||||
/// // `/some-unknown-path` as the path doesn't start with `/api`
|
||||
/// # async {
|
||||
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
/// of the nested router will be discarded and not used. This is such that
|
||||
/// the outer router's fallback takes precedence.
|
||||
pub fn fallback<T>(mut self, svc: T) -> Self
|
||||
where
|
||||
T: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible>
|
||||
|
@ -743,8 +744,22 @@ where
|
|||
req.extensions_mut().insert(id);
|
||||
|
||||
if let Some(matched_path) = self.node.paths.get(&id) {
|
||||
req.extensions_mut()
|
||||
.insert(crate::extract::MatchedPath(matched_path.clone()));
|
||||
let matched_path = if let Some(previous) = req.extensions_mut().get::<MatchedPath>() {
|
||||
// a previous `MatchedPath` might exist if we're inside a nested Router
|
||||
let previous = if let Some(previous) =
|
||||
previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE)
|
||||
{
|
||||
previous
|
||||
} else {
|
||||
previous.as_str()
|
||||
};
|
||||
|
||||
let matched_path = format!("{}{}", previous, matched_path);
|
||||
matched_path.into()
|
||||
} else {
|
||||
Arc::clone(matched_path)
|
||||
};
|
||||
req.extensions_mut().insert(MatchedPath(matched_path));
|
||||
}
|
||||
|
||||
let params = match_
|
||||
|
@ -754,11 +769,6 @@ where
|
|||
.map(|(key, value)| (key.to_string(), value.to_string()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if let Some(tail) = match_.params.get(NEST_TAIL_PARAM) {
|
||||
req.extensions_mut()
|
||||
.insert(nested::NestMatchTail(tail.to_string()));
|
||||
}
|
||||
|
||||
insert_url_params(&mut req, params);
|
||||
|
||||
let route = self
|
||||
|
@ -970,14 +980,22 @@ impl<B> Fallback<B> {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
fn try_downcast<T, K>(k: K) -> Result<T, K>
|
||||
where
|
||||
T: 'static,
|
||||
K: Send + 'static,
|
||||
{
|
||||
use std::any::Any;
|
||||
|
||||
#[test]
|
||||
fn traits() {
|
||||
use crate::tests::*;
|
||||
|
||||
assert_send::<Router<()>>();
|
||||
let k = Box::new(k) as Box<dyn Any + Send + 'static>;
|
||||
match k.downcast() {
|
||||
Ok(t) => Ok(*t),
|
||||
Err(other) => Err(*other.downcast().unwrap()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn traits() {
|
||||
use crate::tests::*;
|
||||
assert_send::<Router<()>>();
|
||||
}
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
use crate::body::BoxBody;
|
||||
use http::{Request, Response, Uri};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tower::util::Oneshot;
|
||||
use tower::ServiceExt;
|
||||
use tower_service::Service;
|
||||
|
||||
/// A [`Service`] that has been nested inside a router at some path.
|
||||
///
|
||||
/// Created with [`Router::nest`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub(super) struct Nested<S> {
|
||||
pub(super) svc: S,
|
||||
}
|
||||
|
||||
impl<B, S> Service<Request<B>> for Nested<S>
|
||||
where
|
||||
S: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible> + Clone,
|
||||
B: Send + Sync + 'static,
|
||||
{
|
||||
type Response = Response<BoxBody>;
|
||||
type Error = Infallible;
|
||||
type Future = Oneshot<S, Request<B>>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
||||
// strip the prefix from the URI just before calling the inner service
|
||||
// such that any surrounding middleware still see the full path
|
||||
if let Some(tail) = req.extensions_mut().remove::<NestMatchTail>() {
|
||||
UriStack::push(&mut req);
|
||||
let new_uri = super::with_path(req.uri(), &tail.0);
|
||||
*req.uri_mut() = new_uri;
|
||||
}
|
||||
|
||||
self.svc.clone().oneshot(req)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct UriStack(Vec<Uri>);
|
||||
|
||||
impl UriStack {
|
||||
fn push<B>(req: &mut Request<B>) {
|
||||
let uri = req.uri().clone();
|
||||
|
||||
if let Some(stack) = req.extensions_mut().get_mut::<Self>() {
|
||||
stack.0.push(uri);
|
||||
} else {
|
||||
req.extensions_mut().insert(Self(vec![uri]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(super) struct NestMatchTail(pub(super) String);
|
||||
|
||||
#[test]
|
||||
fn traits() {
|
||||
use crate::tests::*;
|
||||
|
||||
assert_send::<Nested<()>>();
|
||||
assert_sync::<Nested<()>>();
|
||||
}
|
77
src/routing/strip_prefix.rs
Normal file
77
src/routing/strip_prefix.rs
Normal file
|
@ -0,0 +1,77 @@
|
|||
use http::{Request, Uri};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tower_service::Service;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(super) struct StripPrefix<S> {
|
||||
inner: S,
|
||||
prefix: Arc<str>,
|
||||
}
|
||||
|
||||
impl<S> StripPrefix<S> {
|
||||
pub(super) fn new(inner: S, prefix: &str) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
prefix: prefix.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, B> Service<Request<B>> for StripPrefix<S>
|
||||
where
|
||||
S: Service<Request<B>>,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = S::Future;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
||||
let new_uri = strip_prefix(req.uri(), &self.prefix);
|
||||
*req.uri_mut() = new_uri;
|
||||
self.inner.call(req)
|
||||
}
|
||||
}
|
||||
|
||||
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
|
||||
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
|
||||
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
|
||||
path
|
||||
} else {
|
||||
path_and_query.path()
|
||||
};
|
||||
|
||||
let new_path = if new_path.starts_with('/') {
|
||||
Cow::Borrowed(new_path)
|
||||
} else {
|
||||
Cow::Owned(format!("/{}", new_path))
|
||||
};
|
||||
|
||||
if let Some(query) = path_and_query.query() {
|
||||
Some(
|
||||
format!("{}?{}", new_path, query)
|
||||
.parse::<http::uri::PathAndQuery>()
|
||||
.unwrap(),
|
||||
)
|
||||
} else {
|
||||
Some(new_path.parse().unwrap())
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut parts = http::uri::Parts::default();
|
||||
parts.scheme = uri.scheme().cloned();
|
||||
parts.authority = uri.authority().cloned();
|
||||
parts.path_and_query = path_and_query;
|
||||
|
||||
Uri::from_parts(parts).unwrap()
|
||||
}
|
|
@ -31,29 +31,6 @@ async fn nest() {
|
|||
assert_eq!(res.text().await, "fallback");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn nesting_with_fallback() {
|
||||
let app = Router::new().nest(
|
||||
"/foo",
|
||||
Router::new()
|
||||
.route("/bar", get(|| async {}))
|
||||
.fallback((|| async { "fallback" }).into_service()),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
assert_eq!(client.get("/foo/bar").send().await.status(), StatusCode::OK);
|
||||
|
||||
// this shouldn't exist because the fallback is inside the nested router
|
||||
let res = client.get("/does-not-exist").send().await;
|
||||
assert_eq!(res.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
// this should work since we get into the nested router
|
||||
let res = client.get("/foo/does-not-exist").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
assert_eq!(res.text().await, "fallback");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn or() {
|
||||
let one = Router::new().route("/one", get(|| async {}));
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#![allow(clippy::blacklisted_name)]
|
||||
|
||||
use crate::error_handling::HandleErrorLayer;
|
||||
use crate::extract::MatchedPath;
|
||||
use crate::extract::{Extension, MatchedPath};
|
||||
use crate::BoxError;
|
||||
use crate::{
|
||||
extract::{self, Path},
|
||||
|
@ -29,7 +29,6 @@ use std::{
|
|||
};
|
||||
use tower::{service_fn, timeout::TimeoutLayer, ServiceBuilder};
|
||||
use tower_http::auth::RequireAuthorizationLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tower_service::Service;
|
||||
|
||||
pub(crate) use helpers::*;
|
||||
|
@ -592,24 +591,92 @@ async fn with_and_without_trailing_slash() {
|
|||
assert_eq!(res.text().await, "without tsr");
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SetMatchedPathExtension<S>(S);
|
||||
|
||||
impl<B, S> Service<Request<B>> for SetMatchedPathExtension<S>
|
||||
where
|
||||
S: Service<Request<B>>,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = S::Future;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.0.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
||||
let path = req
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.unwrap()
|
||||
.as_str()
|
||||
.to_string();
|
||||
req.extensions_mut().insert(MatchedPathFromMiddleware(path));
|
||||
self.0.call(req)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MatchedPathFromMiddleware(String);
|
||||
|
||||
#[tokio::test]
|
||||
async fn access_matched_path() {
|
||||
let api = Router::new().route(
|
||||
"/users/:id",
|
||||
get(|path: MatchedPath| async move { path.as_str().to_string() }),
|
||||
);
|
||||
|
||||
async fn handler(
|
||||
path: MatchedPath,
|
||||
Extension(MatchedPathFromMiddleware(path_from_middleware)): Extension<
|
||||
MatchedPathFromMiddleware,
|
||||
>,
|
||||
) -> String {
|
||||
format!(
|
||||
"extractor = {}, middleware = {}",
|
||||
path.as_str(),
|
||||
path_from_middleware
|
||||
)
|
||||
}
|
||||
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/:key",
|
||||
get(|path: MatchedPath| async move { path.as_str().to_string() }),
|
||||
)
|
||||
.layer(
|
||||
TraceLayer::new_for_http().make_span_with(|req: &Request<_>| {
|
||||
let path = req.extensions().get::<MatchedPath>().unwrap().as_str();
|
||||
tracing::info_span!("http-request", %path)
|
||||
}),
|
||||
);
|
||||
.nest("/api", api)
|
||||
.nest(
|
||||
"/public",
|
||||
Router::new().route("/assets/*path", get(handler)),
|
||||
)
|
||||
.nest("/foo", handler.into_service())
|
||||
.layer(tower::layer::layer_fn(SetMatchedPathExtension));
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client.get("/foo").send().await;
|
||||
assert_eq!(res.text().await, "/:key");
|
||||
|
||||
let res = client.get("/api/users/123").send().await;
|
||||
assert_eq!(res.text().await, "/api/users/:id");
|
||||
|
||||
let res = client.get("/public/assets/css/style.css").send().await;
|
||||
assert_eq!(
|
||||
res.text().await,
|
||||
"extractor = /public/assets/*path, middleware = /public/assets/*path"
|
||||
);
|
||||
|
||||
let res = client.get("/foo/bar/baz").send().await;
|
||||
assert_eq!(
|
||||
res.text().await,
|
||||
format!(
|
||||
"extractor = /foo/*{}, middleware = /foo/*{}",
|
||||
crate::routing::NEST_TAIL_PARAM,
|
||||
crate::routing::NEST_TAIL_PARAM,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
Loading…
Reference in a new issue