From e156bc40e17affee064c7e7489a6713d3cc88050 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 3 Jun 2021 21:36:39 +0200 Subject: [PATCH] not quite working --- src/extract.rs | 1 + src/routing.rs | 247 +++++-------------------------------------------- src/tests.rs | 195 +++++++++++++++++++++++--------------- 3 files changed, 147 insertions(+), 296 deletions(-) diff --git a/src/extract.rs b/src/extract.rs index 4ccbe6a6..f8b4b3c6 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -284,6 +284,7 @@ define_rejection! { pub struct MissingRouteParams(()); } +#[derive(Debug)] pub struct UrlParamsMap(HashMap); impl UrlParamsMap { diff --git a/src/routing.rs b/src/routing.rs index dc1c40d4..cba03b9f 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -66,7 +66,7 @@ macro_rules! define_route_at_methods { where S: Service, Response = Response, Error = Infallible> + Clone, { - self.add_route_service(service, MethodOrPrefix::Method(Method::$method)) + self.add_route_service(service, Method::$method) } }; @@ -103,19 +103,6 @@ impl RouteAt { define_route_at_methods!(RouteAt: connect, connect_service, CONNECT); define_route_at_methods!(RouteAt: trace, trace_service, TRACE); - pub fn nest( - self, - other: RouteBuilder, - ) -> RouteBuilder>, R>> { - let route_spec = self.route_spec.clone(); - let other = StripPrefix::new(other.into_service(), route_spec.clone()); - - self.add_route_service_with_spec( - other, - RouteSpec::new(MethodOrPrefix::Prefix(route_spec.clone()), route_spec), - ) - } - fn add_route( self, handler: H, @@ -124,16 +111,12 @@ impl RouteAt { where H: Handler, { - self.add_route_service(HandlerSvc::new(handler), MethodOrPrefix::Method(method)) + self.add_route_service(HandlerSvc::new(handler), method) } - fn add_route_service( - self, - service: S, - method_or_prefix: MethodOrPrefix, - ) -> RouteBuilder> { + fn add_route_service(self, service: S, method: Method) -> RouteBuilder> { let route_spec = self.route_spec.clone(); - self.add_route_service_with_spec(service, RouteSpec::new(method_or_prefix, route_spec)) + self.add_route_service_with_spec(service, RouteSpec::new(method, route_spec)) } fn add_route_service_with_spec( @@ -276,58 +259,24 @@ where #[derive(Debug, Clone)] struct RouteSpec { - method_or_prefix: MethodOrPrefix, + method: Method, spec: Bytes, - length_match: LengthMatch, -} - -#[derive(Debug, Clone)] -enum MethodOrPrefix { - AnyMethod, - Method(Method), - Prefix(Bytes), -} - -#[derive(Debug, Clone, Copy)] -enum LengthMatch { - Exact, - UriCanBeLonger, } impl RouteSpec { - fn new(method_or_prefix: MethodOrPrefix, spec: impl Into) -> Self { + fn new(method: Method, spec: impl Into) -> Self { Self { - method_or_prefix, + method, spec: spec.into(), - length_match: LengthMatch::Exact, } } - - fn length_match(mut self, length_match: LengthMatch) -> Self { - self.length_match = length_match; - self - } } impl RouteSpec { fn matches(&self, req: &Request) -> Option> { - // println!("route spec comparing `{:?}` and `{:?}`", self, req.uri()); - - match &self.method_or_prefix { - MethodOrPrefix::Method(method) => { - if req.method() != method { - return None; - } - } - MethodOrPrefix::AnyMethod => {} - MethodOrPrefix::Prefix(prefix) => { - let route_spec = RouteSpec::new(MethodOrPrefix::AnyMethod, prefix.clone()) - .length_match(LengthMatch::UriCanBeLonger); - - if let Some(params) = route_spec.matches(req) { - return Some(params); - } - } + // TODO(david): perform this matching outside + if req.method() != self.method { + return None; } let spec_parts = self.spec.split(|b| *b == b'/'); @@ -340,11 +289,6 @@ impl RouteSpec { for pair in spec_parts.zip_longest(path_parts) { match pair { EitherOrBoth::Both(spec, path) => { - println!( - "both: ({:?}, {:?})", - str::from_utf8(spec).unwrap(), - str::from_utf8(path).unwrap() - ); if let Some(key) = spec.strip_prefix(b":") { let key = str::from_utf8(key).unwrap().to_string(); if let Ok(value) = std::str::from_utf8(path) { @@ -356,21 +300,9 @@ impl RouteSpec { return None; } } - EitherOrBoth::Left(spec) => { - println!("left: {:?}", str::from_utf8(spec).unwrap()); + EitherOrBoth::Left(_) | EitherOrBoth::Right(_) => { return None; } - EitherOrBoth::Right(path) => { - println!("right: {:?}", str::from_utf8(path).unwrap()); - match self.length_match { - LengthMatch::Exact => { - return None; - } - LengthMatch::UriCanBeLonger => { - return Some(params); - } - } - } } } @@ -419,7 +351,7 @@ where self.handler_ready = false; - req.extensions_mut().insert(Some(UrlParams(params))); + insert_url_params(&mut req, params); future::Either::Left(BoxResponseBody(self.service.call(req))) } else { @@ -436,6 +368,7 @@ where } } +#[derive(Debug)] pub(crate) struct UrlParams(pub(crate) Vec<(String, String)>); #[pin_project] @@ -583,92 +516,13 @@ where .unwrap() } -#[derive(Debug, Clone)] -pub struct StripPrefix { - inner: S, - prefix: Bytes, -} - -impl StripPrefix { - fn new(inner: S, prefix: impl Into) -> Self { - Self { - inner, - prefix: prefix.into(), - } - } -} - -impl Service> for StripPrefix -where - S: Service>, -{ - type Response = S::Response; - type Error = S::Error; - type Future = S::Future; - - #[inline] - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, req: Request) -> Self::Future { - use http::uri::{PathAndQuery, Uri}; - use std::convert::TryFrom; - - println!("strip prefix {:?} of {:?}", self.prefix, req.uri().path()); - - let (mut request_parts, body) = req.into_parts(); - let mut uri_parts = request_parts.uri.into_parts(); - - enum Control { - Continue(T), - Break, - } - - if let Some(path_and_query) = &uri_parts.path_and_query { - let path = path_and_query.path(); - - let prefix = str::from_utf8(&self.prefix).unwrap(); - - let iter = path - .split('/') - .zip_longest(prefix.split('/')) - .map(|pair| match pair { - EitherOrBoth::Both(path, prefix) => { - if prefix.starts_with(':') || path == prefix { - Control::Continue(path) - } else { - Control::Break - } - } - EitherOrBoth::Left(_) | EitherOrBoth::Right(_) => Control::Break, - }) - .take_while(|item| matches!(item, Control::Continue(_))) - .map(|item| { - if let Control::Continue(item) = item { - item - } else { - unreachable!() - } - }); - let prefix_with_captures_updated = - Itertools::intersperse(iter, "/").collect::(); - - if let Some(path_without_prefix) = path.strip_prefix(&prefix_with_captures_updated) { - let new = if let Some(query) = path_and_query.query() { - PathAndQuery::try_from(format!("{}?{}", &path_without_prefix, query)).unwrap() - } else { - PathAndQuery::try_from(path_without_prefix).unwrap() - }; - uri_parts.path_and_query = Some(new); - } - } - - request_parts.uri = Uri::from_parts(uri_parts).unwrap(); - - let req = Request::from_parts(request_parts, body); - - self.inner.call(req) +fn insert_url_params(req: &mut Request, params: Vec<(String, String)>) { + if let Some(current) = req.extensions_mut().get_mut::>() { + let mut current = current.take().unwrap(); + current.0.extend(params); + req.extensions_mut().insert(Some(current)); + } else { + req.extensions_mut().insert(Some(UrlParams(params))); } } @@ -709,7 +563,7 @@ mod tests { } fn assert_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) { - let route = RouteSpec::new(MethodOrPrefix::Method(route_spec.0.clone()), route_spec.1); + let route = RouteSpec::new(route_spec.0.clone(), route_spec.1); let req = Request::builder() .method(req_spec.0.clone()) .uri(req_spec.1) @@ -721,13 +575,13 @@ mod tests { "`{} {}` doesn't match `{:?} {}`", req.method(), req.uri().path(), - route.method_or_prefix, + route.method, str::from_utf8(&route.spec).unwrap(), ); } fn refute_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) { - let route = RouteSpec::new(MethodOrPrefix::Method(route_spec.0.clone()), route_spec.1); + let route = RouteSpec::new(route_spec.0.clone(), route_spec.1); let req = Request::builder() .method(req_spec.0.clone()) .uri(req_spec.1) @@ -739,61 +593,8 @@ mod tests { "`{} {}` shouldn't match `{:?} {}`", req.method(), req.uri().path(), - route.method_or_prefix, + route.method, str::from_utf8(&route.spec).unwrap(), ); } - - #[tokio::test] - async fn strip_prefix() { - let mut svc = StripPrefix::new( - tower::service_fn( - |req: Request<()>| async move { Ok::<_, Infallible>(req.uri().clone()) }, - ), - "/foo", - ); - - assert_eq!( - svc.call(Request::builder().uri("/foo/bar").body(()).unwrap()) - .await - .unwrap(), - "/bar" - ); - - assert_eq!( - svc.call(Request::builder().uri("/foo").body(()).unwrap()) - .await - .unwrap(), - "" - ); - - assert_eq!( - svc.call( - Request::builder() - .uri("http://example.com/foo/bar?key=value") - .body(()) - .unwrap() - ) - .await - .unwrap(), - "http://example.com/bar?key=value" - ); - } - - #[tokio::test] - async fn strip_prefix_with_capture() { - let mut svc = StripPrefix::new( - tower::service_fn( - |req: Request<()>| async move { Ok::<_, Infallible>(req.uri().clone()) }, - ), - "/:version/api", - ); - - assert_eq!( - svc.call(Request::builder().uri("/v0/api/foo").body(()).unwrap()) - .await - .unwrap(), - "/foo" - ); - } } diff --git a/src/tests.rs b/src/tests.rs index 4f30b172..f9f0652d 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -397,92 +397,141 @@ async fn layer_on_whole_router() { assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); } -#[tokio::test] -async fn nesting() { - let api = app() - .at("/users") - .get(|_: Request| async { "users#index" }) - .post(|_: Request| async { "users#create" }) - .at("/users/:id") - .get( - |_: Request, params: extract::UrlParams<(i32,)>| async move { - let (id,) = params.0; - format!("users#show {}", id) - }, - ); +// #[tokio::test] +// async fn nesting() { +// let api = app() +// .at("/users") +// .get(|_: Request| async { "users#index" }) +// .post(|_: Request| async { "users#create" }) +// .at("/users/:id") +// .get( +// |_: Request, params: extract::UrlParams<(i32,)>| async move { +// let (id,) = params.0; +// format!("users#show {}", id) +// }, +// ); - let app = app() - .at("/foo") - .get(|_: Request| async { "foo" }) - .at("/api") - .nest(api) - .at("/bar") - .get(|_: Request| async { "bar" }) - .into_service(); +// let app = app() +// .at("/foo") +// .get(|_: Request| async { "foo" }) +// .at("/api") +// .nest(api) +// .at("/bar") +// .get(|_: Request| async { "bar" }) +// .into_service(); - let addr = run_in_background(app).await; +// let addr = run_in_background(app).await; - let client = reqwest::Client::new(); +// let client = reqwest::Client::new(); - let res = client - .get(format!("http://{}/api/users", addr)) - .send() - .await - .unwrap(); - assert_eq!(res.text().await.unwrap(), "users#index"); +// let res = client +// .get(format!("http://{}/api/users", addr)) +// .send() +// .await +// .unwrap(); +// assert_eq!(res.text().await.unwrap(), "users#index"); - let res = client - .post(format!("http://{}/api/users", addr)) - .send() - .await - .unwrap(); - assert_eq!(res.text().await.unwrap(), "users#create"); +// let res = client +// .post(format!("http://{}/api/users", addr)) +// .send() +// .await +// .unwrap(); +// assert_eq!(res.text().await.unwrap(), "users#create"); - let res = client - .get(format!("http://{}/api/users/42", addr)) - .send() - .await - .unwrap(); - assert_eq!(res.text().await.unwrap(), "users#show 42"); +// let res = client +// .get(format!("http://{}/api/users/42", addr)) +// .send() +// .await +// .unwrap(); +// assert_eq!(res.text().await.unwrap(), "users#show 42"); - let res = client - .get(format!("http://{}/foo", addr)) - .send() - .await - .unwrap(); - assert_eq!(res.text().await.unwrap(), "foo"); +// let res = client +// .get(format!("http://{}/foo", addr)) +// .send() +// .await +// .unwrap(); +// assert_eq!(res.text().await.unwrap(), "foo"); - let res = client - .get(format!("http://{}/bar", addr)) - .send() - .await - .unwrap(); - assert_eq!(res.text().await.unwrap(), "bar"); -} +// let res = client +// .get(format!("http://{}/bar", addr)) +// .send() +// .await +// .unwrap(); +// assert_eq!(res.text().await.unwrap(), "bar"); +// } -#[tokio::test] -async fn nesting_with_dynamic_part() { - let api = app().at("/users/:id").get( - |_: Request, params: extract::UrlParams<(String, i32)>| async move { - let (version, id) = params.0; - format!("users#show {} {}", version, id) - }, - ); +// #[tokio::test] +// async fn nesting_with_dynamic_part() { +// let api = app().at("/users/:id").get( +// |_: Request, params: extract::UrlParamsMap| async move { +// // let (version, id) = params.0; +// dbg!(¶ms); +// let version = params.get("version").unwrap(); +// let id = params.get("id").unwrap(); +// format!("users#show {} {}", version, id) +// }, +// ); - let app = app().at("/:version/api").nest(api).into_service(); +// let app = app().at("/:version/api").nest(api).into_service(); - let addr = run_in_background(app).await; +// let addr = run_in_background(app).await; - let client = reqwest::Client::new(); +// let client = reqwest::Client::new(); - let res = client - .get(format!("http://{}/v0/api/users/123", addr)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await.unwrap(), "users#show v0 123"); -} +// let res = client +// .get(format!("http://{}/v0/api/users/123", addr)) +// .send() +// .await +// .unwrap(); +// let status = res.status(); +// assert_eq!(res.text().await.unwrap(), "users#show v0 123"); +// assert_eq!(status, StatusCode::OK); +// } + +// #[tokio::test] +// async fn nesting_more_deeply() { +// let users_api = app() +// .at("/:id") +// .get(|req: Request| async move { +// dbg!(&req.uri().path()); +// "users#show" +// }); + +// let games_api = app() +// .at("/") +// .post(|req: Request| async move { +// dbg!(&req.uri().path()); +// "games#create" +// }); + +// let api = app() +// .at("/users") +// .nest(users_api) +// .at("/games") +// .nest(games_api); + +// let app = app().at("/:version/api").nest(api).into_service(); + +// let addr = run_in_background(app).await; + +// let client = reqwest::Client::new(); + +// // let res = client +// // .get(format!("http://{}/v0/api/users/123", addr)) +// // .send() +// // .await +// // .unwrap(); +// // assert_eq!(res.status(), StatusCode::OK); + +// println!("============================"); + +// let res = client +// .post(format!("http://{}/v0/api/games", addr)) +// .send() +// .await +// .unwrap(); +// assert_eq!(res.status(), StatusCode::OK); +// } // TODO(david): nesting more deeply