diff --git a/Cargo.toml b/Cargo.toml index 6ea3c6a2..692abda5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,10 +19,11 @@ thiserror = "1.0" tower = { version = "0.4", features = ["util"] } [dev-dependencies] -tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] } +hyper = { version = "0.14", features = ["full"] } +reqwest = { version = "0.11", features = ["json", "stream"] } serde = { version = "1.0", features = ["derive"] } +tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] } tower = { version = "0.4", features = ["util", "make", "timeout"] } tower-http = { version = "0.1", features = ["trace", "compression", "add-extension"] } -hyper = { version = "0.14", features = ["full"] } tracing = "0.1" tracing-subscriber = "0.2" diff --git a/src/error.rs b/src/error.rs index dcfa50bc..9ad1e527 100644 --- a/src/error.rs +++ b/src/error.rs @@ -44,6 +44,9 @@ pub enum Error { #[error("unknown URL param `{0}`")] UnknownUrlParam(String), + + #[error("response body didn't contain valid UTF-8")] + InvalidUtf8, } impl From<Infallible> for Error { @@ -69,7 +72,8 @@ where Error::DeserializeRequestBody(_) | Error::QueryStringMissing | Error::DeserializeQueryString(_) - | Error::InvalidUrlParam { .. } => make_response(StatusCode::BAD_REQUEST), + | Error::InvalidUrlParam { .. } + | Error::InvalidUtf8 => make_response(StatusCode::BAD_REQUEST), Error::Status(status) => make_response(status), diff --git a/src/extract.rs b/src/extract.rs index 4892016b..cf92d9fc 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -1,7 +1,7 @@ use crate::{body::Body, Error}; use bytes::Bytes; use futures_util::{future, ready}; -use http::Request; +use http::{header, Request, StatusCode}; use pin_project::pin_project; use serde::de::DeserializeOwned; use std::{ @@ -86,20 +86,39 @@ where type Future = future::BoxFuture<'static, Result<Self, Error>>; fn from_request(req: &mut Request<Body>) -> Self::Future { - // TODO(david): require the body to have `content-type: application/json` + if has_content_type(&req, "application/json") { + let body = std::mem::take(req.body_mut()); - let body = std::mem::take(req.body_mut()); - - Box::pin(async move { - let bytes = hyper::body::to_bytes(body) - .await - .map_err(Error::ConsumeRequestBody)?; - let value = serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?; - Ok(Json(value)) - }) + Box::pin(async move { + let bytes = hyper::body::to_bytes(body) + .await + .map_err(Error::ConsumeRequestBody)?; + let value = + serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?; + Ok(Json(value)) + }) + } else { + Box::pin(async { Err(Error::Status(StatusCode::BAD_REQUEST)) }) + } } } +fn has_content_type<B>(req: &Request<B>, expected_content_type: &str) -> bool { + let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { + content_type + } else { + return false; + }; + + let content_type = if let Ok(content_type) = content_type.to_str() { + content_type + } else { + return false; + }; + + content_type.starts_with(expected_content_type) +} + #[derive(Debug, Clone, Copy)] pub struct Extension<T>(T); @@ -146,6 +165,32 @@ impl FromRequest for Bytes { } } +impl FromRequest for String { + type Future = future::BoxFuture<'static, Result<Self, Error>>; + + fn from_request(req: &mut Request<Body>) -> Self::Future { + let body = std::mem::take(req.body_mut()); + + Box::pin(async move { + let bytes = hyper::body::to_bytes(body) + .await + .map_err(Error::ConsumeRequestBody)? + .to_vec(); + let string = String::from_utf8(bytes).map_err(|_| Error::InvalidUtf8)?; + Ok(string) + }) + } +} + +impl FromRequest for Body { + type Future = future::Ready<Result<Self, Error>>; + + fn from_request(req: &mut Request<Body>) -> Self::Future { + let body = std::mem::take(req.body_mut()); + future::ok(body) + } +} + #[derive(Debug, Clone)] pub struct BytesMaxLength<const N: u64>(Bytes); diff --git a/src/lib.rs b/src/lib.rs index ed3bda72..d8c3e46a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,3 @@ -/* - -Improvements to make: - -Support extracting headers, perhaps via `headers::Header`? - -Improve compile times with lots of routes, can we box and combine routers? - -Tests - -*/ - use self::{ body::Body, routing::{EmptyRouter, RouteAt}, @@ -33,6 +21,9 @@ pub mod routing; mod error; +#[cfg(test)] +mod tests; + pub use self::error::Error; pub fn app() -> App<EmptyRouter> { @@ -139,173 +130,3 @@ where } } } - -#[cfg(test)] -mod tests { - #![allow(warnings)] - use super::*; - use crate::handler::Handler; - use http::{Method, Request, StatusCode}; - use hyper::Server; - use serde::Deserialize; - use std::time::Duration; - use std::{fmt, net::SocketAddr, sync::Arc}; - use tower::{ - layer::util::Identity, make::Shared, service_fn, timeout::TimeoutLayer, ServiceBuilder, - ServiceExt, - }; - use tower_http::{ - add_extension::AddExtensionLayer, - compression::CompressionLayer, - trace::{Trace, TraceLayer}, - }; - - #[tokio::test] - async fn basic() { - #[derive(Debug, Deserialize)] - struct Pagination { - page: usize, - per_page: usize, - } - - #[derive(Debug, Deserialize)] - struct UsersCreate { - username: String, - } - - async fn root(_: Request<Body>) -> Result<Response<Body>, Error> { - Ok(Response::new(Body::from("Hello, World!"))) - } - - async fn large_static_file( - _: Request<Body>, - body: extract::BytesMaxLength<{ 1024 * 500 }>, - ) -> Result<Response<Body>, Error> { - Ok(Response::new(Body::empty())) - } - - let app = app() - // routes with functions - .at("/") - .get(root) - // routes with closures - .at("/users") - .get( - |_: Request<Body>, pagination: extract::Query<Pagination>| async { - let pagination = pagination.into_inner(); - assert_eq!(pagination.page, 1); - assert_eq!(pagination.per_page, 30); - Ok::<_, Error>("users#index".to_string()) - }, - ) - .post( - |_: Request<Body>, - payload: extract::Json<UsersCreate>, - _state: extract::Extension<Arc<State>>| async { - let payload = payload.into_inner(); - assert_eq!(payload.username, "bob"); - Ok::<_, Error>(response::Json( - serde_json::json!({ "username": payload.username }), - )) - }, - ) - // routes with a service - .at("/service") - .get_service(service_fn(root)) - // routes with layers applied - .at("/large-static-file") - .get( - large_static_file.layer( - ServiceBuilder::new() - .layer(TimeoutLayer::new(Duration::from_secs(30))) - .layer(CompressionLayer::new()) - .into_inner(), - ), - ) - .into_service(); - - // state shared by all routes, could hold db connection etc - struct State {} - - let state = Arc::new(State {}); - - // can add more middleware - let mut app = ServiceBuilder::new() - .layer(AddExtensionLayer::new(state)) - .layer(TraceLayer::new_for_http()) - .service(app); - - let res = app - .ready() - .await - .unwrap() - .call( - Request::builder() - .method(Method::GET) - .uri("/") - .body(Body::empty()) - .unwrap(), - ) - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(body_to_string(res).await, "Hello, World!"); - - let res = app - .ready() - .await - .unwrap() - .call( - Request::builder() - .method(Method::GET) - .uri("/users?page=1&per_page=30") - .body(Body::empty()) - .unwrap(), - ) - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(body_to_string(res).await, "users#index"); - - let res = app - .ready() - .await - .unwrap() - .call( - Request::builder() - .method(Method::GET) - .uri("/users") - .body(Body::empty()) - .unwrap(), - ) - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - assert_eq!(body_to_string(res).await, ""); - - let res = app - .ready() - .await - .unwrap() - .call( - Request::builder() - .method(Method::POST) - .uri("/users") - .body(Body::from(r#"{ "username": "bob" }"#)) - .unwrap(), - ) - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(body_to_string(res).await, r#"{"username":"bob"}"#); - } - - async fn body_to_string<B>(res: Response<B>) -> String - where - B: http_body::Body, - B::Error: fmt::Debug, - { - let bytes = hyper::body::to_bytes(res.into_body()).await.unwrap(); - String::from_utf8(bytes.to_vec()).unwrap() - } -} diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 00000000..a0924932 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,192 @@ +use crate::{app, extract, response}; +use http::{Request, Response, StatusCode}; +use hyper::{Body, Server}; +use serde::Deserialize; +use serde_json::json; +use std::net::{SocketAddr, TcpListener}; +use tower::{make::Shared, BoxError, Service}; + +#[tokio::test] +async fn hello_world() { + let app = app() + .at("/") + .get(|_: Request<Body>| async { Ok("Hello, World!") }) + .into_service(); + + let addr = run_in_background(app).await; + + let res = reqwest::get(format!("http://{}", addr)).await.unwrap(); + let body = res.text().await.unwrap(); + + assert_eq!(body, "Hello, World!"); +} + +#[tokio::test] +async fn consume_body() { + let app = app() + .at("/") + .get(|_: Request<Body>, body: String| async { Ok(body) }) + .into_service(); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + let res = client + .get(format!("http://{}", addr)) + .body("foo") + .send() + .await + .unwrap(); + let body = res.text().await.unwrap(); + + assert_eq!(body, "foo"); +} + +#[tokio::test] +async fn deserialize_body() { + #[derive(Debug, Deserialize)] + struct Input { + foo: String, + } + + let app = app() + .at("/") + .post(|_: Request<Body>, input: extract::Json<Input>| async { Ok(input.into_inner().foo) }) + .into_service(); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + let res = client + .post(format!("http://{}", addr)) + .json(&json!({ "foo": "bar" })) + .send() + .await + .unwrap(); + let body = res.text().await.unwrap(); + + assert_eq!(body, "bar"); +} + +#[tokio::test] +async fn consume_body_to_json_requires_json_content_type() { + #[derive(Debug, Deserialize)] + struct Input { + foo: String, + } + + let app = app() + .at("/") + .post(|_: Request<Body>, input: extract::Json<Input>| async { + let input = input.into_inner(); + Ok(input.foo) + }) + .into_service(); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + let res = client + .post(format!("http://{}", addr)) + .body(r#"{ "foo": "bar" }"#) + .send() + .await + .unwrap(); + + // TODO(david): is this the most appropriate response code? + assert_eq!(res.status(), StatusCode::BAD_REQUEST); +} + +#[tokio::test] +async fn body_with_length_limit() { + use std::iter::repeat; + + #[derive(Debug, Deserialize)] + struct Input { + foo: String, + } + + const LIMIT: u64 = 8; + + let app = app() + .at("/") + .post( + |req: Request<Body>, _body: extract::BytesMaxLength<LIMIT>| async move { + dbg!(&req); + Ok(response::Empty) + }, + ) + .into_service(); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .post(format!("http://{}", addr)) + .body(repeat(0_u8).take((LIMIT - 1) as usize).collect::<Vec<_>>()) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .post(format!("http://{}", addr)) + .body(repeat(0_u8).take(LIMIT as usize).collect::<Vec<_>>()) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .post(format!("http://{}", addr)) + .body(repeat(0_u8).take((LIMIT + 1) as usize).collect::<Vec<_>>()) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); + + let res = client + .post(format!("http://{}", addr)) + .body(reqwest::Body::wrap_stream(futures_util::stream::iter( + vec![Ok::<_, std::io::Error>(bytes::Bytes::new())], + ))) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED); +} + +// TODO(david): can extractors change the request type? +// TODO(david): should FromRequest be an async-trait? + +// TODO(david): routing + +// TODO(david): lots of routes and boxing, shouldn't take forever to compile + +/// Run a `tower::Service` in the background and get a URI for it. +pub async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr +where + S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Data: Send, + ResBody::Error: Into<BoxError>, + S::Error: Into<BoxError>, + S::Future: Send, +{ + let listener = TcpListener::bind("127.0.0.1:0").expect("Could not bind ephemeral socket"); + let addr = listener.local_addr().unwrap(); + println!("Listening on {}", addr); + + let (tx, rx) = tokio::sync::oneshot::channel(); + + tokio::spawn(async move { + let server = Server::from_tcp(listener).unwrap().serve(Shared::new(svc)); + tx.send(()).unwrap(); + server.await.expect("server error"); + }); + + rx.await.unwrap(); + + addr +}