From 07294378b39f439d7fdcd528a6339733a9280006 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sat, 29 May 2021 21:13:06 +0200 Subject: [PATCH] Initial pile of hacks --- .gitignore | 2 + Cargo.toml | 22 +++ src/lib.rs | 429 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 453 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 src/lib.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..96ef6c0b --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..eafc8c1b --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "tower-web" +version = "0.1.0" +authors = ["David Pedersen "] +edition = "2018" + +[dependencies] +async-trait = "0.1" +bytes = "1.0" +http = "0.2" +http-body = "0.4" +hyper = "0.14" +serde = "1.0" +serde_urlencoded = "0.7" +serde_json = "1.0" +futures-util = "0.3" +tower = { version = "0.4", features = ["util"] } +thiserror = "1.0" + +[dev-dependencies] +tokio = { version = "1.6.1", features = ["macros", "rt"] } +serde = { version = "1.0", features = ["derive"] } diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 00000000..9c82f402 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,429 @@ +#![allow(unused_imports, dead_code)] + +/* + +Improvements to make: + +Somehow return generic "into response" kinda types without having to manually +create hyper::Body for everything + +Don't make Query and Json contain a Result, instead make generic wrapper +for "optional" inputs + +Make it possible to convert QueryError and JsonError into responses + +Support wrapping single routes in tower::Layer + +Support putting a tower::Service at a Route + +Don't require the response body to be hyper::Body, wont work if we're wrapping +single routes in layers + +Support extracting headers, perhaps via `headers::Header`? + +Implement `FromRequest` for more functions, with macro + +Tests + +*/ + +use async_trait::async_trait; +use bytes::Bytes; +use futures_util::future; +use http::{Method, Request, Response, StatusCode}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::{ + future::Future, + marker::PhantomData, + task::{Context, Poll}, +}; +use tower::{Service, ServiceExt}; + +pub use hyper::body::Body; + +pub fn app() -> App { + App { + router: EmptyRouter(()), + } +} + +#[derive(Clone)] +pub struct App { + router: R, +} + +impl App { + pub fn at(self, route_spec: &str) -> RouteBuilder { + RouteBuilder { + app: self, + route_spec: Bytes::copy_from_slice(route_spec.as_bytes()), + } + } +} + +pub struct RouteBuilder { + app: App, + route_spec: Bytes, +} + +impl RouteBuilder { + pub fn get(self, handler_fn: F) -> RouteBuilder, R>> + where + F: Handler, + { + self.add_route(handler_fn, Method::GET) + } + + pub fn post(self, handler_fn: F) -> RouteBuilder, R>> + where + F: Handler, + { + self.add_route(handler_fn, Method::POST) + } + + pub fn at(self, route_spec: &str) -> Self { + self.app.at(route_spec) + } + + pub fn into_service(self) -> App { + self.app + } + + fn add_route(self, handler: H, method: Method) -> RouteBuilder, R>> + where + H: Handler, + { + let new_app = App { + router: Route { + handler: HandlerSvc { + handler, + _input: PhantomData, + }, + route_spec: RouteSpec { + method, + spec: self.route_spec.clone(), + }, + fallback: self.app.router, + }, + }; + + RouteBuilder { + app: new_app, + route_spec: self.route_spec, + } + } +} + +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum Error {} + +#[async_trait] +pub trait Handler { + async fn call(self, req: Request) -> Result, Error>; +} + +#[async_trait] +#[allow(non_snake_case)] +impl Handler<()> for F +where + F: Fn(Request) -> Fut + Send + Sync, + Fut: Future, Error>> + Send, +{ + async fn call(self, req: Request) -> Result, Error> { + let res = self(req).await?; + Ok(res) + } +} + +#[async_trait] +#[allow(non_snake_case)] +impl Handler<(T1,)> for F +where + F: Fn(Request, T1) -> Fut + Send + Sync, + Fut: Future, Error>> + Send, + T1: FromRequest + Send, +{ + async fn call(self, mut req: Request) -> Result, Error> { + let T1 = T1::from_request(&mut req).await; + let res = self(req, T1).await?; + Ok(res) + } +} + +#[async_trait] +#[allow(non_snake_case)] +impl Handler<(T1, T2)> for F +where + F: Fn(Request, T1, T2) -> Fut + Send + Sync, + Fut: Future, Error>> + Send, + T1: FromRequest + Send, + T2: FromRequest + Send, +{ + async fn call(self, mut req: Request) -> Result, Error> { + let T1 = T1::from_request(&mut req).await; + let T2 = T2::from_request(&mut req).await; + let res = self(req, T1, T2).await?; + Ok(res) + } +} + +pub struct HandlerSvc { + handler: H, + _input: PhantomData T>, +} + +impl Clone for HandlerSvc +where + H: Clone, +{ + fn clone(&self) -> Self { + Self { + handler: self.handler.clone(), + _input: PhantomData, + } + } +} + +impl Service> for HandlerSvc +where + H: Handler + Clone + 'static, +{ + type Response = Response; + type Error = Error; + type Future = future::BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + // HandlerSvc can only be constructed from async functions which are always ready + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let handler = self.handler.clone(); + Box::pin(Handler::call(handler, req)) + } +} + +#[async_trait] +pub trait FromRequest: Sized { + async fn from_request(req: &mut Request) -> Self; +} + +pub struct Query(Result); + +impl Query { + pub fn into_inner(self) -> Result { + self.0 + } +} + +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum QueryError { + #[error("URI contained no query string")] + Missing, + #[error("failed to deserialize query string")] + Deserialize(#[from] serde_urlencoded::de::Error), +} + +#[async_trait] +impl FromRequest for Query +where + T: DeserializeOwned, +{ + async fn from_request(req: &mut Request) -> Self { + let result = (|| { + let query = req.uri().query().ok_or(QueryError::Missing)?; + let value = serde_urlencoded::from_str(query)?; + Ok(value) + })(); + Query(result) + } +} + +pub struct Json(Result); + +impl Json { + pub fn into_inner(self) -> Result { + self.0 + } +} + +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum JsonError { + #[error("failed to consume the body")] + ConsumeBody(#[from] hyper::Error), + #[error("failed to deserialize the body")] + Deserialize(#[from] serde_json::Error), +} + +#[async_trait] +impl FromRequest for Json +where + T: DeserializeOwned, +{ + async fn from_request(req: &mut Request) -> Self { + // TODO(david): require the body to have `content-type: application/json` + + let body = std::mem::take(req.body_mut()); + + let result = async move { + let bytes = hyper::body::to_bytes(body).await?; + let value = serde_json::from_slice(&bytes)?; + Ok(value) + } + .await; + + Json(result) + } +} + +#[derive(Clone, Copy)] +pub struct EmptyRouter(()); + +impl Service> for EmptyRouter { + type Response = Response; + type Error = Error; + type Future = future::Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: Request) -> Self::Future { + let mut res = Response::new(Body::empty()); + *res.status_mut() = StatusCode::NOT_FOUND; + future::ready(Ok(res)) + } +} + +#[derive(Clone)] +pub struct Route { + handler: H, + route_spec: RouteSpec, + fallback: F, +} + +#[derive(Clone)] +struct RouteSpec { + method: Method, + spec: Bytes, +} + +impl RouteSpec { + fn matches(&self, req: &Request) -> bool { + // TODO(david): support dynamic placeholders like `/users/:id` + req.method() == self.method && req.uri().path().as_bytes() == self.spec + } +} + +impl Service> for Route +where + H: Service, Response = Response, Error = Error> + Clone + Send + 'static, + H::Future: Send, + F: Service, Response = Response, Error = Error> + Clone + Send + 'static, + F::Future: Send, +{ + type Response = Response; + type Error = Error; + type Future = future::BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + // TODO(david): do we need to drive readiness in `call`? + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + if self.route_spec.matches(&req) { + let handler_clone = self.handler.clone(); + let mut handler = std::mem::replace(&mut self.handler, handler_clone); + Box::pin(async move { handler.ready().await?.call(req).await }) + } else { + let fallback_clone = self.fallback.clone(); + let mut fallback = std::mem::replace(&mut self.fallback, fallback_clone); + Box::pin(async move { fallback.ready().await?.call(req).await }) + } + } +} + +impl Service> for App +where + R: Service, Response = Response, Error = Error> + Clone, +{ + type Response = Response; + type Error = Error; + type Future = R::Future; + + // TODO(david): handle backpressure + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.router.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + self.router.call(req) + } +} + +#[cfg(test)] +mod tests { + #![allow(warnings)] + use super::*; + + #[tokio::test] + async fn basic() { + let mut app = app() + .at("/") + .get(root) + .at("/users") + .get(users_index) + .post(users_create) + .into_service(); + + let req = Request::builder() + .method(Method::POST) + .uri("/users") + .body(Body::from(r#"{ "username": "bob" }"#)) + .unwrap(); + + let res = app.ready().await.unwrap().call(req).await.unwrap(); + let body = body_to_string(res).await; + dbg!(&body); + } + + async fn body_to_string(res: Response) -> String { + let bytes = hyper::body::to_bytes(res.into_body()).await.unwrap(); + String::from_utf8(bytes.to_vec()).unwrap() + } + + async fn root(req: Request) -> Result, Error> { + Ok(Response::new(Body::from("Hello, World!"))) + } + + async fn users_index( + req: Request, + pagination: Query, + ) -> Result, Error> { + dbg!(pagination.into_inner()); + Ok(Response::new(Body::from("users#index"))) + } + + #[derive(Debug, Deserialize)] + struct Pagination { + page: usize, + per_page: usize, + } + + async fn users_create( + req: Request, + payload: Json, + ) -> Result, Error> { + dbg!(payload.into_inner()); + Ok(Response::new(Body::from("users#create"))) + } + + #[derive(Debug, Deserialize)] + struct UsersCreate { + username: String, + } +}