"matchit" based router (#363)

* "matchit" based router

* Update changelog

* Remove dependency on `regex`

* Docs

* Fix typos

* Also mention route order in root module docs

* Update CHANGELOG.md

Co-authored-by: Jonas Platte <jplatte@users.noreply.github.com>

* Document that `/:key` and `/foo` overlaps

* Provide good error message for wildcards in routes

* minor clean ups

* Make `Router` cheaper to clone

* Ensure middleware still only applies to routes above

* Remove call to issues from changelog

We're aware of the short coming :)

* Fix tests on 1.51

Co-authored-by: Jonas Platte <jplatte@users.noreply.github.com>
This commit is contained in:
David Pedersen 2021-10-24 15:22:49 +02:00 committed by GitHub
parent 9fcc884374
commit 1a78a3f224
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 553 additions and 413 deletions

View file

@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- Big internal refactoring of routing leading to several improvements ([#363])
- Wildcard routes like `.route("/api/users/*rest", service)` are now supported.
- The order routes are added in no longer matters.
- Adding a conflicting route will now cause a panic instead of silently making
a route unreachable.
- Route matching is faster as number of routes increase.
- The routes `/foo` and `/:key` are considered to overlap and will cause a
panic when constructing the router. This might be fixed in the future.
- Improve performance of `BoxRoute` ([#339])
- Expand accepted content types for JSON requests ([#378])
- **breaking:** Automatically do percent decoding in `extract::Path`
@ -25,6 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#286]: https://github.com/tokio-rs/axum/pull/286
[#272]: https://github.com/tokio-rs/axum/pull/272
[#378]: https://github.com/tokio-rs/axum/pull/378
[#363]: https://github.com/tokio-rs/axum/pull/363
[#396]: https://github.com/tokio-rs/axum/pull/396
# 0.2.8 (07. October, 2021)

View file

@ -32,9 +32,9 @@ futures-util = { version = "0.3", default-features = false, features = ["alloc"]
http = "0.2"
http-body = "0.4.3"
hyper = { version = "0.14", features = ["server", "tcp", "stream"] }
matchit = "0.4"
percent-encoding = "2.1"
pin-project-lite = "0.2.7"
regex = "1.5"
serde = "1.0"
serde_urlencoded = "0.7"
sync_wrapper = "0.1.1"
@ -60,8 +60,8 @@ reqwest = { version = "0.11", features = ["json", "stream"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread", "net"] }
uuid = { version = "0.8", features = ["serde", "v4"] }
tokio-stream = "0.1"
uuid = { version = "0.8", features = ["serde", "v4"] }
[dev-dependencies.tower]
package = "tower"

View file

@ -176,6 +176,7 @@ mod tests {
use super::*;
use crate::tests::*;
use crate::{handler::get, Router};
use std::collections::HashMap;
#[tokio::test]
async fn percent_decoding() {
@ -190,4 +191,27 @@ mod tests {
assert_eq!(res.text().await, "one two");
}
#[tokio::test]
async fn wildcard() {
let app = Router::new()
.route(
"/foo/*rest",
get(|Path(param): Path<String>| async move { param }),
)
.route(
"/bar/*rest",
get(|Path(params): Path<HashMap<String, String>>| async move {
params.get("rest").unwrap().clone()
}),
);
let client = TestClient::new(app);
let res = client.get("/foo/bar/baz").send().await;
assert_eq!(res.text().await, "/bar/baz");
let res = client.get("/bar/baz/qux").send().await;
assert_eq!(res.text().await, "/baz/qux");
}
}

View file

@ -7,10 +7,10 @@
//! - [Handlers](#handlers)
//! - [Debugging handler type errors](#debugging-handler-type-errors)
//! - [Routing](#routing)
//! - [Precedence](#precedence)
//! - [Matching multiple methods](#matching-multiple-methods)
//! - [Routing to any `Service`](#routing-to-any-service)
//! - [Routing to fallible services](#routing-to-fallible-services)
//! - [Wildcard routes](#wildcard-routes)
//! - [Nesting routes](#nesting-routes)
//! - [Extractors](#extractors)
//! - [Common extractors](#common-extractors)
@ -177,75 +177,8 @@
//!
//! You can also define routes separately and merge them with [`Router::or`].
//!
//! ## Precedence
//!
//! Note that routes are matched _bottom to top_ so routes that should have
//! higher precedence should be added _after_ routes with lower precedence:
//!
//! ```rust
//! use axum::{
//! body::{Body, BoxBody},
//! handler::get,
//! http::Request,
//! Router,
//! };
//! use tower::{Service, ServiceExt};
//! use http::{Method, Response, StatusCode};
//! use std::convert::Infallible;
//!
//! # #[tokio::main]
//! # async fn main() {
//! // `/foo` also matches `/:key` so adding the routes in this order means `/foo`
//! // will be inaccessible.
//! let mut app = Router::new()
//! .route("/foo", get(|| async { "/foo called" }))
//! .route("/:key", get(|| async { "/:key called" }));
//!
//! // Even though we use `/foo` as the request URI, `/:key` takes precedence
//! // since its defined last.
//! let (status, body) = call_service(&mut app, Method::GET, "/foo").await;
//! assert_eq!(status, StatusCode::OK);
//! assert_eq!(body, "/:key called");
//!
//! // We have to add `/foo` after `/:key` since routes are matched bottom to
//! // top.
//! let mut new_app = Router::new()
//! .route("/:key", get(|| async { "/:key called" }))
//! .route("/foo", get(|| async { "/foo called" }));
//!
//! // Now it works
//! let (status, body) = call_service(&mut new_app, Method::GET, "/foo").await;
//! assert_eq!(status, StatusCode::OK);
//! assert_eq!(body, "/foo called");
//!
//! // And the other route works as well
//! let (status, body) = call_service(&mut new_app, Method::GET, "/bar").await;
//! assert_eq!(status, StatusCode::OK);
//! assert_eq!(body, "/:key called");
//!
//! // Little helper function to make calling a service easier. Just for
//! // demonstration purposes.
//! async fn call_service<S>(
//! svc: &mut S,
//! method: Method,
//! uri: &str,
//! ) -> (StatusCode, String)
//! where
//! S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
//! {
//! let req = Request::builder().method(method).uri(uri).body(Body::empty()).unwrap();
//! let res = svc.ready().await.unwrap().call(req).await.unwrap();
//!
//! let status = res.status();
//!
//! let body = res.into_body();
//! let body = hyper::body::to_bytes(body).await.unwrap();
//! let body = String::from_utf8(body.to_vec()).unwrap();
//!
//! (status, body)
//! }
//! # }
//! ```
//! Routes are not allowed to overlap and will panic if an overlapping route is
//! added. This also means the order in which routes are added doesn't matter.
//!
//! ## Routing to any [`Service`]
//!
@ -376,6 +309,41 @@
//! See ["Error handling"](#error-handling) for more details on [`handle_error`]
//! and error handling in general.
//!
//! ## Wildcard routes
//!
//! axum also supports wildcard routes:
//!
//! ```rust,no_run
//! use axum::{
//! handler::get,
//! Router,
//! };
//!
//! let app = Router::new()
//! // this matches any request that starts with `/api`
//! .route("/api/*rest", get(|| async { /* ... */ }));
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! The matched path can be extracted via [`extract::Path`]:
//!
//! ```rust,no_run
//! use axum::{
//! handler::get,
//! extract::Path,
//! Router,
//! };
//!
//! let app = Router::new().route("/api/*rest", get(|Path(rest): Path<String>| async {
//! // `rest` will be everything after `/api`
//! }));
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! ## Nesting routes
//!
//! Routes can be nested by calling [`Router::nest`](routing::Router::nest):
@ -410,6 +378,25 @@
//! file serving to work. Use [`OriginalUri`] if you need the original request
//! URI.
//!
//! Nested routes are similar to wild card routes. The difference is that
//! wildcard routes still see the whole URI whereas nested routes will have
//! the prefix stripped.
//!
//! ```rust
//! use axum::{handler::get, http::Uri, Router};
//!
//! let app = Router::new()
//! .route("/foo/*rest", get(|uri: Uri| async {
//! // `uri` will contain `/foo`
//! }))
//! .nest("/bar", get(|uri: Uri| async {
//! // `uri` will _not_ contain `/bar`
//! }));
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! # Extractors
//!
//! An extractor is a type that implements [`FromRequest`]. Extractors is how

View file

@ -1,9 +1,11 @@
//! Future types.
use crate::{
body::BoxBody, clone_box_service::CloneBoxService, routing::FromEmptyRouter, BoxError,
body::BoxBody,
clone_box_service::CloneBoxService,
routing::{FromEmptyRouter, UriStack},
BoxError,
};
use futures_util::ready;
use http::{Request, Response};
use pin_project_lite::pin_project;
use std::{
@ -13,7 +15,7 @@ use std::{
pin::Pin,
task::{Context, Poll},
};
use tower::{util::Oneshot, ServiceExt};
use tower::util::Oneshot;
use tower_service::Service;
pub use super::or::ResponseFuture as OrResponseFuture;
@ -76,12 +78,9 @@ where
S: Service<Request<B>>,
F: Service<Request<B>>,
{
pub(crate) fn a(a: Oneshot<S, Request<B>>, fallback: F) -> Self {
pub(crate) fn a(a: Oneshot<S, Request<B>>) -> Self {
RouteFuture {
state: RouteFutureInner::A {
a,
fallback: Some(fallback),
},
state: RouteFutureInner::A { a },
}
}
@ -103,7 +102,6 @@ pin_project! {
A {
#[pin]
a: Oneshot<S, Request<B>>,
fallback: Option<F>,
},
B {
#[pin]
@ -120,33 +118,10 @@ where
{
type Output = Result<Response<BoxBody>, S::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let mut this = self.as_mut().project();
let new_state = match this.state.as_mut().project() {
RouteFutureInnerProj::A { a, fallback } => {
let mut response = ready!(a.poll(cx))?;
let req = if let Some(ext) =
response.extensions_mut().remove::<FromEmptyRouter<B>>()
{
ext.request
} else {
return Poll::Ready(Ok(response));
};
RouteFutureInner::B {
b: fallback
.take()
.expect("future polled after completion")
.oneshot(req),
}
}
RouteFutureInnerProj::B { b } => return b.poll(cx),
};
this.state.set(new_state);
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().state.project() {
RouteFutureInnerProj::A { a } => a.poll(cx),
RouteFutureInnerProj::B { b } => b.poll(cx),
}
}
}
@ -173,7 +148,20 @@ where
type Output = Result<Response<BoxBody>, S::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().inner.poll(cx)
let mut res: Response<_> = futures_util::ready!(self.project().inner.poll(cx)?);
// `nest` mutates the URI of the request so if it turns out no route matched
// we need to reset the URI so the next routes see the original URI
//
// That requires using a stack since we can have arbitrarily nested routes
if let Some(from_empty_router) = res.extensions_mut().get_mut::<FromEmptyRouter<B>>() {
let uri = UriStack::pop(&mut from_empty_router.request);
if let Some(uri) = uri {
*from_empty_router.request.uri_mut() = uri;
}
}
Poll::Ready(Ok(res))
}
}

View file

@ -14,7 +14,7 @@ use crate::{
};
use bytes::Bytes;
use http::{Request, Response, StatusCode, Uri};
use regex::Regex;
use matchit::Node;
use std::{
borrow::Cow,
convert::Infallible,
@ -36,10 +36,32 @@ mod or;
pub use self::{method_filter::MethodFilter, or::Or};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
struct RouteId(u64);
impl RouteId {
fn next() -> Self {
use std::sync::atomic::{AtomicU64, Ordering};
static ID: AtomicU64 = AtomicU64::new(0);
Self(ID.fetch_add(1, Ordering::SeqCst))
}
}
/// The router type for composing handlers and services.
#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct Router<S> {
svc: S,
routes: S,
node: MaybeSharedNode,
}
// optimization that allows us to only clone the whole `Node` if we're actually
// mutating it while building the router. Once we've created a `MakeService` we
// no longer need to add routes and can `Arc` the node making it cheaper to
// clone
#[derive(Clone)]
enum MaybeSharedNode {
NotShared(Node<RouteId>),
Shared(Arc<Node<RouteId>>),
}
impl<E> Router<EmptyRouter<E>> {
@ -49,7 +71,8 @@ impl<E> Router<EmptyRouter<E>> {
/// all requests.
pub fn new() -> Self {
Self {
svc: EmptyRouter::not_found(),
routes: EmptyRouter::not_found(),
node: MaybeSharedNode::NotShared(Node::new()),
}
}
}
@ -60,25 +83,19 @@ impl<E> Default for Router<EmptyRouter<E>> {
}
}
impl<S, R> Service<R> for Router<S>
impl<S> fmt::Debug for Router<S>
where
S: Service<R>,
S: fmt::Debug,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.svc.poll_ready(cx)
}
#[inline]
fn call(&mut self, req: R) -> Self::Future {
self.svc.call(req)
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Router")
.field("routes", &self.routes)
.finish()
}
}
const NEST_TAIL_PARAM: &str = "__axum_nest";
impl<S> Router<S> {
/// Add another route to the router.
///
@ -119,13 +136,61 @@ impl<S> Router<S> {
///
/// # Panics
///
/// Panics if `path` doesn't start with `/`.
pub fn route<T>(self, path: &str, svc: T) -> Router<Route<T, S>> {
self.map(|fallback| Route {
pattern: PathPattern::new(path),
svc,
fallback,
})
/// Panics if the route overlaps with another route:
///
/// ```should_panic
/// use axum::{handler::get, Router};
///
/// let app = Router::new()
/// .route("/", get(|| async {}))
/// .route("/", get(|| async {}));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// This also applies to `nest` which is similar to a wildcard route:
///
/// ```should_panic
/// use axum::{handler::get, Router};
///
/// let app = Router::new()
/// // this is similar to `/api/*`
/// .nest("/api", get(|| async {}))
/// // which overlaps with this route
/// .route("/api/users", get(|| async {}));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// Note that routes like `/:key` and `/foo` are considered overlapping:
///
/// ```should_panic
/// use axum::{handler::get, Router};
///
/// let app = Router::new()
/// .route("/foo", get(|| async {}))
/// .route("/:key", get(|| async {}));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn route<T>(mut self, path: &str, svc: T) -> Router<Route<T, S>> {
let id = RouteId::next();
if let Err(err) = self.update_node(|node| node.insert(path, id)) {
panic!("Invalid route: {}", err);
}
Router {
routes: Route {
id,
svc,
fallback: self.routes,
},
node: self.node,
}
}
/// Nest a group of routes (or a [`Service`]) at some path.
@ -213,13 +278,58 @@ impl<S> Router<S> {
/// making the type easier to name. This is sometimes useful when working with
/// `nest`.
///
/// # Wildcard routes
///
/// Nested routes are similar to wildcard routes. The difference is that
/// wildcard routes still see the whole URI whereas nested routes will have
/// the prefix stripped.
///
/// ```rust
/// use axum::{handler::get, http::Uri, Router};
///
/// let app = Router::new()
/// .route("/foo/*rest", get(|uri: Uri| async {
/// // `uri` will contain `/foo`
/// }))
/// .nest("/bar", get(|uri: Uri| async {
/// // `uri` will _not_ contain `/bar`
/// }));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// # Panics
///
/// Panics if the route overlaps with another route. See [`Router::route`]
/// for more details.
///
/// [`OriginalUri`]: crate::extract::OriginalUri
pub fn nest<T>(self, path: &str, svc: T) -> Router<Nested<T, S>> {
self.map(|fallback| Nested {
pattern: PathPattern::new(path),
svc,
fallback,
})
pub fn nest<T>(mut self, path: &str, svc: T) -> Router<Nested<T, S>> {
let id = RouteId::next();
if path.contains('*') {
panic!("Invalid route: nested routes cannot contain wildcards (*)");
}
let path = if path == "/" {
format!("/*{}", NEST_TAIL_PARAM)
} else {
format!("{}/*{}", path, NEST_TAIL_PARAM)
};
if let Err(err) = self.update_node(|node| node.insert(path, id)) {
panic!("Invalid route: {}", err);
}
Router {
routes: Nested {
id,
svc,
fallback: self.routes,
},
node: self.node,
}
}
/// Create a boxed route trait object.
@ -366,11 +476,11 @@ impl<S> Router<S> {
/// ```
///
/// [`MakeService`]: tower::make::MakeService
pub fn into_make_service(self) -> IntoMakeService<S>
pub fn into_make_service(self) -> IntoMakeService<Self>
where
S: Clone,
{
IntoMakeService::new(self.svc)
IntoMakeService::new(self.into_shared_node())
}
/// Convert this router into a [`MakeService`], that will store `C`'s
@ -455,12 +565,12 @@ impl<S> Router<S> {
/// [uds]: https://github.com/tokio-rs/axum/blob/main/examples/unix_domain_socket.rs
pub fn into_make_service_with_connect_info<C, Target>(
self,
) -> IntoMakeServiceWithConnectInfo<S, C>
) -> IntoMakeServiceWithConnectInfo<Self, C>
where
S: Clone,
C: Connected<Target>,
{
IntoMakeServiceWithConnectInfo::new(self.svc)
IntoMakeServiceWithConnectInfo::new(self.into_shared_node())
}
/// Merge two routers into one.
@ -586,43 +696,112 @@ impl<S> Router<S> {
where
F: FnOnce(S) -> S2,
{
Router { svc: f(self.svc) }
}
}
/// A route that sends requests to one of two [`Service`]s depending on the
/// path.
#[derive(Debug, Clone)]
pub struct Route<S, F> {
pub(crate) pattern: PathPattern,
pub(crate) svc: S,
pub(crate) fallback: F,
}
impl<S, F, B> Service<Request<B>> for Route<S, F>
where
S: Service<Request<B>, Response = Response<BoxBody>> + Clone,
F: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error> + Clone,
B: Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = S::Error;
type Future = RouteFuture<S, F, B>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
if let Some(captures) = self.pattern.full_match(&req) {
insert_url_params(&mut req, captures);
let fut = self.svc.clone().oneshot(req);
RouteFuture::a(fut, self.fallback.clone())
} else {
let fut = self.fallback.clone().oneshot(req);
RouteFuture::b(fut)
Router {
routes: f(self.routes),
node: self.node,
}
}
fn update_node<F, T>(&mut self, f: F) -> T
where
F: FnOnce(&mut Node<RouteId>) -> T,
{
match &mut self.node {
MaybeSharedNode::NotShared(node) => f(node),
MaybeSharedNode::Shared(shared_node) => {
let mut node: Node<_> = Clone::clone(&*shared_node);
let result = f(&mut node);
self.node = MaybeSharedNode::NotShared(node);
result
}
}
}
fn get_node(&self) -> &Node<RouteId> {
match &self.node {
MaybeSharedNode::NotShared(node) => node,
MaybeSharedNode::Shared(shared_node) => &*shared_node,
}
}
fn into_shared_node(self) -> Self {
let node = match self.node {
MaybeSharedNode::NotShared(node) => MaybeSharedNode::Shared(Arc::new(node)),
MaybeSharedNode::Shared(shared_node) => {
MaybeSharedNode::Shared(Arc::clone(&shared_node))
}
};
Self {
routes: self.routes,
node,
}
}
}
impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for Router<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
ReqBody: Send + Sync + 'static,
{
type Response = Response<ResBody>;
type Error = S::Error;
type Future = S::Future;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.routes.poll_ready(cx)
}
#[inline]
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
if req.extensions().get::<OriginalUri>().is_none() {
let original_uri = OriginalUri(req.uri().clone());
req.extensions_mut().insert(original_uri);
}
let path = req.uri().path().to_string();
if let Ok(match_) = self.get_node().at(&path) {
let id = *match_.value;
req.extensions_mut().insert(id);
let params = match_
.params
.iter()
.filter(|(key, _)| !key.starts_with(NEST_TAIL_PARAM))
.map(|(key, value)| (key.to_string(), value.to_string()))
.collect::<Vec<_>>();
if let Some(tail) = match_.params.get(NEST_TAIL_PARAM) {
UriStack::push(&mut req);
let new_uri = with_path(req.uri(), tail);
*req.uri_mut() = new_uri;
}
insert_url_params(&mut req, params);
}
self.routes.call(req)
}
}
pub(crate) struct UriStack(Vec<Uri>);
impl UriStack {
fn push<B>(req: &mut Request<B>) {
let uri = req.uri().clone();
if let Some(stack) = req.extensions_mut().get_mut::<Self>() {
stack.0.push(uri);
} else {
req.extensions_mut().insert(Self(vec![uri]));
}
}
pub(crate) fn pop<B>(req: &mut Request<B>) -> Option<Uri> {
req.extensions_mut()
.get_mut::<Self>()
.and_then(|stack| stack.0.pop())
}
}
// we store the potential error here such that users can handle invalid path
@ -759,95 +938,84 @@ struct FromEmptyRouter<B> {
request: Request<B>,
}
/// A route that sends requests to one of two [`Service`]s depending on the
/// path.
#[derive(Debug, Clone)]
pub(crate) struct PathPattern(Arc<Inner>);
#[derive(Debug)]
struct Inner {
full_path_regex: Regex,
capture_group_names: Box<[Bytes]>,
pub struct Route<S, T> {
id: RouteId,
svc: S,
fallback: T,
}
impl PathPattern {
pub(crate) fn new(pattern: &str) -> Self {
assert!(pattern.starts_with('/'), "Route path must start with a `/`");
impl<B, S, T> Service<Request<B>> for Route<S, T>
where
S: Service<Request<B>, Response = Response<BoxBody>> + Clone,
T: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error> + Clone,
B: Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = S::Error;
type Future = RouteFuture<S, T, B>;
let mut capture_group_names = Vec::new();
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
let pattern = pattern
.split('/')
.map(|part| {
if let Some(key) = part.strip_prefix(':') {
capture_group_names.push(Bytes::copy_from_slice(key.as_bytes()));
Cow::Owned(format!("(?P<{}>[^/]+)", key))
fn call(&mut self, req: Request<B>) -> Self::Future {
match req.extensions().get::<RouteId>() {
Some(id) => {
if self.id == *id {
RouteFuture::a(self.svc.clone().oneshot(req))
} else {
Cow::Borrowed(part)
RouteFuture::b(self.fallback.clone().oneshot(req))
}
})
.collect::<Vec<_>>()
.join("/");
let full_path_regex =
Regex::new(&format!("^{}", pattern)).expect("invalid regex generated from route");
Self(Arc::new(Inner {
full_path_regex,
capture_group_names: capture_group_names.into(),
}))
}
pub(crate) fn full_match<B>(&self, req: &Request<B>) -> Option<Captures> {
self.do_match(req).and_then(|match_| {
if match_.full_match {
Some(match_.captures)
} else {
None
}
})
}
pub(crate) fn prefix_match<'a, B>(&self, req: &'a Request<B>) -> Option<(&'a str, Captures)> {
self.do_match(req)
.map(|match_| (match_.matched, match_.captures))
}
fn do_match<'a, B>(&self, req: &'a Request<B>) -> Option<Match<'a>> {
let path = req.uri().path();
self.0.full_path_regex.captures(path).map(|captures| {
let matched = captures.get(0).unwrap();
let full_match = matched.as_str() == path;
let captures = self
.0
.capture_group_names
.iter()
.map(|bytes| {
std::str::from_utf8(bytes)
.expect("bytes were created from str so is valid utf-8")
})
.filter_map(|name| captures.name(name).map(|value| (name, value.as_str())))
.map(|(key, value)| (key.to_string(), value.to_string()))
.collect::<Vec<_>>();
Match {
captures,
full_match,
matched: matched.as_str(),
}
})
None => RouteFuture::b(self.fallback.clone().oneshot(req)),
}
}
}
struct Match<'a> {
captures: Captures,
// true if regex matched whole path, false if it only matched a prefix
full_match: bool,
matched: &'a str,
/// A [`Service`] that has been nested inside a router at some path.
///
/// Created with [`Router::nest`].
#[derive(Debug, Clone)]
pub struct Nested<S, T> {
id: RouteId,
svc: S,
fallback: T,
}
type Captures = Vec<(String, String)>;
impl<B, S, T> Service<Request<B>> for Nested<S, T>
where
S: Service<Request<B>, Response = Response<BoxBody>> + Clone,
T: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error> + Clone,
B: Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = S::Error;
type Future = NestedFuture<S, T, B>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<B>) -> Self::Future {
let future = match req.extensions().get::<RouteId>() {
Some(id) => {
if self.id == *id {
RouteFuture::a(self.svc.clone().oneshot(req))
} else {
RouteFuture::b(self.fallback.clone().oneshot(req))
}
}
None => RouteFuture::b(self.fallback.clone().oneshot(req)),
};
NestedFuture { inner: future }
}
}
/// A boxed route trait object.
///
@ -889,60 +1057,8 @@ where
}
}
/// A [`Service`] that has been nested inside a router at some path.
///
/// Created with [`Router::nest`].
#[derive(Debug, Clone)]
pub struct Nested<S, F> {
pattern: PathPattern,
svc: S,
fallback: F,
}
impl<S, F, B> Service<Request<B>> for Nested<S, F>
where
S: Service<Request<B>, Response = Response<BoxBody>> + Clone,
F: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error> + Clone,
B: Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = S::Error;
type Future = NestedFuture<S, F, B>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
if req.extensions().get::<OriginalUri>().is_none() {
let original_uri = OriginalUri(req.uri().clone());
req.extensions_mut().insert(original_uri);
}
let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) {
let without_prefix = strip_prefix(req.uri(), prefix);
*req.uri_mut() = without_prefix;
insert_url_params(&mut req, captures);
let fut = self.svc.clone().oneshot(req);
RouteFuture::a(fut, self.fallback.clone())
} else {
let fut = self.fallback.clone().oneshot(req);
RouteFuture::b(fut)
};
NestedFuture { inner: f }
}
}
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
fn with_path(uri: &Uri, new_path: &str) -> Uri {
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
path
} else {
path_and_query.path()
};
let new_path = if new_path.starts_with('/') {
Cow::Borrowed(new_path)
} else {
@ -1032,48 +1148,6 @@ where
mod tests {
use super::*;
#[test]
fn test_routing() {
assert_match("/", "/");
assert_match("/foo", "/foo");
assert_match("/foo/", "/foo/");
refute_match("/foo", "/foo/");
refute_match("/foo/", "/foo");
assert_match("/foo/bar", "/foo/bar");
refute_match("/foo/bar/", "/foo/bar");
refute_match("/foo/bar", "/foo/bar/");
assert_match("/:value", "/foo");
assert_match("/users/:id", "/users/1");
assert_match("/users/:id/action", "/users/42/action");
refute_match("/users/:id/action", "/users/42");
refute_match("/users/:id", "/users/42/action");
}
fn assert_match(route_spec: &'static str, path: &'static str) {
let route = PathPattern::new(route_spec);
let req = Request::builder().uri(path).body(()).unwrap();
assert!(
route.full_match(&req).is_some(),
"`{}` doesn't match `{}`",
path,
route_spec
);
}
fn refute_match(route_spec: &'static str, path: &'static str) {
let route = PathPattern::new(route_spec);
let req = Request::builder().uri(path).body(()).unwrap();
assert!(
route.full_match(&req).is_none(),
"`{}` did match `{}` (but shouldn't)",
path,
route_spec
);
}
#[test]
fn traits() {
use crate::tests::*;

View file

@ -50,14 +50,11 @@ where
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let original_uri = req.uri().clone();
ResponseFuture {
state: State::FirstFuture {
f: self.first.clone().oneshot(req),
},
second: Some(self.second.clone()),
original_uri: Some(original_uri),
}
}
}
@ -72,9 +69,6 @@ pin_project! {
#[pin]
state: State<A, B, ReqBody>,
second: Option<B>,
// Some services, namely `Nested`, mutates the request URI so we must
// restore it to its original state before calling `second`
original_uri: Option<http::Uri>,
}
}
@ -109,7 +103,7 @@ where
StateProj::FirstFuture { f } => {
let mut response = ready!(f.poll(cx)?);
let mut req = if let Some(ext) = response
let req = if let Some(ext) = response
.extensions_mut()
.remove::<FromEmptyRouter<ReqBody>>()
{
@ -118,8 +112,6 @@ where
return Poll::Ready(Ok(response));
};
*req.uri_mut() = this.original_uri.take().unwrap();
let second = this.second.take().expect("future polled after completion");
State::SecondFuture {

View file

@ -1,31 +1,14 @@
#![allow(unused_imports, dead_code)]
use crate::BoxError;
use crate::{
extract,
handler::{any, delete, get, on, patch, post, Handler},
response::IntoResponse,
routing::MethodFilter,
service, Router,
};
use bytes::Bytes;
use http::{
header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION},
Method, Request, StatusCode, Uri,
header::{HeaderName, HeaderValue},
Request, StatusCode,
};
use hyper::{Body, Server};
use serde::Deserialize;
use serde_json::json;
use std::future::Ready;
use std::{
collections::HashMap,
convert::{Infallible, TryFrom},
future::ready,
convert::TryFrom,
net::{SocketAddr, TcpListener},
task::{Context, Poll},
time::Duration,
};
use tower::{make::Shared, service_fn};
use tower::make::Shared;
use tower_service::Service;
pub(crate) struct TestClient {

View file

@ -25,6 +25,7 @@ use std::{
time::Duration,
};
use tower::service_fn;
use tower::timeout::TimeoutLayer;
use tower_service::Service;
pub(crate) use helpers::*;
@ -479,29 +480,6 @@ async fn handler_into_service() {
assert_eq!(res.text().await, "you said: hi there!");
}
#[tokio::test]
async fn when_multiple_routes_match() {
let app = Router::new()
.route("/", post(|| async {}))
.route("/", get(|| async {}))
.route("/foo", get(|| async {}))
.nest("/foo", Router::new().route("/bar", get(|| async {})));
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::OK);
let res = client.post("/").send().await;
assert_eq!(res.status(), StatusCode::OK);
let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn captures_dont_match_empty_segments() {
let app = Router::new().route("/:key", get(|| async {}));
@ -539,6 +517,33 @@ async fn json_content_types() {
assert!(!valid_json_content_type("text/json").await);
}
#[tokio::test]
async fn wildcard_sees_whole_url() {
let app = Router::new().route("/api/*rest", get(|uri: Uri| async move { uri.to_string() }));
let client = TestClient::new(app);
let res = client.get("/api/foo/bar").send().await;
assert_eq!(res.text().await, "/api/foo/bar");
}
#[tokio::test]
async fn middleware_applies_to_routes_above() {
let app = Router::new()
.route("/one", get(std::future::pending::<()>))
.layer(TimeoutLayer::new(Duration::new(0, 0)))
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT))
.route("/two", get(|| async {}));
let client = TestClient::new(app);
let res = client.get("/one").send().await;
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
let res = client.get("/two").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
pub(crate) fn assert_send<T: Send>() {}
pub(crate) fn assert_sync<T: Sync>() {}
pub(crate) fn assert_unpin<T: Unpin>() {}

View file

@ -1,5 +1,6 @@
use super::*;
use crate::body::box_body;
use crate::routing::EmptyRouter;
use std::collections::HashMap;
#[tokio::test]
@ -13,6 +14,7 @@ async fn nesting_apps() {
"/users/:id",
get(
|params: extract::Path<HashMap<String, String>>| async move {
dbg!(&params);
format!(
"{}: users#show ({})",
params.get("version").unwrap(),
@ -179,3 +181,79 @@ async fn nest_static_file_server() {
let res = client.get("/static/README.md").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn nested_multiple_routes() {
let app = Router::new()
.nest(
"/api",
Router::new()
.route("/users", get(|| async { "users" }))
.route("/teams", get(|| async { "teams" })),
)
.route("/", get(|| async { "root" }));
let client = TestClient::new(app);
assert_eq!(client.get("/").send().await.text().await, "root");
assert_eq!(client.get("/api/users").send().await.text().await, "users");
assert_eq!(client.get("/api/teams").send().await.text().await, "teams");
}
#[tokio::test]
async fn nested_with_other_route_also_matching_with_route_first() {
let app = Router::new().route("/api", get(|| async { "api" })).nest(
"/api",
Router::new()
.route("/users", get(|| async { "users" }))
.route("/teams", get(|| async { "teams" })),
);
let client = TestClient::new(app);
assert_eq!(client.get("/api").send().await.text().await, "api");
assert_eq!(client.get("/api/users").send().await.text().await, "users");
assert_eq!(client.get("/api/teams").send().await.text().await, "teams");
}
#[tokio::test]
async fn nested_with_other_route_also_matching_with_route_last() {
let app = Router::new()
.nest(
"/api",
Router::new()
.route("/users", get(|| async { "users" }))
.route("/teams", get(|| async { "teams" })),
)
.route("/api", get(|| async { "api" }));
let client = TestClient::new(app);
assert_eq!(client.get("/api").send().await.text().await, "api");
assert_eq!(client.get("/api/users").send().await.text().await, "users");
assert_eq!(client.get("/api/teams").send().await.text().await, "teams");
}
#[tokio::test]
async fn multiple_top_level_nests() {
let app = Router::new()
.nest(
"/one",
Router::new().route("/route", get(|| async { "one" })),
)
.nest(
"/two",
Router::new().route("/route", get(|| async { "two" })),
);
let client = TestClient::new(app);
assert_eq!(client.get("/one/route").send().await.text().await, "one");
assert_eq!(client.get("/two/route").send().await.text().await, "two");
}
#[tokio::test]
#[should_panic(expected = "Invalid route: nested routes cannot contain wildcards (*)")]
async fn nest_cannot_contain_wildcards() {
Router::<EmptyRouter>::new().nest("/one/*rest", Router::<EmptyRouter>::new());
}

View file

@ -80,19 +80,19 @@ async fn multiple_ors_balanced_differently() {
}
#[tokio::test]
async fn or_nested_inside_other_thing() {
let inner = Router::new()
.route("/bar", get(|| async {}))
.or(Router::new().route("/baz", get(|| async {})));
let app = Router::new().nest("/foo", inner);
async fn nested_or() {
let bar = Router::new().route("/bar", get(|| async { "bar" }));
let baz = Router::new().route("/baz", get(|| async { "baz" }));
let client = TestClient::new(app);
let bar_or_baz = bar.or(baz);
let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
let client = TestClient::new(bar_or_baz.clone());
assert_eq!(client.get("/bar").send().await.text().await, "bar");
assert_eq!(client.get("/baz").send().await.text().await, "baz");
let res = client.get("/foo/baz").send().await;
assert_eq!(res.status(), StatusCode::OK);
let client = TestClient::new(Router::new().nest("/foo", bar_or_baz));
assert_eq!(client.get("/foo/bar").send().await.text().await, "bar");
assert_eq!(client.get("/foo/baz").send().await.text().await, "baz");
}
#[tokio::test]