Generic response types

This commit is contained in:
David Pedersen 2021-05-30 12:26:58 +02:00
parent bf3d4f4a40
commit 433b3183c4

View file

@ -4,15 +4,10 @@
Improvements to make: Improvements to make:
Somehow return generic "into response" kinda types without having to manually Break stuff up into modules
create hyper::Body for everything
Support extracting headers, perhaps via `headers::Header`? Support extracting headers, perhaps via `headers::Header`?
Body to bytes extractor
Implement `FromRequest` for more functions, with macro
Tests Tests
*/ */
@ -20,7 +15,7 @@ Tests
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use futures_util::{future, ready}; use futures_util::{future, ready};
use http::{Method, Request, Response, StatusCode}; use http::{header, HeaderValue, Method, Request, Response, StatusCode};
use http_body::Body as _; use http_body::Body as _;
use pin_project::pin_project; use pin_project::pin_project;
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
@ -69,9 +64,9 @@ pub struct RouteAt<R> {
} }
impl<R> RouteAt<R> { impl<R> RouteAt<R> {
pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>> pub fn get<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
where where
F: Handler<T>, F: Handler<B, T>,
{ {
self.add_route(handler_fn, Method::GET) self.add_route(handler_fn, Method::GET)
} }
@ -84,9 +79,9 @@ impl<R> RouteAt<R> {
self.add_route_service(service, Method::GET) self.add_route_service(service, Method::GET)
} }
pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>> pub fn post<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
where where
F: Handler<T>, F: Handler<B, T>,
{ {
self.add_route(handler_fn, Method::POST) self.add_route(handler_fn, Method::POST)
} }
@ -99,9 +94,13 @@ impl<R> RouteAt<R> {
self.add_route_service(service, Method::POST) self.add_route_service(service, Method::POST)
} }
fn add_route<H, T>(self, handler: H, method: Method) -> RouteBuilder<Route<HandlerSvc<H, T>, R>> fn add_route<H, B, T>(
self,
handler: H,
method: Method,
) -> RouteBuilder<Route<HandlerSvc<H, B, T>, R>>
where where
H: Handler<T>, H: Handler<B, T>,
{ {
self.add_route_service(HandlerSvc::new(handler), method) self.add_route_service(HandlerSvc::new(handler), method)
} }
@ -149,9 +148,9 @@ impl<R> RouteBuilder<R> {
self.app.at(route_spec) self.app.at(route_spec)
} }
pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>> pub fn get<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
where where
F: Handler<T>, F: Handler<B, T>,
{ {
self.app.at_bytes(self.route_spec).get(handler_fn) self.app.at_bytes(self.route_spec).get(handler_fn)
} }
@ -164,9 +163,9 @@ impl<R> RouteBuilder<R> {
self.app.at_bytes(self.route_spec).get_service(service) self.app.at_bytes(self.route_spec).get_service(service)
} }
pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>> pub fn post<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
where where
F: Handler<T>, F: Handler<B, T>,
{ {
self.app.at_bytes(self.route_spec).post(handler_fn) self.app.at_bytes(self.route_spec).post(handler_fn)
} }
@ -193,6 +192,9 @@ pub enum Error {
#[error("failed to deserialize the request body")] #[error("failed to deserialize the request body")]
DeserializeRequestBody(#[source] serde_json::Error), DeserializeRequestBody(#[source] serde_json::Error),
#[error("failed to serialize the response body")]
SerializeResponseBody(#[source] serde_json::Error),
#[error("failed to consume the body")] #[error("failed to consume the body")]
ConsumeRequestBody(#[source] hyper::Error), ConsumeRequestBody(#[source] hyper::Error),
@ -225,8 +227,8 @@ mod sealed {
} }
#[async_trait] #[async_trait]
pub trait Handler<In>: Sized { pub trait Handler<B, In>: Sized {
type Response: IntoResponse; type Response: IntoResponse<B>;
// This seals the trait. We cannot use the regular "sealed super trait" approach // This seals the trait. We cannot use the regular "sealed super trait" approach
// due to coherence. // due to coherence.
@ -237,37 +239,94 @@ pub trait Handler<In>: Sized {
fn layer<L>(self, layer: L) -> Layered<L::Service, In> fn layer<L>(self, layer: L) -> Layered<L::Service, In>
where where
L: Layer<HandlerSvc<Self, In>>, L: Layer<HandlerSvc<Self, B, In>>,
{ {
Layered::new(layer.layer(HandlerSvc::new(self))) Layered::new(layer.layer(HandlerSvc::new(self)))
} }
} }
pub trait IntoResponse { pub trait IntoResponse<B> {
fn into_response(self) -> Response<Body>; fn into_response(self) -> Result<Response<B>, Error>;
} }
impl<B> IntoResponse for Response<B> impl<B> IntoResponse<B> for Response<B> {
fn into_response(self) -> Result<Response<B>, Error> {
Ok(self)
}
}
impl IntoResponse<Body> for &'static str {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for String {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for Bytes {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for &'static [u8] {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for Vec<u8> {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for std::borrow::Cow<'static, str> {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for std::borrow::Cow<'static, [u8]> {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
// TODO(david): rename this to Json when its in another module
pub struct JsonBody<T>(T);
impl<T> IntoResponse<Body> for JsonBody<T>
where where
B: Into<Body>, T: Serialize,
{ {
fn into_response(self) -> Response<Body> { fn into_response(self) -> Result<Response<Body>, Error> {
self.map(Into::into) let bytes = serde_json::to_vec(&self.0).map_err(Error::SerializeResponseBody)?;
} let len = bytes.len();
} let mut res = Response::new(Body::from(bytes));
impl IntoResponse for String { res.headers_mut().insert(
fn into_response(self) -> Response<Body> { header::CONTENT_TYPE,
Response::new(Body::from(self)) HeaderValue::from_static("application/json"),
);
res.headers_mut()
.insert(header::CONTENT_LENGTH, HeaderValue::from(len));
Ok(res)
} }
} }
#[async_trait] #[async_trait]
impl<F, Fut, Res> Handler<()> for F impl<F, Fut, B, Res> Handler<B, ()> for F
where where
F: Fn(Request<Body>) -> Fut + Send + Sync, F: Fn(Request<Body>) -> Fut + Send + Sync,
Fut: Future<Output = Result<Res, Error>> + Send, Fut: Future<Output = Result<Res, Error>> + Send,
Res: IntoResponse, Res: IntoResponse<B>,
{ {
type Response = Res; type Response = Res;
@ -282,11 +341,11 @@ macro_rules! impl_handler {
( $head:ident $(,)? ) => { ( $head:ident $(,)? ) => {
#[async_trait] #[async_trait]
#[allow(non_snake_case)] #[allow(non_snake_case)]
impl<F, Fut, Res, $head> Handler<($head,)> for F impl<F, Fut, B, Res, $head> Handler<B, ($head,)> for F
where where
F: Fn(Request<Body>, $head) -> Fut + Send + Sync, F: Fn(Request<Body>, $head) -> Fut + Send + Sync,
Fut: Future<Output = Result<Res, Error>> + Send, Fut: Future<Output = Result<Res, Error>> + Send,
Res: IntoResponse, Res: IntoResponse<B>,
$head: FromRequest + Send, $head: FromRequest + Send,
{ {
type Response = Res; type Response = Res;
@ -304,11 +363,11 @@ macro_rules! impl_handler {
( $head:ident, $($tail:ident),* $(,)? ) => { ( $head:ident, $($tail:ident),* $(,)? ) => {
#[async_trait] #[async_trait]
#[allow(non_snake_case)] #[allow(non_snake_case)]
impl<F, Fut, Res, $head, $($tail,)*> Handler<($head, $($tail,)*)> for F impl<F, Fut, B, Res, $head, $($tail,)*> Handler<B, ($head, $($tail,)*)> for F
where where
F: Fn(Request<Body>, $head, $($tail,)*) -> Fut + Send + Sync, F: Fn(Request<Body>, $head, $($tail,)*) -> Fut + Send + Sync,
Fut: Future<Output = Result<Res, Error>> + Send, Fut: Future<Output = Result<Res, Error>> + Send,
Res: IntoResponse, Res: IntoResponse<B>,
$head: FromRequest + Send, $head: FromRequest + Send,
$( $tail: FromRequest + Send, )* $( $tail: FromRequest + Send, )*
{ {
@ -347,10 +406,9 @@ where
} }
#[async_trait] #[async_trait]
impl<S, T> Handler<T> for Layered<S, T> impl<S, B, T> Handler<B, T> for Layered<S, T>
where where
S: Service<Request<Body>> + Send, S: Service<Request<Body>, Response = Response<B>> + Send,
S::Response: IntoResponse,
S::Error: Into<BoxError>, S::Error: Into<BoxError>,
S::Future: Send, S::Future: Send,
{ {
@ -375,12 +433,12 @@ impl<S, T> Layered<S, T> {
} }
} }
pub struct HandlerSvc<H, T> { pub struct HandlerSvc<H, B, T> {
handler: H, handler: H,
_input: PhantomData<fn() -> T>, _input: PhantomData<fn() -> (B, T)>,
} }
impl<H, T> HandlerSvc<H, T> { impl<H, B, T> HandlerSvc<H, B, T> {
fn new(handler: H) -> Self { fn new(handler: H) -> Self {
Self { Self {
handler, handler,
@ -389,7 +447,7 @@ impl<H, T> HandlerSvc<H, T> {
} }
} }
impl<H, T> Clone for HandlerSvc<H, T> impl<H, B, T> Clone for HandlerSvc<H, B, T>
where where
H: Clone, H: Clone,
{ {
@ -401,25 +459,27 @@ where
} }
} }
impl<H, T> Service<Request<Body>> for HandlerSvc<H, T> impl<H, B, T> Service<Request<Body>> for HandlerSvc<H, B, T>
where where
H: Handler<T> + Clone + Send + 'static, H: Handler<B, T> + Clone + Send + 'static,
H::Response: 'static, H::Response: 'static,
{ {
type Response = Response<Body>; type Response = Response<B>;
type Error = Error; type Error = Error;
type Future = future::BoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// HandlerSvc can only be constructed from async functions which are always ready // HandlerSvc can only be constructed from async functions which are always ready, or from
// `Layered` which bufferes in `<Layered as Handler>::call` and is therefore also always
// ready.
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, req: Request<Body>) -> Self::Future { fn call(&mut self, req: Request<Body>) -> Self::Future {
let handler = self.handler.clone(); let handler = self.handler.clone();
Box::pin(async move { Box::pin(async move {
let res = Handler::call(handler, req).await?; let res = Handler::call(handler, req).await?.into_response()?;
Ok(res.into_response()) Ok(res)
}) })
} }
} }
@ -543,6 +603,31 @@ where
} }
} }
// TODO(david): rename this to Bytes when its in another module
#[derive(Debug, Clone)]
pub struct BytesBody(Bytes);
impl BytesBody {
pub fn into_inner(self) -> Bytes {
self.0
}
}
impl FromRequest for BytesBody {
type Future = future::BoxFuture<'static, Result<Self, Error>>;
fn from_request(req: &mut Request<Body>) -> Self::Future {
let body = std::mem::take(req.body_mut());
Box::pin(async move {
let bytes = hyper::body::to_bytes(body)
.await
.map_err(Error::ConsumeRequestBody)?;
Ok(BytesBody(bytes))
})
}
}
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub struct EmptyRouter(()); pub struct EmptyRouter(());
@ -771,7 +856,9 @@ where
| Error::QueryStringMissing | Error::QueryStringMissing
| Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST), | Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST),
Error::MissingExtension { .. } => make_response(StatusCode::INTERNAL_SERVER_ERROR), Error::MissingExtension { .. } | Error::SerializeResponseBody(_) => {
make_response(StatusCode::INTERNAL_SERVER_ERROR)
}
Error::Service(err) => match err.downcast::<Error>() { Error::Service(err) => match err.downcast::<Error>() {
Ok(err) => Err(*err), Ok(err) => Err(*err),
@ -820,30 +907,28 @@ mod tests {
Ok(Response::new(Body::empty())) Ok(Response::new(Body::empty()))
} }
async fn users_index( let app =
_: Request<Body>, app()
pagination: Query<Pagination>,
) -> Result<String, Error> {
let pagination = pagination.into_inner();
assert_eq!(pagination.page, 1);
assert_eq!(pagination.per_page, 30);
Ok::<_, Error>("users#index".to_string())
}
let app = app()
// routes with functions // routes with functions
.at("/") .at("/")
.get(root) .get(root)
// routes with closures // routes with closures
.at("/users") .at("/users")
.get(users_index) .get(|_: Request<Body>, pagination: Query<Pagination>| async {
let pagination = pagination.into_inner();
assert_eq!(pagination.page, 1);
assert_eq!(pagination.per_page, 30);
Ok::<_, Error>("users#index".to_string())
})
.post( .post(
|_: Request<Body>, |_: Request<Body>,
payload: Json<UsersCreate>, payload: Json<UsersCreate>,
_state: Extension<Arc<State>>| async { _state: Extension<Arc<State>>| async {
let payload = payload.into_inner(); let payload = payload.into_inner();
assert_eq!(payload.username, "bob"); assert_eq!(payload.username, "bob");
Ok::<_, Error>(Response::new(Body::from("users#create"))) Ok::<_, Error>(JsonBody(
serde_json::json!({ "username": payload.username }),
))
}, },
) )
// routes with a service // routes with a service
@ -934,7 +1019,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.status(), StatusCode::OK);
assert_eq!(body_to_string(res).await, "users#create"); assert_eq!(body_to_string(res).await, r#"{"username":"bob"}"#);
} }
async fn body_to_string<B>(res: Response<B>) -> String async fn body_to_string<B>(res: Response<B>) -> String