Nesting and more flexible routing dsl

This commit is contained in:
David Pedersen 2021-06-06 20:30:54 +02:00
parent b4e2750d6a
commit 1609191a74
3 changed files with 245 additions and 238 deletions

View file

@ -569,7 +569,7 @@ pub mod service;
#[doc(inline)] #[doc(inline)]
pub use self::{ pub use self::{
handler::{get, on, post, Handler}, handler::{get, on, post, Handler},
routing::AddRoute, routing::RoutingDsl,
}; };
pub use async_trait::async_trait; pub use async_trait::async_trait;
@ -581,7 +581,7 @@ pub mod prelude {
extract, extract,
handler::{get, on, post, Handler}, handler::{get, on, post, Handler},
response, route, response, route,
routing::AddRoute, routing::RoutingDsl,
}; };
pub use http::Request; pub use http::Request;
} }

View file

@ -1,7 +1,8 @@
use crate::{body::BoxBody, response::IntoResponse, ResultExt}; use crate::{body::BoxBody, response::IntoResponse, ResultExt};
use bytes::Bytes; use bytes::Bytes;
use futures_util::{future, ready}; use futures_util::{future, ready};
use http::{Method, Request, Response, StatusCode}; use http::{Method, Request, Response, StatusCode, Uri};
use http_body::Full;
use hyper::Body; use hyper::Body;
use itertools::Itertools; use itertools::Itertools;
use pin_project::pin_project; use pin_project::pin_project;
@ -62,36 +63,7 @@ pub struct Route<S, F> {
pub(crate) fallback: F, pub(crate) fallback: F,
} }
pub trait AddRoute: Sized { pub trait RoutingDsl: Sized {
fn route<T>(self, spec: &str, svc: T) -> Route<T, Self>
where
T: Service<Request<Body>, Error = Infallible> + Clone;
}
impl<S, F> Route<S, F> {
pub fn boxed<B>(self) -> BoxRoute<B>
where
Self: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Send + 'static,
<Self as Service<Request<Body>>>::Future: Send,
B: From<String> + 'static,
{
ServiceBuilder::new()
.layer_fn(BoxRoute)
.buffer(1024)
.layer(BoxService::layer())
.service(self)
}
pub fn layer<L>(self, layer: L) -> Layered<L::Service>
where
L: Layer<Self>,
L::Service: Service<Request<Body>> + Clone,
{
Layered(layer.layer(self))
}
}
impl<S, F> AddRoute for Route<S, F> {
fn route<T>(self, spec: &str, svc: T) -> Route<T, Self> fn route<T>(self, spec: &str, svc: T) -> Route<T, Self>
where where
T: Service<Request<Body>, Error = Infallible> + Clone, T: Service<Request<Body>, Error = Infallible> + Clone,
@ -102,7 +74,42 @@ impl<S, F> AddRoute for Route<S, F> {
fallback: self, fallback: self,
} }
} }
fn nest<T>(self, spec: &str, svc: T) -> Nested<T, Self>
where
T: Service<Request<Body>, Error = Infallible> + Clone,
{
Nested {
pattern: PathPattern::new(spec),
svc,
fallback: self,
} }
}
fn boxed<B>(self) -> BoxRoute<B>
where
Self: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Send + 'static,
<Self as Service<Request<Body>>>::Future: Send,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{
ServiceBuilder::new()
.layer_fn(BoxRoute)
.buffer(1024)
.layer(BoxService::layer())
.service(self)
}
fn layer<L>(self, layer: L) -> Layered<L::Service>
where
L: Layer<Self>,
L::Service: Service<Request<Body>> + Clone,
{
Layered(layer.layer(self))
}
}
impl<S, F> RoutingDsl for Route<S, F> {}
// ===== Routing service impls ===== // ===== Routing service impls =====
@ -130,7 +137,7 @@ where
} }
fn call(&mut self, mut req: Request<Body>) -> Self::Future { fn call(&mut self, mut req: Request<Body>) -> Self::Future {
if let Some(captures) = self.pattern.matches(req.uri().path()) { if let Some(captures) = self.pattern.full_match(req.uri().path()) {
insert_url_params(&mut req, captures); insert_url_params(&mut req, captures);
let response_future = self.svc.clone().oneshot(req); let response_future = self.svc.clone().oneshot(req);
future::Either::Left(BoxResponseBody(response_future)) future::Either::Left(BoxResponseBody(response_future))
@ -178,18 +185,7 @@ where
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub struct EmptyRouter; pub struct EmptyRouter;
impl AddRoute for EmptyRouter { impl RoutingDsl for EmptyRouter {}
fn route<S>(self, spec: &str, svc: S) -> Route<S, Self>
where
S: Service<Request<Body>, Error = Infallible> + Clone,
{
Route {
pattern: PathPattern::new(spec),
svc,
fallback: self,
}
}
}
impl<R> Service<R> for EmptyRouter { impl<R> Service<R> for EmptyRouter {
type Response = Response<Body>; type Response = Response<Body>;
@ -236,7 +232,7 @@ impl PathPattern {
.join("/"); .join("/");
let full_path_regex = let full_path_regex =
Regex::new(&format!("^{}$", pattern)).expect("invalid regex generated from route"); Regex::new(&format!("^{}", pattern)).expect("invalid regex generated from route");
Self(Arc::new(Inner { Self(Arc::new(Inner {
full_path_regex, full_path_regex,
@ -244,8 +240,26 @@ impl PathPattern {
})) }))
} }
pub(crate) fn matches(&self, path: &str) -> Option<Captures> { pub(crate) fn full_match(&self, path: &str) -> Option<Captures> {
self.do_match(path).and_then(|match_| {
if match_.full_match {
Some(match_.captures)
} else {
None
}
})
}
pub(crate) fn prefix_match<'a>(&self, path: &'a str) -> Option<(&'a str, Captures)> {
self.do_match(path)
.map(|match_| (match_.matched, match_.captures))
}
fn do_match<'a>(&self, path: &'a str) -> Option<Match<'a>> {
self.0.full_path_regex.captures(path).map(|captures| { self.0.full_path_regex.captures(path).map(|captures| {
let matched = captures.get(0).unwrap();
let full_match = matched.as_str() == path;
let captures = self let captures = self
.0 .0
.capture_group_names .capture_group_names
@ -258,11 +272,22 @@ impl PathPattern {
.map(|(key, value)| (key.to_string(), value.to_string())) .map(|(key, value)| (key.to_string(), value.to_string()))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
captures Match {
captures,
full_match,
matched: matched.as_str(),
}
}) })
} }
} }
struct Match<'a> {
captures: Captures,
// true if regex matched whole path, false if it only matched a prefix
full_match: bool,
matched: &'a str,
}
type Captures = Vec<(String, String)>; type Captures = Vec<(String, String)>;
// ===== BoxRoute ===== // ===== BoxRoute =====
@ -275,24 +300,14 @@ impl<B> Clone for BoxRoute<B> {
} }
} }
impl<B> AddRoute for BoxRoute<B> { impl<B> RoutingDsl for BoxRoute<B> {}
fn route<S>(self, spec: &str, svc: S) -> Route<S, Self>
where
S: Service<Request<Body>, Error = Infallible> + Clone,
{
Route {
pattern: PathPattern::new(spec),
svc,
fallback: self,
}
}
}
impl<B> Service<Request<Body>> for BoxRoute<B> impl<B> Service<Request<Body>> for BoxRoute<B>
where where
B: From<String> + 'static, B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{ {
type Response = Response<B>; type Response = Response<BoxBody>;
type Error = Infallible; type Error = Infallible;
type Future = BoxRouteResponseFuture<B>; type Future = BoxRouteResponseFuture<B>;
@ -317,29 +332,27 @@ type InnerFuture<B> = Oneshot<
impl<B> Future for BoxRouteResponseFuture<B> impl<B> Future for BoxRouteResponseFuture<B>
where where
B: From<String>, B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{ {
type Output = Result<Response<B>, Infallible>; type Output = Result<Response<BoxBody>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match ready!(self.project().0.poll(cx)) { match ready!(self.project().0.poll(cx)) {
Ok(res) => Poll::Ready(Ok(res)), Ok(res) => Poll::Ready(Ok(res.map(BoxBody::new))),
Err(err) => Poll::Ready(Ok(handle_buffer_error(err))), Err(err) => Poll::Ready(Ok(handle_buffer_error(err))),
} }
} }
} }
fn handle_buffer_error<B>(error: BoxError) -> Response<B> fn handle_buffer_error(error: BoxError) -> Response<BoxBody> {
where
B: From<String>,
{
use tower::buffer::error::{Closed, ServiceError}; use tower::buffer::error::{Closed, ServiceError};
let error = match error.downcast::<Closed>() { let error = match error.downcast::<Closed>() {
Ok(closed) => { Ok(closed) => {
return Response::builder() return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR) .status(StatusCode::INTERNAL_SERVER_ERROR)
.body(B::from(closed.to_string())) .body(BoxBody::new(Full::from(closed.to_string())))
.unwrap(); .unwrap();
} }
Err(e) => e, Err(e) => e,
@ -349,7 +362,7 @@ where
Ok(service_error) => { Ok(service_error) => {
return Response::builder() return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR) .status(StatusCode::INTERNAL_SERVER_ERROR)
.body(B::from(format!("Service error: {}. This is a bug in tower-web. All inner services should be infallible. Please file an issue", service_error))) .body(BoxBody::new(Full::from(format!("Service error: {}. This is a bug in tower-web. All inner services should be infallible. Please file an issue", service_error))))
.unwrap(); .unwrap();
} }
Err(e) => e, Err(e) => e,
@ -357,10 +370,10 @@ where
Response::builder() Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR) .status(StatusCode::INTERNAL_SERVER_ERROR)
.body(B::from(format!( .body(BoxBody::new(Full::from(format!(
"Uncountered an unknown error: {}. This should never happen. Please file an issue", "Uncountered an unknown error: {}. This should never happen. Please file an issue",
error error
))) ))))
.unwrap() .unwrap()
} }
@ -369,18 +382,7 @@ where
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Layered<S>(S); pub struct Layered<S>(S);
impl<S> AddRoute for Layered<S> { impl<S> RoutingDsl for Layered<S> {}
fn route<T>(self, spec: &str, svc: T) -> Route<T, Self>
where
T: Service<Request<Body>, Error = Infallible> + Clone,
{
Route {
pattern: PathPattern::new(spec),
svc,
fallback: self,
}
}
}
impl<S> Layered<S> { impl<S> Layered<S> {
pub fn handle_error<F, B, Res>(self, f: F) -> HandleError<S, F> pub fn handle_error<F, B, Res>(self, f: F) -> HandleError<S, F>
@ -420,18 +422,7 @@ pub struct HandleError<S, F> {
f: F, f: F,
} }
impl<S, F> AddRoute for HandleError<S, F> { impl<S, F> RoutingDsl for HandleError<S, F> {}
fn route<T>(self, spec: &str, svc: T) -> Route<T, Self>
where
T: Service<Request<Body>, Error = Infallible> + Clone,
{
Route {
pattern: PathPattern::new(spec),
svc,
fallback: self,
}
}
}
impl<S, F, B, Res> Service<Request<Body>> for HandleError<S, F> impl<S, F, B, Res> Service<Request<Body>> for HandleError<S, F>
where where
@ -487,6 +478,95 @@ where
} }
} }
// ===== nesting =====
pub fn nest<S>(spec: &str, svc: S) -> Nested<S, EmptyRouter>
where
S: Service<Request<Body>, Error = Infallible> + Clone,
{
Nested {
pattern: PathPattern::new(spec),
svc,
fallback: EmptyRouter,
}
}
#[derive(Debug, Clone)]
pub struct Nested<S, F> {
pattern: PathPattern,
svc: S,
fallback: F,
}
impl<S, F> RoutingDsl for Nested<S, F> {}
impl<S, F, SB, FB> Service<Request<Body>> for Nested<S, F>
where
S: Service<Request<Body>, Response = Response<SB>, Error = Infallible> + Clone,
SB: http_body::Body<Data = Bytes> + Send + Sync + 'static,
SB::Error: Into<BoxError>,
F: Service<Request<Body>, Response = Response<FB>, Error = Infallible> + Clone,
FB: http_body::Body<Data = Bytes> + Send + Sync + 'static,
FB::Error: Into<BoxError>,
{
type Response = Response<BoxBody>;
type Error = Infallible;
#[allow(clippy::type_complexity)]
type Future = future::Either<
BoxResponseBody<Oneshot<S, Request<Body>>>,
BoxResponseBody<Oneshot<F, Request<Body>>>,
>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) {
let without_prefix = strip_prefix(req.uri(), prefix);
*req.uri_mut() = without_prefix;
insert_url_params(&mut req, captures);
let response_future = self.svc.clone().oneshot(req);
future::Either::Left(BoxResponseBody(response_future))
} else {
let response_future = self.fallback.clone().oneshot(req);
future::Either::Right(BoxResponseBody(response_future))
}
}
}
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
path
} else {
path_and_query.path()
};
if let Some(query) = path_and_query.query() {
Some(
format!("{}?{}", new_path, query)
.parse::<http::uri::PathAndQuery>()
.unwrap(),
)
} else {
Some(new_path.parse().unwrap())
}
} else {
None
};
let mut parts = http::uri::Parts::default();
parts.scheme = uri.scheme().cloned();
parts.authority = uri.authority().cloned();
parts.path_and_query = path_and_query;
Uri::from_parts(parts).unwrap()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -514,7 +594,7 @@ mod tests {
fn assert_match(route_spec: &'static str, path: &'static str) { fn assert_match(route_spec: &'static str, path: &'static str) {
let route = PathPattern::new(route_spec); let route = PathPattern::new(route_spec);
assert!( assert!(
route.matches(path).is_some(), route.full_match(path).is_some(),
"`{}` doesn't match `{}`", "`{}` doesn't match `{}`",
path, path,
route_spec route_spec
@ -524,7 +604,7 @@ mod tests {
fn refute_match(route_spec: &'static str, path: &'static str) { fn refute_match(route_spec: &'static str, path: &'static str) {
let route = PathPattern::new(route_spec); let route = PathPattern::new(route_spec);
assert!( assert!(
route.matches(path).is_none(), route.full_match(path).is_none(),
"`{}` did match `{}` (but shouldn't)", "`{}` did match `{}` (but shouldn't)",
path, path,
route_spec route_spec

View file

@ -1,4 +1,4 @@
use crate::{extract, get, on, post, route, routing::MethodFilter, service, AddRoute, Handler}; use crate::{extract, get, on, post, route, routing::MethodFilter, service, Handler, RoutingDsl};
use http::{Request, Response, StatusCode}; use http::{Request, Response, StatusCode};
use hyper::{Body, Server}; use hyper::{Body, Server};
use serde::Deserialize; use serde::Deserialize;
@ -283,6 +283,7 @@ async fn boxing() {
"hi from POST" "hi from POST"
}), }),
) )
.layer(tower_http::compression::CompressionLayer::new())
.boxed(); .boxed();
let addr = run_in_background(app).await; let addr = run_in_background(app).await;
@ -485,150 +486,76 @@ async fn layer_on_whole_router() {
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
} }
// TODO(david): layer that changes the response body type to have a different error #[tokio::test]
async fn disjunction() {
let api_routes = route(
"/users",
get(|_: Request<Body>| async { "users#index" })
.post(|_: Request<Body>| async { "users#create" }),
)
.route(
"/users/:id",
get(
|_: Request<Body>, params: extract::UrlParamsMap| async move {
format!(
"{}: users#show ({})",
params.get("version").unwrap(),
params.get("id").unwrap()
)
},
),
)
.route(
"/games/:id",
get(
|_: Request<Body>, params: extract::UrlParamsMap| async move {
format!(
"{}: games#show ({})",
params.get("version").unwrap(),
params.get("id").unwrap()
)
},
),
);
// // #[tokio::test] let app = route("/", get(|_: Request<Body>| async { "hi" })).nest("/:version/api", api_routes);
// // async fn nesting() {
// // let api = app()
// // .at("/users")
// // .get(|_: Request<Body>| async { "users#index" })
// // .post(|_: Request<Body>| async { "users#create" })
// // .at("/users/:id")
// // .get(
// // |_: Request<Body>, params: extract::UrlParams<(i32,)>| async move {
// // let (id,) = params.0;
// // format!("users#show {}", id)
// // },
// // );
// // let app = app() let addr = run_in_background(app).await;
// // .at("/foo")
// // .get(|_: Request<Body>| async { "foo" })
// // .at("/api")
// // .nest(api)
// // .at("/bar")
// // .get(|_: Request<Body>| async { "bar" })
// // .into_service();
// // let addr = run_in_background(app).await; let client = reqwest::Client::new();
// // let client = reqwest::Client::new(); let res = client
.get(format!("http://{}/", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "hi");
// // let res = client let res = client
// // .get(format!("http://{}/api/users", addr)) .get(format!("http://{}/v0/api/users", addr))
// // .send() .send()
// // .await .await
// // .unwrap(); .unwrap();
// // assert_eq!(res.text().await.unwrap(), "users#index"); assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "users#index");
// // let res = client let res = client
// // .post(format!("http://{}/api/users", addr)) .get(format!("http://{}/v0/api/users/123", addr))
// // .send() .send()
// // .await .await
// // .unwrap(); .unwrap();
// // assert_eq!(res.text().await.unwrap(), "users#create"); assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "v0: users#show (123)");
// // let res = client let res = client
// // .get(format!("http://{}/api/users/42", addr)) .get(format!("http://{}/v0/api/games/123", addr))
// // .send() .send()
// // .await .await
// // .unwrap(); .unwrap();
// // assert_eq!(res.text().await.unwrap(), "users#show 42"); assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "v0: games#show (123)");
// // let res = client }
// // .get(format!("http://{}/foo", addr))
// // .send()
// // .await
// // .unwrap();
// // assert_eq!(res.text().await.unwrap(), "foo");
// // let res = client
// // .get(format!("http://{}/bar", addr))
// // .send()
// // .await
// // .unwrap();
// // assert_eq!(res.text().await.unwrap(), "bar");
// // }
// // #[tokio::test]
// // async fn nesting_with_dynamic_part() {
// // let api = app().at("/users/:id").get(
// // |_: Request<Body>, params: extract::UrlParamsMap| async move {
// // // let (version, id) = params.0;
// // dbg!(&params);
// // let version = params.get("version").unwrap();
// // let id = params.get("id").unwrap();
// // format!("users#show {} {}", version, id)
// // },
// // );
// // let app = app().at("/:version/api").nest(api).into_service();
// // let addr = run_in_background(app).await;
// // let client = reqwest::Client::new();
// // let res = client
// // .get(format!("http://{}/v0/api/users/123", addr))
// // .send()
// // .await
// // .unwrap();
// // let status = res.status();
// // assert_eq!(res.text().await.unwrap(), "users#show v0 123");
// // assert_eq!(status, StatusCode::OK);
// // }
// // #[tokio::test]
// // async fn nesting_more_deeply() {
// // let users_api = app()
// // .at("/:id")
// // .get(|req: Request<Body>| async move {
// // dbg!(&req.uri().path());
// // "users#show"
// // });
// // let games_api = app()
// // .at("/")
// // .post(|req: Request<Body>| async move {
// // dbg!(&req.uri().path());
// // "games#create"
// // });
// // let api = app()
// // .at("/users")
// // .nest(users_api)
// // .at("/games")
// // .nest(games_api);
// // let app = app().at("/:version/api").nest(api).into_service();
// // let addr = run_in_background(app).await;
// // let client = reqwest::Client::new();
// // // let res = client
// // // .get(format!("http://{}/v0/api/users/123", addr))
// // // .send()
// // // .await
// // // .unwrap();
// // // assert_eq!(res.status(), StatusCode::OK);
// // println!("============================");
// // let res = client
// // .post(format!("http://{}/v0/api/games", addr))
// // .send()
// // .await
// // .unwrap();
// // assert_eq!(res.status(), StatusCode::OK);
// // }
// // TODO(david): nesting more deeply
// // TODO(david): composing two apps
// // TODO(david): composing two apps with one at a "sub path"
// // TODO(david): composing two boxed apps
// // TODO(david): composing two apps that have had layers applied
/// Run a `tower::Service` in the background and get a URI for it. /// Run a `tower::Service` in the background and get a URI for it.
async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr