checkpoint

This commit is contained in:
David Pedersen 2021-06-02 22:07:37 +02:00
parent 0d7e1e74c4
commit 00737c4e0a
3 changed files with 353 additions and 40 deletions

View file

@ -11,6 +11,7 @@ futures-util = "0.3"
http = "0.2" http = "0.2"
http-body = "0.4" http-body = "0.4"
hyper = "0.14" hyper = "0.14"
itertools = "0.10"
pin-project = "1.0" pin-project = "1.0"
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
@ -31,9 +32,9 @@ tracing-subscriber = "0.2"
[dev-dependencies.tower-http] [dev-dependencies.tower-http]
version = "0.1" version = "0.1"
features = [ features = [
"trace", "add-extension",
"compression", "compression",
"compression-full", "compression-full",
"add-extension",
"fs", "fs",
"trace",
] ]

View file

@ -7,12 +7,14 @@ use crate::{
use bytes::Bytes; use bytes::Bytes;
use futures_util::{future, ready}; use futures_util::{future, ready};
use http::{Method, Request, Response, StatusCode}; use http::{Method, Request, Response, StatusCode};
use itertools::{EitherOrBoth, Itertools};
use pin_project::pin_project; use pin_project::pin_project;
use std::{ use std::{
convert::Infallible, convert::Infallible,
fmt, fmt,
future::Future, future::Future,
pin::Pin, pin::Pin,
str,
task::{Context, Poll}, task::{Context, Poll},
}; };
use tower::{ use tower::{
@ -64,7 +66,7 @@ macro_rules! define_route_at_methods {
where where
S: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Clone, S: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Clone,
{ {
self.add_route_service(service, Method::$method) self.add_route_service(service, MethodOrPrefix::Method(Method::$method))
} }
}; };
@ -101,6 +103,19 @@ impl<R> RouteAt<R> {
define_route_at_methods!(RouteAt: connect, connect_service, CONNECT); define_route_at_methods!(RouteAt: connect, connect_service, CONNECT);
define_route_at_methods!(RouteAt: trace, trace_service, TRACE); define_route_at_methods!(RouteAt: trace, trace_service, TRACE);
pub fn nest<T>(
self,
other: RouteBuilder<T>,
) -> RouteBuilder<Or<StripPrefix<IntoService<T>>, R>> {
let route_spec = self.route_spec.clone();
let other = StripPrefix::new(other.into_service(), route_spec.clone());
self.add_route_service_with_spec(
other,
RouteSpec::new(MethodOrPrefix::Prefix(route_spec.clone()), route_spec),
)
}
fn add_route<H, B, T>( fn add_route<H, B, T>(
self, self,
handler: H, handler: H,
@ -109,10 +124,23 @@ impl<R> RouteAt<R> {
where where
H: Handler<B, T>, H: Handler<B, T>,
{ {
self.add_route_service(HandlerSvc::new(handler), method) self.add_route_service(HandlerSvc::new(handler), MethodOrPrefix::Method(method))
} }
fn add_route_service<S>(self, service: S, method: Method) -> RouteBuilder<Or<S, R>> { fn add_route_service<S>(
self,
service: S,
method_or_prefix: MethodOrPrefix,
) -> RouteBuilder<Or<S, R>> {
let route_spec = self.route_spec.clone();
self.add_route_service_with_spec(service, RouteSpec::new(method_or_prefix, route_spec))
}
fn add_route_service_with_spec<S>(
self,
service: S,
route_spec: RouteSpec,
) -> RouteBuilder<Or<S, R>> {
assert!( assert!(
self.route_spec.starts_with(b"/"), self.route_spec.starts_with(b"/"),
"route spec must start with a slash (`/`)" "route spec must start with a slash (`/`)"
@ -121,7 +149,7 @@ impl<R> RouteAt<R> {
let new_app = App { let new_app = App {
service_tree: Or { service_tree: Or {
service, service,
route_spec: RouteSpec::new(method, self.route_spec.clone()), route_spec,
fallback: self.app.service_tree, fallback: self.app.service_tree,
handler_ready: false, handler_ready: false,
fallback_ready: false, fallback_ready: false,
@ -246,54 +274,107 @@ where
} }
} }
#[derive(Clone)] #[derive(Debug, Clone)]
struct RouteSpec { struct RouteSpec {
method: Method, method_or_prefix: MethodOrPrefix,
spec: Bytes, spec: Bytes,
length_match: LengthMatch,
}
#[derive(Debug, Clone)]
enum MethodOrPrefix {
AnyMethod,
Method(Method),
Prefix(Bytes),
}
#[derive(Debug, Clone, Copy)]
enum LengthMatch {
Exact,
UriCanBeLonger,
} }
impl RouteSpec { impl RouteSpec {
fn new(method: Method, spec: impl Into<Bytes>) -> Self { fn new(method_or_prefix: MethodOrPrefix, spec: impl Into<Bytes>) -> Self {
Self { Self {
method, method_or_prefix,
spec: spec.into(), spec: spec.into(),
length_match: LengthMatch::Exact,
} }
} }
fn length_match(mut self, length_match: LengthMatch) -> Self {
self.length_match = length_match;
self
}
} }
impl RouteSpec { impl RouteSpec {
fn matches<B>(&self, req: &Request<B>) -> Option<Vec<(String, String)>> { fn matches<B>(&self, req: &Request<B>) -> Option<Vec<(String, String)>> {
if req.method() != self.method { // println!("route spec comparing `{:?}` and `{:?}`", self, req.uri());
return None;
match &self.method_or_prefix {
MethodOrPrefix::Method(method) => {
if req.method() != method {
return None;
}
}
MethodOrPrefix::AnyMethod => {}
MethodOrPrefix::Prefix(prefix) => {
let route_spec = RouteSpec::new(MethodOrPrefix::AnyMethod, prefix.clone())
.length_match(LengthMatch::UriCanBeLonger);
if let Some(params) = route_spec.matches(req) {
return Some(params);
}
}
} }
let spec_parts = self.spec.split(|b| *b == b'/');
let path = req.uri().path().as_bytes(); let path = req.uri().path().as_bytes();
let path_parts = path.split(|b| *b == b'/'); let path_parts = path.split(|b| *b == b'/');
let spec_parts = self.spec.split(|b| *b == b'/');
if spec_parts.clone().count() != path_parts.clone().count() {
return None;
}
let mut params = Vec::new(); let mut params = Vec::new();
spec_parts for pair in spec_parts.zip_longest(path_parts) {
.zip(path_parts) match pair {
.all(|(spec, path)| { EitherOrBoth::Both(spec, path) => {
if let Some(key) = spec.strip_prefix(b":") { println!(
let key = std::str::from_utf8(key).unwrap().to_string(); "both: ({:?}, {:?})",
if let Ok(value) = std::str::from_utf8(path) { str::from_utf8(spec).unwrap(),
params.push((key, value.to_string())); str::from_utf8(path).unwrap()
true );
} else { if let Some(key) = spec.strip_prefix(b":") {
false let key = str::from_utf8(key).unwrap().to_string();
if let Ok(value) = std::str::from_utf8(path) {
params.push((key, value.to_string()));
} else {
return None;
}
} else if spec != path {
return None;
} }
} else {
spec == path
} }
}) EitherOrBoth::Left(spec) => {
.then(|| params) println!("left: {:?}", str::from_utf8(spec).unwrap());
return None;
}
EitherOrBoth::Right(path) => {
println!("right: {:?}", str::from_utf8(path).unwrap());
match self.length_match {
LengthMatch::Exact => {
return None;
}
LengthMatch::UriCanBeLonger => {
return Some(params);
}
}
}
}
}
Some(params)
} }
} }
@ -502,6 +583,95 @@ where
.unwrap() .unwrap()
} }
#[derive(Debug, Clone)]
pub struct StripPrefix<S> {
inner: S,
prefix: Bytes,
}
impl<S> StripPrefix<S> {
fn new(inner: S, prefix: impl Into<Bytes>) -> Self {
Self {
inner,
prefix: prefix.into(),
}
}
}
impl<S, B> Service<Request<B>> for StripPrefix<S>
where
S: Service<Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
use http::uri::{PathAndQuery, Uri};
use std::convert::TryFrom;
println!("strip prefix {:?} of {:?}", self.prefix, req.uri().path());
let (mut request_parts, body) = req.into_parts();
let mut uri_parts = request_parts.uri.into_parts();
enum Control<T> {
Continue(T),
Break,
}
if let Some(path_and_query) = &uri_parts.path_and_query {
let path = path_and_query.path();
let prefix = str::from_utf8(&self.prefix).unwrap();
let iter = path
.split('/')
.zip_longest(prefix.split('/'))
.map(|pair| match pair {
EitherOrBoth::Both(path, prefix) => {
if prefix.starts_with(':') || path == prefix {
Control::Continue(path)
} else {
Control::Break
}
}
EitherOrBoth::Left(_) | EitherOrBoth::Right(_) => Control::Break,
})
.take_while(|item| matches!(item, Control::Continue(_)))
.map(|item| {
if let Control::Continue(item) = item {
item
} else {
unreachable!()
}
});
let prefix_with_captures_updated =
Itertools::intersperse(iter, "/").collect::<String>();
if let Some(path_without_prefix) = path.strip_prefix(&prefix_with_captures_updated) {
let new = if let Some(query) = path_and_query.query() {
PathAndQuery::try_from(format!("{}?{}", &path_without_prefix, query)).unwrap()
} else {
PathAndQuery::try_from(path_without_prefix).unwrap()
};
uri_parts.path_and_query = Some(new);
}
}
request_parts.uri = Uri::from_parts(uri_parts).unwrap();
let req = Request::from_parts(request_parts, body);
self.inner.call(req)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[allow(unused_imports)] #[allow(unused_imports)]
@ -539,7 +709,7 @@ mod tests {
} }
fn assert_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) { fn assert_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) {
let route = RouteSpec::new(route_spec.0.clone(), route_spec.1); let route = RouteSpec::new(MethodOrPrefix::Method(route_spec.0.clone()), route_spec.1);
let req = Request::builder() let req = Request::builder()
.method(req_spec.0.clone()) .method(req_spec.0.clone())
.uri(req_spec.1) .uri(req_spec.1)
@ -548,16 +718,16 @@ mod tests {
assert!( assert!(
route.matches(&req).is_some(), route.matches(&req).is_some(),
"`{} {}` doesn't match `{} {}`", "`{} {}` doesn't match `{:?} {}`",
req.method(), req.method(),
req.uri().path(), req.uri().path(),
route.method, route.method_or_prefix,
std::str::from_utf8(&route.spec).unwrap(), str::from_utf8(&route.spec).unwrap(),
); );
} }
fn refute_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) { fn refute_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) {
let route = RouteSpec::new(route_spec.0.clone(), route_spec.1); let route = RouteSpec::new(MethodOrPrefix::Method(route_spec.0.clone()), route_spec.1);
let req = Request::builder() let req = Request::builder()
.method(req_spec.0.clone()) .method(req_spec.0.clone())
.uri(req_spec.1) .uri(req_spec.1)
@ -566,11 +736,64 @@ mod tests {
assert!( assert!(
route.matches(&req).is_none(), route.matches(&req).is_none(),
"`{} {}` shouldn't match `{} {}`", "`{} {}` shouldn't match `{:?} {}`",
req.method(), req.method(),
req.uri().path(), req.uri().path(),
route.method, route.method_or_prefix,
std::str::from_utf8(&route.spec).unwrap(), str::from_utf8(&route.spec).unwrap(),
);
}
#[tokio::test]
async fn strip_prefix() {
let mut svc = StripPrefix::new(
tower::service_fn(
|req: Request<()>| async move { Ok::<_, Infallible>(req.uri().clone()) },
),
"/foo",
);
assert_eq!(
svc.call(Request::builder().uri("/foo/bar").body(()).unwrap())
.await
.unwrap(),
"/bar"
);
assert_eq!(
svc.call(Request::builder().uri("/foo").body(()).unwrap())
.await
.unwrap(),
""
);
assert_eq!(
svc.call(
Request::builder()
.uri("http://example.com/foo/bar?key=value")
.body(())
.unwrap()
)
.await
.unwrap(),
"http://example.com/bar?key=value"
);
}
#[tokio::test]
async fn strip_prefix_with_capture() {
let mut svc = StripPrefix::new(
tower::service_fn(
|req: Request<()>| async move { Ok::<_, Infallible>(req.uri().clone()) },
),
"/:version/api",
);
assert_eq!(
svc.call(Request::builder().uri("/v0/api/foo").body(()).unwrap())
.await
.unwrap(),
"/foo"
); );
} }
} }

View file

@ -397,6 +397,95 @@ async fn layer_on_whole_router() {
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
} }
#[tokio::test]
async fn nesting() {
let api = app()
.at("/users")
.get(|_: Request<Body>| async { "users#index" })
.post(|_: Request<Body>| async { "users#create" })
.at("/users/:id")
.get(
|_: Request<Body>, params: extract::UrlParams<(i32,)>| async move {
let (id,) = params.0;
format!("users#show {}", id)
},
);
let app = app()
.at("/foo")
.get(|_: Request<Body>| async { "foo" })
.at("/api")
.nest(api)
.at("/bar")
.get(|_: Request<Body>| async { "bar" })
.into_service();
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/api/users", addr))
.send()
.await
.unwrap();
assert_eq!(res.text().await.unwrap(), "users#index");
let res = client
.post(format!("http://{}/api/users", addr))
.send()
.await
.unwrap();
assert_eq!(res.text().await.unwrap(), "users#create");
let res = client
.get(format!("http://{}/api/users/42", addr))
.send()
.await
.unwrap();
assert_eq!(res.text().await.unwrap(), "users#show 42");
let res = client
.get(format!("http://{}/foo", addr))
.send()
.await
.unwrap();
assert_eq!(res.text().await.unwrap(), "foo");
let res = client
.get(format!("http://{}/bar", addr))
.send()
.await
.unwrap();
assert_eq!(res.text().await.unwrap(), "bar");
}
#[tokio::test]
async fn nesting_with_dynamic_part() {
let api = app().at("/users/:id").get(
|_: Request<Body>, params: extract::UrlParams<(String, i32)>| async move {
let (version, id) = params.0;
format!("users#show {} {}", version, id)
},
);
let app = app().at("/:version/api").nest(api).into_service();
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{}/v0/api/users/123", addr))
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "users#show v0 123");
}
// TODO(david): nesting more deeply
// TODO(david): composing two apps // TODO(david): composing two apps
// TODO(david): composing two apps with one at a "sub path" // TODO(david): composing two apps with one at a "sub path"
// TODO(david): composing two boxed apps // TODO(david): composing two boxed apps