Generic response types

This commit is contained in:
David Pedersen 2021-05-30 12:26:58 +02:00
parent bf3d4f4a40
commit 433b3183c4

View file

@ -4,15 +4,10 @@
Improvements to make:
Somehow return generic "into response" kinda types without having to manually
create hyper::Body for everything
Break stuff up into modules
Support extracting headers, perhaps via `headers::Header`?
Body to bytes extractor
Implement `FromRequest` for more functions, with macro
Tests
*/
@ -20,7 +15,7 @@ Tests
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::{future, ready};
use http::{Method, Request, Response, StatusCode};
use http::{header, HeaderValue, Method, Request, Response, StatusCode};
use http_body::Body as _;
use pin_project::pin_project;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
@ -69,9 +64,9 @@ pub struct RouteAt<R> {
}
impl<R> RouteAt<R> {
pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
pub fn get<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
where
F: Handler<T>,
F: Handler<B, T>,
{
self.add_route(handler_fn, Method::GET)
}
@ -84,9 +79,9 @@ impl<R> RouteAt<R> {
self.add_route_service(service, Method::GET)
}
pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
pub fn post<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
where
F: Handler<T>,
F: Handler<B, T>,
{
self.add_route(handler_fn, Method::POST)
}
@ -99,9 +94,13 @@ impl<R> RouteAt<R> {
self.add_route_service(service, Method::POST)
}
fn add_route<H, T>(self, handler: H, method: Method) -> RouteBuilder<Route<HandlerSvc<H, T>, R>>
fn add_route<H, B, T>(
self,
handler: H,
method: Method,
) -> RouteBuilder<Route<HandlerSvc<H, B, T>, R>>
where
H: Handler<T>,
H: Handler<B, T>,
{
self.add_route_service(HandlerSvc::new(handler), method)
}
@ -149,9 +148,9 @@ impl<R> RouteBuilder<R> {
self.app.at(route_spec)
}
pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
pub fn get<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
where
F: Handler<T>,
F: Handler<B, T>,
{
self.app.at_bytes(self.route_spec).get(handler_fn)
}
@ -164,9 +163,9 @@ impl<R> RouteBuilder<R> {
self.app.at_bytes(self.route_spec).get_service(service)
}
pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
pub fn post<F, B, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, B, T>, R>>
where
F: Handler<T>,
F: Handler<B, T>,
{
self.app.at_bytes(self.route_spec).post(handler_fn)
}
@ -193,6 +192,9 @@ pub enum Error {
#[error("failed to deserialize the request body")]
DeserializeRequestBody(#[source] serde_json::Error),
#[error("failed to serialize the response body")]
SerializeResponseBody(#[source] serde_json::Error),
#[error("failed to consume the body")]
ConsumeRequestBody(#[source] hyper::Error),
@ -225,8 +227,8 @@ mod sealed {
}
#[async_trait]
pub trait Handler<In>: Sized {
type Response: IntoResponse;
pub trait Handler<B, In>: Sized {
type Response: IntoResponse<B>;
// This seals the trait. We cannot use the regular "sealed super trait" approach
// due to coherence.
@ -237,37 +239,94 @@ pub trait Handler<In>: Sized {
fn layer<L>(self, layer: L) -> Layered<L::Service, In>
where
L: Layer<HandlerSvc<Self, In>>,
L: Layer<HandlerSvc<Self, B, In>>,
{
Layered::new(layer.layer(HandlerSvc::new(self)))
}
}
pub trait IntoResponse {
fn into_response(self) -> Response<Body>;
pub trait IntoResponse<B> {
fn into_response(self) -> Result<Response<B>, Error>;
}
impl<B> IntoResponse for Response<B>
where
B: Into<Body>,
{
fn into_response(self) -> Response<Body> {
self.map(Into::into)
impl<B> IntoResponse<B> for Response<B> {
fn into_response(self) -> Result<Response<B>, Error> {
Ok(self)
}
}
impl IntoResponse for String {
fn into_response(self) -> Response<Body> {
Response::new(Body::from(self))
impl IntoResponse<Body> for &'static str {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for String {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for Bytes {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for &'static [u8] {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for Vec<u8> {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for std::borrow::Cow<'static, str> {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
impl IntoResponse<Body> for std::borrow::Cow<'static, [u8]> {
fn into_response(self) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::from(self)))
}
}
// TODO(david): rename this to Json when its in another module
pub struct JsonBody<T>(T);
impl<T> IntoResponse<Body> for JsonBody<T>
where
T: Serialize,
{
fn into_response(self) -> Result<Response<Body>, Error> {
let bytes = serde_json::to_vec(&self.0).map_err(Error::SerializeResponseBody)?;
let len = bytes.len();
let mut res = Response::new(Body::from(bytes));
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
res.headers_mut()
.insert(header::CONTENT_LENGTH, HeaderValue::from(len));
Ok(res)
}
}
#[async_trait]
impl<F, Fut, Res> Handler<()> for F
impl<F, Fut, B, Res> Handler<B, ()> for F
where
F: Fn(Request<Body>) -> Fut + Send + Sync,
Fut: Future<Output = Result<Res, Error>> + Send,
Res: IntoResponse,
Res: IntoResponse<B>,
{
type Response = Res;
@ -282,11 +341,11 @@ macro_rules! impl_handler {
( $head:ident $(,)? ) => {
#[async_trait]
#[allow(non_snake_case)]
impl<F, Fut, Res, $head> Handler<($head,)> for F
impl<F, Fut, B, Res, $head> Handler<B, ($head,)> for F
where
F: Fn(Request<Body>, $head) -> Fut + Send + Sync,
Fut: Future<Output = Result<Res, Error>> + Send,
Res: IntoResponse,
Res: IntoResponse<B>,
$head: FromRequest + Send,
{
type Response = Res;
@ -304,11 +363,11 @@ macro_rules! impl_handler {
( $head:ident, $($tail:ident),* $(,)? ) => {
#[async_trait]
#[allow(non_snake_case)]
impl<F, Fut, Res, $head, $($tail,)*> Handler<($head, $($tail,)*)> for F
impl<F, Fut, B, Res, $head, $($tail,)*> Handler<B, ($head, $($tail,)*)> for F
where
F: Fn(Request<Body>, $head, $($tail,)*) -> Fut + Send + Sync,
Fut: Future<Output = Result<Res, Error>> + Send,
Res: IntoResponse,
Res: IntoResponse<B>,
$head: FromRequest + Send,
$( $tail: FromRequest + Send, )*
{
@ -347,10 +406,9 @@ where
}
#[async_trait]
impl<S, T> Handler<T> for Layered<S, T>
impl<S, B, T> Handler<B, T> for Layered<S, T>
where
S: Service<Request<Body>> + Send,
S::Response: IntoResponse,
S: Service<Request<Body>, Response = Response<B>> + Send,
S::Error: Into<BoxError>,
S::Future: Send,
{
@ -375,12 +433,12 @@ impl<S, T> Layered<S, T> {
}
}
pub struct HandlerSvc<H, T> {
pub struct HandlerSvc<H, B, T> {
handler: H,
_input: PhantomData<fn() -> T>,
_input: PhantomData<fn() -> (B, T)>,
}
impl<H, T> HandlerSvc<H, T> {
impl<H, B, T> HandlerSvc<H, B, T> {
fn new(handler: H) -> Self {
Self {
handler,
@ -389,7 +447,7 @@ impl<H, T> HandlerSvc<H, T> {
}
}
impl<H, T> Clone for HandlerSvc<H, T>
impl<H, B, T> Clone for HandlerSvc<H, B, T>
where
H: Clone,
{
@ -401,25 +459,27 @@ where
}
}
impl<H, T> Service<Request<Body>> for HandlerSvc<H, T>
impl<H, B, T> Service<Request<Body>> for HandlerSvc<H, B, T>
where
H: Handler<T> + Clone + Send + 'static,
H: Handler<B, T> + Clone + Send + 'static,
H::Response: 'static,
{
type Response = Response<Body>;
type Response = Response<B>;
type Error = Error;
type Future = future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// HandlerSvc can only be constructed from async functions which are always ready
// HandlerSvc can only be constructed from async functions which are always ready, or from
// `Layered` which bufferes in `<Layered as Handler>::call` and is therefore also always
// ready.
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let handler = self.handler.clone();
Box::pin(async move {
let res = Handler::call(handler, req).await?;
Ok(res.into_response())
let res = Handler::call(handler, req).await?.into_response()?;
Ok(res)
})
}
}
@ -543,6 +603,31 @@ where
}
}
// TODO(david): rename this to Bytes when its in another module
#[derive(Debug, Clone)]
pub struct BytesBody(Bytes);
impl BytesBody {
pub fn into_inner(self) -> Bytes {
self.0
}
}
impl FromRequest for BytesBody {
type Future = future::BoxFuture<'static, Result<Self, Error>>;
fn from_request(req: &mut Request<Body>) -> Self::Future {
let body = std::mem::take(req.body_mut());
Box::pin(async move {
let bytes = hyper::body::to_bytes(body)
.await
.map_err(Error::ConsumeRequestBody)?;
Ok(BytesBody(bytes))
})
}
}
#[derive(Clone, Copy)]
pub struct EmptyRouter(());
@ -771,7 +856,9 @@ where
| Error::QueryStringMissing
| Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST),
Error::MissingExtension { .. } => make_response(StatusCode::INTERNAL_SERVER_ERROR),
Error::MissingExtension { .. } | Error::SerializeResponseBody(_) => {
make_response(StatusCode::INTERNAL_SERVER_ERROR)
}
Error::Service(err) => match err.downcast::<Error>() {
Ok(err) => Err(*err),
@ -820,46 +907,44 @@ mod tests {
Ok(Response::new(Body::empty()))
}
async fn users_index(
_: Request<Body>,
pagination: Query<Pagination>,
) -> Result<String, Error> {
let pagination = pagination.into_inner();
assert_eq!(pagination.page, 1);
assert_eq!(pagination.per_page, 30);
Ok::<_, Error>("users#index".to_string())
}
let app = app()
// routes with functions
.at("/")
.get(root)
// routes with closures
.at("/users")
.get(users_index)
.post(
|_: Request<Body>,
payload: Json<UsersCreate>,
_state: Extension<Arc<State>>| async {
let payload = payload.into_inner();
assert_eq!(payload.username, "bob");
Ok::<_, Error>(Response::new(Body::from("users#create")))
},
)
// routes with a service
.at("/service")
.get_service(service_fn(root))
// routes with layers applied
.at("/large-static-file")
.get(
large_static_file.layer(
ServiceBuilder::new()
.layer(TimeoutLayer::new(Duration::from_secs(30)))
.layer(CompressionLayer::new())
.into_inner(),
),
)
.into_service();
let app =
app()
// routes with functions
.at("/")
.get(root)
// routes with closures
.at("/users")
.get(|_: Request<Body>, pagination: Query<Pagination>| async {
let pagination = pagination.into_inner();
assert_eq!(pagination.page, 1);
assert_eq!(pagination.per_page, 30);
Ok::<_, Error>("users#index".to_string())
})
.post(
|_: Request<Body>,
payload: Json<UsersCreate>,
_state: Extension<Arc<State>>| async {
let payload = payload.into_inner();
assert_eq!(payload.username, "bob");
Ok::<_, Error>(JsonBody(
serde_json::json!({ "username": payload.username }),
))
},
)
// routes with a service
.at("/service")
.get_service(service_fn(root))
// routes with layers applied
.at("/large-static-file")
.get(
large_static_file.layer(
ServiceBuilder::new()
.layer(TimeoutLayer::new(Duration::from_secs(30)))
.layer(CompressionLayer::new())
.into_inner(),
),
)
.into_service();
// state shared by all routes, could hold db connection etc
struct State {}
@ -934,7 +1019,7 @@ mod tests {
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(body_to_string(res).await, "users#create");
assert_eq!(body_to_string(res).await, r#"{"username":"bob"}"#);
}
async fn body_to_string<B>(res: Response<B>) -> String