diff --git a/Cargo.toml b/Cargo.toml index d417747c..19f0caf2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ serde_json = "1.0" serde_urlencoded = "0.7" thiserror = "1.0" tower = { version = "0.4", features = ["util", "buffer"] } +tower-http = { version = "0.1", features = ["add-extension"] } [dev-dependencies] hyper = { version = "0.14", features = ["full"] } diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs index 98c9c51d..382cae3d 100644 --- a/examples/key_value_store.rs +++ b/examples/key_value_store.rs @@ -51,10 +51,10 @@ async fn get( params: extract::UrlParams<(String,)>, state: extract::Extension, ) -> Result { - let state = state.into_inner(); + let state = state.0; let db = &state.lock().unwrap().db; - let key = params.into_inner(); + let (key,) = params.0; if let Some(value) = db.get(&key) { Ok(value.clone()) @@ -69,11 +69,11 @@ async fn set( value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb state: extract::Extension, ) { - let state = state.into_inner(); + let state = state.0; let db = &mut state.lock().unwrap().db; - let key = params.into_inner(); - let value = value.into_inner(); + let (key,) = params.0; + let value = value.0; db.insert(key, value); } diff --git a/src/extract.rs b/src/extract.rs index fd70e254..4ccbe6a6 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -1,10 +1,7 @@ -use crate::{ - body::Body, - response::{BoxIntoResponse, IntoResponse}, -}; +use crate::{body::Body, response::IntoResponse}; use async_trait::async_trait; use bytes::Bytes; -use http::{header, Request}; +use http::{header, Response, Request}; use serde::de::DeserializeOwned; use std::{collections::HashMap, convert::Infallible, str::FromStr}; @@ -80,13 +77,7 @@ define_rejection! { } #[derive(Debug, Clone, Copy)] -pub struct Query(T); - -impl Query { - pub fn into_inner(self) -> T { - self.0 - } -} +pub struct Query(pub T); #[async_trait] impl FromRequest for Query @@ -103,13 +94,7 @@ where } #[derive(Debug, Clone, Copy)] -pub struct Json(T); - -impl Json { - pub fn into_inner(self) -> T { - self.0 - } -} +pub struct Json(pub T); define_rejection! { #[status = BAD_REQUEST] @@ -128,24 +113,24 @@ impl FromRequest for Json where T: DeserializeOwned, { - type Rejection = BoxIntoResponse; + type Rejection = Response; async fn from_request(req: &mut Request) -> Result { if has_content_type(&req, "application/json") { - let body = take_body(req).map_err(IntoResponse::boxed)?; + let body = take_body(req).map_err(IntoResponse::into_response)?; let bytes = hyper::body::to_bytes(body) .await .map_err(InvalidJsonBody::from_err) - .map_err(IntoResponse::boxed)?; + .map_err(IntoResponse::into_response)?; let value = serde_json::from_slice(&bytes) .map_err(InvalidJsonBody::from_err) - .map_err(IntoResponse::boxed)?; + .map_err(IntoResponse::into_response)?; Ok(Json(value)) } else { - Err(MissingJsonContentType(()).boxed()) + Err(MissingJsonContentType(()).into_response()) } } } @@ -173,13 +158,7 @@ define_rejection! { } #[derive(Debug, Clone, Copy)] -pub struct Extension(T); - -impl Extension { - pub fn into_inner(self) -> T { - self.0 - } -} +pub struct Extension(pub T); #[async_trait] impl FromRequest for Extension @@ -207,15 +186,15 @@ define_rejection! { #[async_trait] impl FromRequest for Bytes { - type Rejection = BoxIntoResponse; + type Rejection = Response; async fn from_request(req: &mut Request) -> Result { - let body = take_body(req).map_err(IntoResponse::boxed)?; + let body = take_body(req).map_err(IntoResponse::into_response)?; let bytes = hyper::body::to_bytes(body) .await .map_err(FailedToBufferBody::from_err) - .map_err(IntoResponse::boxed)?; + .map_err(IntoResponse::into_response)?; Ok(bytes) } @@ -229,20 +208,20 @@ define_rejection! { #[async_trait] impl FromRequest for String { - type Rejection = BoxIntoResponse; + type Rejection = Response; async fn from_request(req: &mut Request) -> Result { - let body = take_body(req).map_err(IntoResponse::boxed)?; + let body = take_body(req).map_err(IntoResponse::into_response)?; let bytes = hyper::body::to_bytes(body) .await .map_err(FailedToBufferBody::from_err) - .map_err(IntoResponse::boxed)? + .map_err(IntoResponse::into_response)? .to_vec(); let string = String::from_utf8(bytes) .map_err(InvalidUtf8::from_err) - .map_err(IntoResponse::boxed)?; + .map_err(IntoResponse::into_response)?; Ok(string) } @@ -270,36 +249,30 @@ define_rejection! { } #[derive(Debug, Clone)] -pub struct BytesMaxLength(Bytes); - -impl BytesMaxLength { - pub fn into_inner(self) -> Bytes { - self.0 - } -} +pub struct BytesMaxLength(pub Bytes); #[async_trait] impl FromRequest for BytesMaxLength { - type Rejection = BoxIntoResponse; + type Rejection = Response; async fn from_request(req: &mut Request) -> Result { let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned(); - let body = take_body(req).map_err(|reject| reject.boxed())?; + let body = take_body(req).map_err(|reject| reject.into_response())?; let content_length = content_length.and_then(|value| value.to_str().ok()?.parse::().ok()); if let Some(length) = content_length { if length > N { - return Err(PayloadTooLarge(()).boxed()); + return Err(PayloadTooLarge(()).into_response()); } } else { - return Err(LengthRequired(()).boxed()); + return Err(LengthRequired(()).into_response()); }; let bytes = hyper::body::to_bytes(body) .await - .map_err(|e| FailedToBufferBody::from_err(e).boxed())?; + .map_err(|e| FailedToBufferBody::from_err(e).into_response())?; Ok(BytesMaxLength(bytes)) } @@ -367,7 +340,7 @@ impl IntoResponse for InvalidUrlParam { } } -pub struct UrlParams(T); +pub struct UrlParams(pub T); macro_rules! impl_parse_url { () => {}; @@ -379,7 +352,7 @@ macro_rules! impl_parse_url { $head: FromStr + Send, $( $tail: FromStr + Send, )* { - type Rejection = BoxIntoResponse; + type Rejection = Response; #[allow(non_snake_case)] async fn from_request(req: &mut Request) -> Result { @@ -389,27 +362,27 @@ macro_rules! impl_parse_url { { params.take().expect("params already taken").0 } else { - return Err(MissingRouteParams(()).boxed()) + return Err(MissingRouteParams(()).into_response()) }; if let [(_, $head), $((_, $tail),)*] = &*params { let $head = if let Ok(x) = $head.parse::<$head>() { x } else { - return Err(InvalidUrlParam::new::<$head>().boxed()); + return Err(InvalidUrlParam::new::<$head>().into_response()); }; $( let $tail = if let Ok(x) = $tail.parse::<$tail>() { x } else { - return Err(InvalidUrlParam::new::<$tail>().boxed()); + return Err(InvalidUrlParam::new::<$tail>().into_response()); }; )* Ok(UrlParams(($head, $($tail,)*))) } else { - return Err(MissingRouteParams(()).boxed()) + return Err(MissingRouteParams(()).into_response()) } } } @@ -420,49 +393,6 @@ macro_rules! impl_parse_url { impl_parse_url!(T1, T2, T3, T4, T5, T6); -impl UrlParams<(T1,)> { - pub fn into_inner(self) -> T1 { - (self.0).0 - } -} - -impl UrlParams<(T1, T2)> { - pub fn into_inner(self) -> (T1, T2) { - ((self.0).0, (self.0).1) - } -} - -impl UrlParams<(T1, T2, T3)> { - pub fn into_inner(self) -> (T1, T2, T3) { - ((self.0).0, (self.0).1, (self.0).2) - } -} - -impl UrlParams<(T1, T2, T3, T4)> { - pub fn into_inner(self) -> (T1, T2, T3, T4) { - ((self.0).0, (self.0).1, (self.0).2, (self.0).3) - } -} - -impl UrlParams<(T1, T2, T3, T4, T5)> { - pub fn into_inner(self) -> (T1, T2, T3, T4, T5) { - ((self.0).0, (self.0).1, (self.0).2, (self.0).3, (self.0).4) - } -} - -impl UrlParams<(T1, T2, T3, T4, T5, T6)> { - pub fn into_inner(self) -> (T1, T2, T3, T4, T5, T6) { - ( - (self.0).0, - (self.0).1, - (self.0).2, - (self.0).3, - (self.0).4, - (self.0).5, - ) - } -} - define_rejection! { #[status = INTERNAL_SERVER_ERROR] #[body = "Cannot have two request body extractors for a single handler"] diff --git a/src/lib.rs b/src/lib.rs index 5658d43f..47069f98 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,9 @@ pub mod handler; pub mod response; pub mod routing; +pub use tower_http::add_extension::{AddExtension, AddExtensionLayer}; +pub use async_trait::async_trait; + #[cfg(test)] mod tests; diff --git a/src/response.rs b/src/response.rs index 0b2b0a89..8e722eee 100644 --- a/src/response.rs +++ b/src/response.rs @@ -5,17 +5,9 @@ use serde::Serialize; use std::convert::Infallible; use tower::util::Either; +// TODO(david): can we change this to not be generic over the body and just use hyper::Body? pub trait IntoResponse { fn into_response(self) -> Response; - - // TODO(david): remove this an return return `Response` instead. That is what this method - // does anyway. - fn boxed(self) -> BoxIntoResponse - where - Self: Sized + 'static, - { - BoxIntoResponse(self.into_response()) - } } impl IntoResponse for () @@ -171,14 +163,6 @@ where } } -pub struct BoxIntoResponse(Response); - -impl IntoResponse for BoxIntoResponse { - fn into_response(self) -> Response { - self.0 - } -} - impl IntoResponse for StatusCode where B: Default, diff --git a/src/tests.rs b/src/tests.rs index 92977bd1..183f6881 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -51,7 +51,7 @@ async fn deserialize_body() { let app = app() .at("/") - .post(|_: Request, input: extract::Json| async { input.into_inner().foo }) + .post(|_: Request, input: extract::Json| async { input.0.foo }) .into_service(); let addr = run_in_background(app).await; @@ -78,8 +78,7 @@ async fn consume_body_to_json_requires_json_content_type() { let app = app() .at("/") .post(|_: Request, input: extract::Json| async { - let input = input.into_inner(); - input.foo + input.0.foo }) .into_service(); @@ -216,7 +215,7 @@ async fn extracting_url_params() { .at("/users/:id") .get( |_: Request, params: extract::UrlParams<(i32,)>| async move { - let id = params.into_inner(); + let (id,) = params.0; assert_eq!(id, 42); }, )