diff --git a/src/lib.rs b/src/lib.rs index 9c82f402..3538c3bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,7 +29,7 @@ Tests use async_trait::async_trait; use bytes::Bytes; -use futures_util::future; +use futures_util::{future, ready}; use http::{Method, Request, Response, StatusCode}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::{ @@ -47,26 +47,31 @@ pub fn app() -> App<EmptyRouter> { } } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct App<R> { router: R, } impl<R> App<R> { - pub fn at(self, route_spec: &str) -> RouteBuilder<R> { - RouteBuilder { + pub fn at(self, route_spec: &str) -> RouteAt<R> { + self.at_bytes(Bytes::copy_from_slice(route_spec.as_bytes())) + } + + fn at_bytes(self, route_spec: Bytes) -> RouteAt<R> { + RouteAt { app: self, - route_spec: Bytes::copy_from_slice(route_spec.as_bytes()), + route_spec, } } } -pub struct RouteBuilder<R> { +#[derive(Debug, Clone)] +pub struct RouteAt<R> { app: App<R>, route_spec: Bytes, } -impl<R> RouteBuilder<R> { +impl<R> RouteAt<R> { pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>> where F: Handler<T>, @@ -81,14 +86,6 @@ impl<R> RouteBuilder<R> { self.add_route(handler_fn, Method::POST) } - pub fn at(self, route_spec: &str) -> Self { - self.app.at(route_spec) - } - - pub fn into_service(self) -> App<R> { - self.app - } - fn add_route<H, T>(self, handler: H, method: Method) -> RouteBuilder<Route<HandlerSvc<H, T>, R>> where H: Handler<T>, @@ -104,6 +101,8 @@ impl<R> RouteBuilder<R> { spec: self.route_spec.clone(), }, fallback: self.app.router, + handler_ready: false, + fallback_ready: false, }, }; @@ -114,9 +113,47 @@ impl<R> RouteBuilder<R> { } } +#[derive(Clone)] +pub struct RouteBuilder<R> { + app: App<R>, + route_spec: Bytes, +} + +impl<R> RouteBuilder<R> { + pub fn at(self, route_spec: &str) -> RouteAt<R> { + self.app.at(route_spec) + } + + pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>> + where + F: Handler<T>, + { + self.app.at_bytes(self.route_spec).get(handler_fn) + } + + pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>> + where + F: Handler<T>, + { + self.app.at_bytes(self.route_spec).post(handler_fn) + } +} + #[derive(Debug, thiserror::Error)] #[non_exhaustive] -pub enum Error {} +pub enum Error { + #[error("failed to deserialize the request body")] + DeserializeRequestBody(#[source] serde_json::Error), + + #[error("failed to consume the body")] + ConsumeBody(#[source] hyper::Error), + + #[error("URI contained no query string")] + QueryStringMissing, + + #[error("failed to deserialize query string")] + DeserializeQueryString(#[from] serde_urlencoded::de::Error), +} #[async_trait] pub trait Handler<Out> { @@ -136,37 +173,49 @@ where } } -#[async_trait] -#[allow(non_snake_case)] -impl<F, Fut, T1> Handler<(T1,)> for F -where - F: Fn(Request<Body>, T1) -> Fut + Send + Sync, - Fut: Future<Output = Result<Response<Body>, Error>> + Send, - T1: FromRequest + Send, -{ - async fn call(self, mut req: Request<Body>) -> Result<Response<Body>, Error> { - let T1 = T1::from_request(&mut req).await; - let res = self(req, T1).await?; - Ok(res) - } +macro_rules! impl_handler { + ( $head:ident $(,)? ) => { + #[async_trait] + #[allow(non_snake_case)] + impl<F, Fut, $head> Handler<($head,)> for F + where + F: Fn(Request<Body>, $head) -> Fut + Send + Sync, + Fut: Future<Output = Result<Response<Body>, Error>> + Send, + $head: FromRequest + Send, + { + async fn call(self, mut req: Request<Body>) -> Result<Response<Body>, Error> { + let $head = $head::from_request(&mut req).await?; + let res = self(req, $head).await?; + Ok(res) + } + } + }; + + ( $head:ident, $($tail:ident),* $(,)? ) => { + #[async_trait] + #[allow(non_snake_case)] + impl<F, Fut, $head, $($tail,)*> Handler<($head, $($tail,)*)> for F + where + F: Fn(Request<Body>, $head, $($tail,)*) -> Fut + Send + Sync, + Fut: Future<Output = Result<Response<Body>, Error>> + Send, + $head: FromRequest + Send, + $( $tail: FromRequest + Send, )* + { + async fn call(self, mut req: Request<Body>) -> Result<Response<Body>, Error> { + let $head = $head::from_request(&mut req).await?; + $( + let $tail = $tail::from_request(&mut req).await?; + )* + let res = self(req, $head, $($tail,)*).await?; + Ok(res) + } + } + + impl_handler!($($tail,)*); + }; } -#[async_trait] -#[allow(non_snake_case)] -impl<F, Fut, T1, T2> Handler<(T1, T2)> for F -where - F: Fn(Request<Body>, T1, T2) -> Fut + Send + Sync, - Fut: Future<Output = Result<Response<Body>, Error>> + Send, - T1: FromRequest + Send, - T2: FromRequest + Send, -{ - async fn call(self, mut req: Request<Body>) -> Result<Response<Body>, Error> { - let T1 = T1::from_request(&mut req).await; - let T2 = T2::from_request(&mut req).await; - let res = self(req, T1, T2).await?; - Ok(res) - } -} +impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); pub struct HandlerSvc<H, T> { handler: H, @@ -206,24 +255,26 @@ where #[async_trait] pub trait FromRequest: Sized { - async fn from_request(req: &mut Request<Body>) -> Self; + async fn from_request(req: &mut Request<Body>) -> Result<Self, Error>; } -pub struct Query<T>(Result<T, QueryError>); - -impl<T> Query<T> { - pub fn into_inner(self) -> Result<T, QueryError> { - self.0 +#[async_trait] +impl<T> FromRequest for Option<T> +where + T: FromRequest, +{ + async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> { + Ok(T::from_request(req).await.ok()) } } -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum QueryError { - #[error("URI contained no query string")] - Missing, - #[error("failed to deserialize query string")] - Deserialize(#[from] serde_urlencoded::de::Error), +#[derive(Debug, Clone, Copy)] +pub struct Query<T>(T); + +impl<T> Query<T> { + pub fn into_inner(self) -> T { + self.0 + } } #[async_trait] @@ -231,58 +282,44 @@ impl<T> FromRequest for Query<T> where T: DeserializeOwned, { - async fn from_request(req: &mut Request<Body>) -> Self { - let result = (|| { - let query = req.uri().query().ok_or(QueryError::Missing)?; - let value = serde_urlencoded::from_str(query)?; - Ok(value) - })(); - Query(result) + async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> { + let query = req.uri().query().ok_or(Error::QueryStringMissing)?; + let value = serde_urlencoded::from_str(query)?; + Ok(Query(value)) } } -pub struct Json<T>(Result<T, JsonError>); +#[derive(Debug, Clone, Copy)] +pub struct Json<T>(T); impl<T> Json<T> { - pub fn into_inner(self) -> Result<T, JsonError> { + pub fn into_inner(self) -> T { self.0 } } -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum JsonError { - #[error("failed to consume the body")] - ConsumeBody(#[from] hyper::Error), - #[error("failed to deserialize the body")] - Deserialize(#[from] serde_json::Error), -} - #[async_trait] impl<T> FromRequest for Json<T> where T: DeserializeOwned, { - async fn from_request(req: &mut Request<Body>) -> Self { + async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> { // TODO(david): require the body to have `content-type: application/json` let body = std::mem::take(req.body_mut()); - let result = async move { - let bytes = hyper::body::to_bytes(body).await?; - let value = serde_json::from_slice(&bytes)?; - Ok(value) - } - .await; - - Json(result) + let bytes = hyper::body::to_bytes(body) + .await + .map_err(Error::ConsumeBody)?; + let value = serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?; + Ok(Json(value)) } } #[derive(Clone, Copy)] pub struct EmptyRouter(()); -impl Service<Request<Body>> for EmptyRouter { +impl<R> Service<R> for EmptyRouter { type Response = Response<Body>; type Error = Error; type Future = future::Ready<Result<Self::Response, Self::Error>>; @@ -291,7 +328,7 @@ impl Service<Request<Body>> for EmptyRouter { Poll::Ready(Ok(())) } - fn call(&mut self, _req: Request<Body>) -> Self::Future { + fn call(&mut self, _req: R) -> Self::Future { let mut res = Response::new(Body::empty()); *res.status_mut() = StatusCode::NOT_FOUND; future::ready(Ok(res)) @@ -303,6 +340,8 @@ pub struct Route<H, F> { handler: H, route_spec: RouteSpec, fallback: F, + handler_ready: bool, + fallback_ready: bool, } #[derive(Clone)] @@ -320,51 +359,76 @@ impl RouteSpec { impl<H, F> Service<Request<Body>> for Route<H, F> where - H: Service<Request<Body>, Response = Response<Body>, Error = Error> + Clone + Send + 'static, - H::Future: Send, - F: Service<Request<Body>, Response = Response<Body>, Error = Error> + Clone + Send + 'static, - F::Future: Send, + H: Service<Request<Body>, Response = Response<Body>, Error = Error>, + F: Service<Request<Body>, Response = Response<Body>, Error = Error>, { type Response = Response<Body>; type Error = Error; - type Future = future::BoxFuture<'static, Result<Self::Response, Self::Error>>; + type Future = future::Either<H::Future, F::Future>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + if !self.handler_ready { + ready!(self.handler.poll_ready(cx))?; + self.handler_ready = true; + } + + if !self.fallback_ready { + ready!(self.fallback.poll_ready(cx))?; + self.fallback_ready = true; + } - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - // TODO(david): do we need to drive readiness in `call`? Poll::Ready(Ok(())) } fn call(&mut self, req: Request<Body>) -> Self::Future { if self.route_spec.matches(&req) { - let handler_clone = self.handler.clone(); - let mut handler = std::mem::replace(&mut self.handler, handler_clone); - Box::pin(async move { handler.ready().await?.call(req).await }) + self.handler_ready = false; + future::Either::Left(self.handler.call(req)) } else { - let fallback_clone = self.fallback.clone(); - let mut fallback = std::mem::replace(&mut self.fallback, fallback_clone); - Box::pin(async move { fallback.ready().await?.call(req).await }) + self.fallback_ready = false; + future::Either::Right(self.fallback.call(req)) } } } -impl<R> Service<Request<Body>> for App<R> +impl<R, T> Service<T> for App<R> where - R: Service<Request<Body>, Response = Response<Body>, Error = Error> + Clone, + R: Service<T>, { - type Response = Response<Body>; - type Error = Error; + type Response = R::Response; + type Error = R::Error; type Future = R::Future; - // TODO(david): handle backpressure + #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { self.router.poll_ready(cx) } - fn call(&mut self, req: Request<Body>) -> Self::Future { + #[inline] + fn call(&mut self, req: T) -> Self::Future { self.router.call(req) } } +impl<R, T> Service<T> for RouteBuilder<R> +where + App<R>: Service<T>, +{ + type Response = <App<R> as Service<T>>::Response; + type Error = <App<R> as Service<T>>::Error; + type Future = <App<R> as Service<T>>::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + self.app.poll_ready(cx) + } + + #[inline] + fn call(&mut self, req: T) -> Self::Future { + self.app.call(req) + } +} + #[cfg(test)] mod tests { #![allow(warnings)] @@ -377,8 +441,7 @@ mod tests { .get(root) .at("/users") .get(users_index) - .post(users_create) - .into_service(); + .post(users_create); let req = Request::builder() .method(Method::POST)