From 00737c4e0a3740782c980ab01f0a864af298503a Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 2 Jun 2021 22:07:37 +0200 Subject: [PATCH] checkpoint --- Cargo.toml | 5 +- src/routing.rs | 299 ++++++++++++++++++++++++++++++++++++++++++------- src/tests.rs | 89 +++++++++++++++ 3 files changed, 353 insertions(+), 40 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 19f0caf2..2654bf49 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ futures-util = "0.3" http = "0.2" http-body = "0.4" hyper = "0.14" +itertools = "0.10" pin-project = "1.0" serde = "1.0" serde_json = "1.0" @@ -31,9 +32,9 @@ tracing-subscriber = "0.2" [dev-dependencies.tower-http] version = "0.1" features = [ - "trace", + "add-extension", "compression", "compression-full", - "add-extension", "fs", + "trace", ] diff --git a/src/routing.rs b/src/routing.rs index 21117c37..dc1c40d4 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -7,12 +7,14 @@ use crate::{ use bytes::Bytes; use futures_util::{future, ready}; use http::{Method, Request, Response, StatusCode}; +use itertools::{EitherOrBoth, Itertools}; use pin_project::pin_project; use std::{ convert::Infallible, fmt, future::Future, pin::Pin, + str, task::{Context, Poll}, }; use tower::{ @@ -64,7 +66,7 @@ macro_rules! define_route_at_methods { where S: Service, Response = Response, Error = Infallible> + Clone, { - self.add_route_service(service, Method::$method) + self.add_route_service(service, MethodOrPrefix::Method(Method::$method)) } }; @@ -101,6 +103,19 @@ impl RouteAt { define_route_at_methods!(RouteAt: connect, connect_service, CONNECT); define_route_at_methods!(RouteAt: trace, trace_service, TRACE); + pub fn nest( + self, + other: RouteBuilder, + ) -> RouteBuilder>, R>> { + let route_spec = self.route_spec.clone(); + let other = StripPrefix::new(other.into_service(), route_spec.clone()); + + self.add_route_service_with_spec( + other, + RouteSpec::new(MethodOrPrefix::Prefix(route_spec.clone()), route_spec), + ) + } + fn add_route( self, handler: H, @@ -109,10 +124,23 @@ impl RouteAt { where H: Handler, { - self.add_route_service(HandlerSvc::new(handler), method) + self.add_route_service(HandlerSvc::new(handler), MethodOrPrefix::Method(method)) } - fn add_route_service(self, service: S, method: Method) -> RouteBuilder> { + fn add_route_service( + self, + service: S, + method_or_prefix: MethodOrPrefix, + ) -> RouteBuilder> { + let route_spec = self.route_spec.clone(); + self.add_route_service_with_spec(service, RouteSpec::new(method_or_prefix, route_spec)) + } + + fn add_route_service_with_spec( + self, + service: S, + route_spec: RouteSpec, + ) -> RouteBuilder> { assert!( self.route_spec.starts_with(b"/"), "route spec must start with a slash (`/`)" @@ -121,7 +149,7 @@ impl RouteAt { let new_app = App { service_tree: Or { service, - route_spec: RouteSpec::new(method, self.route_spec.clone()), + route_spec, fallback: self.app.service_tree, handler_ready: false, fallback_ready: false, @@ -246,54 +274,107 @@ where } } -#[derive(Clone)] +#[derive(Debug, Clone)] struct RouteSpec { - method: Method, + method_or_prefix: MethodOrPrefix, spec: Bytes, + length_match: LengthMatch, +} + +#[derive(Debug, Clone)] +enum MethodOrPrefix { + AnyMethod, + Method(Method), + Prefix(Bytes), +} + +#[derive(Debug, Clone, Copy)] +enum LengthMatch { + Exact, + UriCanBeLonger, } impl RouteSpec { - fn new(method: Method, spec: impl Into) -> Self { + fn new(method_or_prefix: MethodOrPrefix, spec: impl Into) -> Self { Self { - method, + method_or_prefix, spec: spec.into(), + length_match: LengthMatch::Exact, } } + + fn length_match(mut self, length_match: LengthMatch) -> Self { + self.length_match = length_match; + self + } } impl RouteSpec { fn matches(&self, req: &Request) -> Option> { - if req.method() != self.method { - return None; + // println!("route spec comparing `{:?}` and `{:?}`", self, req.uri()); + + match &self.method_or_prefix { + MethodOrPrefix::Method(method) => { + if req.method() != method { + return None; + } + } + MethodOrPrefix::AnyMethod => {} + MethodOrPrefix::Prefix(prefix) => { + let route_spec = RouteSpec::new(MethodOrPrefix::AnyMethod, prefix.clone()) + .length_match(LengthMatch::UriCanBeLonger); + + if let Some(params) = route_spec.matches(req) { + return Some(params); + } + } } + let spec_parts = self.spec.split(|b| *b == b'/'); + let path = req.uri().path().as_bytes(); let path_parts = path.split(|b| *b == b'/'); - let spec_parts = self.spec.split(|b| *b == b'/'); - - if spec_parts.clone().count() != path_parts.clone().count() { - return None; - } - let mut params = Vec::new(); - spec_parts - .zip(path_parts) - .all(|(spec, path)| { - if let Some(key) = spec.strip_prefix(b":") { - let key = std::str::from_utf8(key).unwrap().to_string(); - if let Ok(value) = std::str::from_utf8(path) { - params.push((key, value.to_string())); - true - } else { - false + for pair in spec_parts.zip_longest(path_parts) { + match pair { + EitherOrBoth::Both(spec, path) => { + println!( + "both: ({:?}, {:?})", + str::from_utf8(spec).unwrap(), + str::from_utf8(path).unwrap() + ); + if let Some(key) = spec.strip_prefix(b":") { + let key = str::from_utf8(key).unwrap().to_string(); + if let Ok(value) = std::str::from_utf8(path) { + params.push((key, value.to_string())); + } else { + return None; + } + } else if spec != path { + return None; } - } else { - spec == path } - }) - .then(|| params) + EitherOrBoth::Left(spec) => { + println!("left: {:?}", str::from_utf8(spec).unwrap()); + return None; + } + EitherOrBoth::Right(path) => { + println!("right: {:?}", str::from_utf8(path).unwrap()); + match self.length_match { + LengthMatch::Exact => { + return None; + } + LengthMatch::UriCanBeLonger => { + return Some(params); + } + } + } + } + } + + Some(params) } } @@ -502,6 +583,95 @@ where .unwrap() } +#[derive(Debug, Clone)] +pub struct StripPrefix { + inner: S, + prefix: Bytes, +} + +impl StripPrefix { + fn new(inner: S, prefix: impl Into) -> Self { + Self { + inner, + prefix: prefix.into(), + } + } +} + +impl Service> for StripPrefix +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + use http::uri::{PathAndQuery, Uri}; + use std::convert::TryFrom; + + println!("strip prefix {:?} of {:?}", self.prefix, req.uri().path()); + + let (mut request_parts, body) = req.into_parts(); + let mut uri_parts = request_parts.uri.into_parts(); + + enum Control { + Continue(T), + Break, + } + + if let Some(path_and_query) = &uri_parts.path_and_query { + let path = path_and_query.path(); + + let prefix = str::from_utf8(&self.prefix).unwrap(); + + let iter = path + .split('/') + .zip_longest(prefix.split('/')) + .map(|pair| match pair { + EitherOrBoth::Both(path, prefix) => { + if prefix.starts_with(':') || path == prefix { + Control::Continue(path) + } else { + Control::Break + } + } + EitherOrBoth::Left(_) | EitherOrBoth::Right(_) => Control::Break, + }) + .take_while(|item| matches!(item, Control::Continue(_))) + .map(|item| { + if let Control::Continue(item) = item { + item + } else { + unreachable!() + } + }); + let prefix_with_captures_updated = + Itertools::intersperse(iter, "/").collect::(); + + if let Some(path_without_prefix) = path.strip_prefix(&prefix_with_captures_updated) { + let new = if let Some(query) = path_and_query.query() { + PathAndQuery::try_from(format!("{}?{}", &path_without_prefix, query)).unwrap() + } else { + PathAndQuery::try_from(path_without_prefix).unwrap() + }; + uri_parts.path_and_query = Some(new); + } + } + + request_parts.uri = Uri::from_parts(uri_parts).unwrap(); + + let req = Request::from_parts(request_parts, body); + + self.inner.call(req) + } +} + #[cfg(test)] mod tests { #[allow(unused_imports)] @@ -539,7 +709,7 @@ mod tests { } fn assert_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) { - let route = RouteSpec::new(route_spec.0.clone(), route_spec.1); + let route = RouteSpec::new(MethodOrPrefix::Method(route_spec.0.clone()), route_spec.1); let req = Request::builder() .method(req_spec.0.clone()) .uri(req_spec.1) @@ -548,16 +718,16 @@ mod tests { assert!( route.matches(&req).is_some(), - "`{} {}` doesn't match `{} {}`", + "`{} {}` doesn't match `{:?} {}`", req.method(), req.uri().path(), - route.method, - std::str::from_utf8(&route.spec).unwrap(), + route.method_or_prefix, + str::from_utf8(&route.spec).unwrap(), ); } fn refute_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) { - let route = RouteSpec::new(route_spec.0.clone(), route_spec.1); + let route = RouteSpec::new(MethodOrPrefix::Method(route_spec.0.clone()), route_spec.1); let req = Request::builder() .method(req_spec.0.clone()) .uri(req_spec.1) @@ -566,11 +736,64 @@ mod tests { assert!( route.matches(&req).is_none(), - "`{} {}` shouldn't match `{} {}`", + "`{} {}` shouldn't match `{:?} {}`", req.method(), req.uri().path(), - route.method, - std::str::from_utf8(&route.spec).unwrap(), + route.method_or_prefix, + str::from_utf8(&route.spec).unwrap(), + ); + } + + #[tokio::test] + async fn strip_prefix() { + let mut svc = StripPrefix::new( + tower::service_fn( + |req: Request<()>| async move { Ok::<_, Infallible>(req.uri().clone()) }, + ), + "/foo", + ); + + assert_eq!( + svc.call(Request::builder().uri("/foo/bar").body(()).unwrap()) + .await + .unwrap(), + "/bar" + ); + + assert_eq!( + svc.call(Request::builder().uri("/foo").body(()).unwrap()) + .await + .unwrap(), + "" + ); + + assert_eq!( + svc.call( + Request::builder() + .uri("http://example.com/foo/bar?key=value") + .body(()) + .unwrap() + ) + .await + .unwrap(), + "http://example.com/bar?key=value" + ); + } + + #[tokio::test] + async fn strip_prefix_with_capture() { + let mut svc = StripPrefix::new( + tower::service_fn( + |req: Request<()>| async move { Ok::<_, Infallible>(req.uri().clone()) }, + ), + "/:version/api", + ); + + assert_eq!( + svc.call(Request::builder().uri("/v0/api/foo").body(()).unwrap()) + .await + .unwrap(), + "/foo" ); } } diff --git a/src/tests.rs b/src/tests.rs index 836c63c3..4f30b172 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -397,6 +397,95 @@ async fn layer_on_whole_router() { assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); } +#[tokio::test] +async fn nesting() { + let api = app() + .at("/users") + .get(|_: Request| async { "users#index" }) + .post(|_: Request| async { "users#create" }) + .at("/users/:id") + .get( + |_: Request, params: extract::UrlParams<(i32,)>| async move { + let (id,) = params.0; + format!("users#show {}", id) + }, + ); + + let app = app() + .at("/foo") + .get(|_: Request| async { "foo" }) + .at("/api") + .nest(api) + .at("/bar") + .get(|_: Request| async { "bar" }) + .into_service(); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/api/users", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.text().await.unwrap(), "users#index"); + + let res = client + .post(format!("http://{}/api/users", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.text().await.unwrap(), "users#create"); + + let res = client + .get(format!("http://{}/api/users/42", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.text().await.unwrap(), "users#show 42"); + + let res = client + .get(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.text().await.unwrap(), "foo"); + + let res = client + .get(format!("http://{}/bar", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.text().await.unwrap(), "bar"); +} + +#[tokio::test] +async fn nesting_with_dynamic_part() { + let api = app().at("/users/:id").get( + |_: Request, params: extract::UrlParams<(String, i32)>| async move { + let (version, id) = params.0; + format!("users#show {} {}", version, id) + }, + ); + + let app = app().at("/:version/api").nest(api).into_service(); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/v0/api/users/123", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "users#show v0 123"); +} + +// TODO(david): nesting more deeply + // TODO(david): composing two apps // TODO(david): composing two apps with one at a "sub path" // TODO(david): composing two boxed apps