diff --git a/Cargo.toml b/Cargo.toml index fcb3f190..6ea3c6a2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,8 +19,10 @@ thiserror = "1.0" tower = { version = "0.4", features = ["util"] } [dev-dependencies] -tokio = { version = "1.6.1", features = ["macros", "rt"] } +tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] } serde = { version = "1.0", features = ["derive"] } tower = { version = "0.4", features = ["util", "make", "timeout"] } tower-http = { version = "0.1", features = ["trace", "compression", "add-extension"] } hyper = { version = "0.14", features = ["full"] } +tracing = "0.1" +tracing-subscriber = "0.2" diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs new file mode 100644 index 00000000..686dfef7 --- /dev/null +++ b/examples/key_value_store.rs @@ -0,0 +1,89 @@ +use bytes::Bytes; +use http::{Request, StatusCode}; +use hyper::Server; +use serde::Deserialize; +use std::{ + collections::HashMap, + net::SocketAddr, + sync::{Arc, Mutex}, + time::Duration, +}; +use tower::{make::Shared, ServiceBuilder}; +use tower_http::{ + add_extension::AddExtensionLayer, compression::CompressionLayer, trace::TraceLayer, +}; +use tower_web::{body::Body, extract, response, Error}; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + // build our application with some routes + let app = tower_web::app() + .at("/get") + .get(get) + .at("/set") + .post(set) + // convert it into a `Service` + .into_service(); + + // add some middleware + let app = ServiceBuilder::new() + .timeout(Duration::from_secs(10)) + .layer(TraceLayer::new_for_http()) + .layer(CompressionLayer::new()) + .layer(AddExtensionLayer::new(SharedState::default())) + .service(app); + + // run it with hyper + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + tracing::debug!("listening on {}", addr); + let server = Server::bind(&addr).serve(Shared::new(app)); + server.await.unwrap(); +} + +type SharedState = Arc<Mutex<State>>; + +#[derive(Default)] +struct State { + db: HashMap<String, Bytes>, +} + +#[derive(Deserialize)] +struct GetSetQueryString { + key: String, +} + +async fn get( + _req: Request<Body>, + query: extract::Query<GetSetQueryString>, + state: extract::Extension<SharedState>, +) -> Result<Bytes, Error> { + let state = state.into_inner(); + let db = &state.lock().unwrap().db; + + let key = query.into_inner().key; + + if let Some(value) = db.get(&key) { + Ok(value.clone()) + } else { + Err(Error::WithStatus(StatusCode::NOT_FOUND)) + } +} + +async fn set( + _req: Request<Body>, + query: extract::Query<GetSetQueryString>, + value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb + state: extract::Extension<SharedState>, +) -> Result<response::Empty, Error> { + let state = state.into_inner(); + let db = &mut state.lock().unwrap().db; + + let key = query.into_inner().key; + let value = value.into_inner(); + + db.insert(key, value); + + Ok(response::Empty) +} diff --git a/src/error.rs b/src/error.rs index 0b8088b7..8edbde29 100644 --- a/src/error.rs +++ b/src/error.rs @@ -29,6 +29,15 @@ pub enum Error { #[error("request extension of type `{type_name}` was not set")] MissingExtension { type_name: &'static str }, + + #[error("`Content-Length` header is missing but was required")] + LengthRequired, + + #[error("response body was too large")] + PayloadTooLarge, + + #[error("response failed with status {0}")] + WithStatus(StatusCode), } impl From<Infallible> for Error { @@ -55,6 +64,11 @@ where | Error::QueryStringMissing | Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST), + Error::WithStatus(status) => make_response(status), + + Error::LengthRequired => make_response(StatusCode::LENGTH_REQUIRED), + Error::PayloadTooLarge => make_response(StatusCode::PAYLOAD_TOO_LARGE), + Error::MissingExtension { .. } | Error::SerializeResponseBody(_) => { make_response(StatusCode::INTERNAL_SERVER_ERROR) } diff --git a/src/extract.rs b/src/extract.rs index acfccf6f..dc42ccb3 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -1,6 +1,8 @@ use crate::{body::Body, Error}; +use bytes::Bytes; use futures_util::{future, ready}; use http::Request; +use http_body::Body as _; use pin_project::pin_project; use serde::de::DeserializeOwned; use std::{ @@ -128,16 +130,6 @@ where } } -// TODO(david): can we add a length limit somehow? Maybe a const generic? -#[derive(Debug, Clone)] -pub struct Bytes(bytes::Bytes); - -impl Bytes { - pub fn into_inner(self) -> bytes::Bytes { - self.0 - } -} - impl FromRequest for Bytes { type Future = future::BoxFuture<'static, Result<Self, Error>>; @@ -148,7 +140,44 @@ impl FromRequest for Bytes { let bytes = hyper::body::to_bytes(body) .await .map_err(Error::ConsumeRequestBody)?; - Ok(Bytes(bytes)) + Ok(bytes) + }) + } +} + +#[derive(Debug, Clone)] +pub struct BytesMaxLength<const N: u64>(Bytes); + +impl<const N: u64> BytesMaxLength<N> { + pub fn into_inner(self) -> Bytes { + self.0 + } +} + +impl<const N: u64> FromRequest for BytesMaxLength<N> { + type Future = future::BoxFuture<'static, Result<Self, Error>>; + + fn from_request(req: &mut Request<Body>) -> Self::Future { + let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned(); + let body = std::mem::take(req.body_mut()); + + Box::pin(async move { + let content_length = + content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok()); + + if let Some(length) = content_length { + if length > N { + return Err(Error::PayloadTooLarge); + } + } else { + return Err(Error::LengthRequired); + }; + + let bytes = hyper::body::to_bytes(body) + .await + .map_err(Error::ConsumeRequestBody)?; + + Ok(BytesMaxLength(bytes)) }) } } diff --git a/src/lib.rs b/src/lib.rs index ad3068e0..8ec5006a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,10 @@ Improvements to make: Support extracting headers, perhaps via `headers::Header`? +Actual routing + +Improve compile times with lots of routes, can we box and combine routers? + Tests */ @@ -181,7 +185,10 @@ mod tests { Ok(Response::new(Body::from("Hello, World!"))) } - async fn large_static_file(_: Request<Body>) -> Result<Response<Body>, Error> { + async fn large_static_file( + _: Request<Body>, + body: extract::BytesMaxLength<{ 1024 * 500 }>, + ) -> Result<Response<Body>, Error> { Ok(Response::new(Body::empty())) } @@ -309,24 +316,4 @@ mod tests { let bytes = hyper::body::to_bytes(res.into_body()).await.unwrap(); String::from_utf8(bytes.to_vec()).unwrap() } - - #[allow(dead_code)] - // this should just compile - async fn compatible_with_hyper_and_tower_http() { - let app = app() - .at("/") - .get(|_: Request<Body>| async { - Ok::<_, Error>(Response::new(Body::from("Hello, World!"))) - }) - .into_service(); - - let app = ServiceBuilder::new() - .layer(TraceLayer::new_for_http()) - .layer(CompressionLayer::new()) - .service(app); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let server = Server::bind(&addr).serve(Shared::new(app)); - server.await.unwrap(); - } } diff --git a/src/response.rs b/src/response.rs index f41b1cb9..5174b741 100644 --- a/src/response.rs +++ b/src/response.rs @@ -25,33 +25,53 @@ impl IntoResponse<Body> for String { } } -impl IntoResponse<Body> for Bytes { - fn into_response(self) -> Result<Response<Body>, Error> { - Ok(Response::new(Body::from(self))) - } -} - -impl IntoResponse<Body> for &'static [u8] { - fn into_response(self) -> Result<Response<Body>, Error> { - Ok(Response::new(Body::from(self))) - } -} - -impl IntoResponse<Body> for Vec<u8> { - fn into_response(self) -> Result<Response<Body>, Error> { - Ok(Response::new(Body::from(self))) - } -} - impl IntoResponse<Body> for std::borrow::Cow<'static, str> { fn into_response(self) -> Result<Response<Body>, Error> { Ok(Response::new(Body::from(self))) } } +impl IntoResponse<Body> for Bytes { + fn into_response(self) -> Result<Response<Body>, Error> { + let mut res = Response::new(Body::from(self)); + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ); + Ok(res) + } +} + +impl IntoResponse<Body> for &'static [u8] { + fn into_response(self) -> Result<Response<Body>, Error> { + let mut res = Response::new(Body::from(self)); + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ); + Ok(res) + } +} + +impl IntoResponse<Body> for Vec<u8> { + fn into_response(self) -> Result<Response<Body>, Error> { + let mut res = Response::new(Body::from(self)); + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ); + Ok(res) + } +} + impl IntoResponse<Body> for std::borrow::Cow<'static, [u8]> { fn into_response(self) -> Result<Response<Body>, Error> { - Ok(Response::new(Body::from(self))) + let mut res = Response::new(Body::from(self)); + res.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/octet-stream"), + ); + Ok(res) } } @@ -63,17 +83,20 @@ where { fn into_response(self) -> Result<Response<Body>, Error> { let bytes = serde_json::to_vec(&self.0).map_err(Error::SerializeResponseBody)?; - let len = bytes.len(); let mut res = Response::new(Body::from(bytes)); - res.headers_mut().insert( header::CONTENT_TYPE, HeaderValue::from_static("application/json"), ); - - res.headers_mut() - .insert(header::CONTENT_LENGTH, HeaderValue::from(len)); - Ok(res) } } + +#[derive(Debug, Copy, Clone)] +pub struct Empty; + +impl IntoResponse<Body> for Empty { + fn into_response(self) -> Result<Response<Body>, Error> { + Ok(Response::new(Body::empty())) + } +}