Fix layer support

This commit is contained in:
David Pedersen 2021-05-30 04:28:24 +02:00
parent f983b37fea
commit b763eaa037
3 changed files with 244 additions and 66 deletions

View file

@ -22,5 +22,5 @@ tower = { version = "0.4", features = ["util"] }
tokio = { version = "1.6.1", features = ["macros", "rt"] } tokio = { version = "1.6.1", features = ["macros", "rt"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
tower = { version = "0.4", features = ["util", "make", "timeout"] } tower = { version = "0.4", features = ["util", "make", "timeout"] }
tower-http = { version = "0.1", features = ["trace"] } tower-http = { version = "0.1", features = ["trace", "compression"] }
hyper = { version = "0.14", features = ["full"] } hyper = { version = "0.14", features = ["full"] }

71
src/body.rs Normal file
View file

@ -0,0 +1,71 @@
use bytes::Buf;
use http_body::{Body, Empty};
use std::{
fmt,
pin::Pin,
task::{Context, Poll},
};
/// A boxed [`Body`] trait object.
pub struct BoxBody<D, E> {
inner: Pin<Box<dyn Body<Data = D, Error = E> + Send + Sync + 'static>>,
}
impl<D, E> BoxBody<D, E> {
/// Create a new `BoxBody`.
pub fn new<B>(body: B) -> Self
where
B: Body<Data = D, Error = E> + Send + Sync + 'static,
D: Buf,
{
Self {
inner: Box::pin(body),
}
}
}
// TODO(david): upstream this to http-body?
impl<D, E> Default for BoxBody<D, E>
where
D: bytes::Buf + 'static,
{
fn default() -> Self {
BoxBody::new(Empty::<D>::new().map_err(|err| match err {}))
}
}
impl<D, E> fmt::Debug for BoxBody<D, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BoxBody").finish()
}
}
impl<D, E> Body for BoxBody<D, E>
where
D: Buf,
{
type Data = D;
type Error = E;
fn poll_data(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
self.inner.as_mut().poll_data(cx)
}
fn poll_trailers(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
self.inner.as_mut().poll_trailers(cx)
}
fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}
fn size_hint(&self) -> http_body::SizeHint {
self.inner.size_hint()
}
}

View file

@ -7,18 +7,6 @@ Improvements to make:
Somehow return generic "into response" kinda types without having to manually Somehow return generic "into response" kinda types without having to manually
create hyper::Body for everything create hyper::Body for everything
Don't make Query and Json contain a Result, instead make generic wrapper
for "optional" inputs
Make it possible to convert QueryError and JsonError into responses
Support wrapping single routes in tower::Layer
Support putting a tower::Service at a Route
Don't require the response body to be hyper::Body, wont work if we're wrapping
single routes in layers
Support extracting headers, perhaps via `headers::Header`? Support extracting headers, perhaps via `headers::Header`?
Implement `FromRequest` for more functions, with macro Implement `FromRequest` for more functions, with macro
@ -31,7 +19,7 @@ use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use futures_util::{future, ready}; use futures_util::{future, ready};
use http::{Method, Request, Response, StatusCode}; use http::{Method, Request, Response, StatusCode};
use http_body::{combinators::BoxBody, Body as _}; use http_body::Body as _;
use pin_project::pin_project; use pin_project::pin_project;
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{ use std::{
@ -43,6 +31,9 @@ use std::{
}; };
use tower::{BoxError, Layer, Service, ServiceExt}; use tower::{BoxError, Layer, Service, ServiceExt};
mod body;
pub use body::BoxBody;
pub use hyper::body::Body; pub use hyper::body::Body;
pub fn app() -> App<EmptyRouter> { pub fn app() -> App<EmptyRouter> {
@ -134,12 +125,23 @@ impl<R> RouteAt<R> {
} }
} }
#[derive(Clone)]
pub struct RouteBuilder<R> { pub struct RouteBuilder<R> {
app: App<R>, app: App<R>,
route_spec: Bytes, route_spec: Bytes,
} }
impl<R> Clone for RouteBuilder<R>
where
R: Clone,
{
fn clone(&self) -> Self {
Self {
app: self.app.clone(),
route_spec: self.route_spec.clone(),
}
}
}
impl<R> RouteBuilder<R> { impl<R> RouteBuilder<R> {
pub fn at(self, route_spec: &str) -> RouteAt<R> { pub fn at(self, route_spec: &str) -> RouteAt<R> {
self.app.at(route_spec) self.app.at(route_spec)
@ -174,6 +176,13 @@ impl<R> RouteBuilder<R> {
{ {
self.app.at_bytes(self.route_spec).post_service(service) self.app.at_bytes(self.route_spec).post_service(service)
} }
pub fn into_service(self) -> IntoService<R> {
IntoService {
app: self.app,
poll_ready_error: None,
}
}
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@ -183,13 +192,13 @@ pub enum Error {
DeserializeRequestBody(#[source] serde_json::Error), DeserializeRequestBody(#[source] serde_json::Error),
#[error("failed to consume the body")] #[error("failed to consume the body")]
ConsumeBody(#[source] hyper::Error), ConsumeRequestBody(#[source] hyper::Error),
#[error("URI contained no query string")] #[error("URI contained no query string")]
QueryStringMissing, QueryStringMissing,
#[error("failed to deserialize query string")] #[error("failed to deserialize query string")]
DeserializeQueryString(#[from] serde_urlencoded::de::Error), DeserializeQueryString(#[source] serde_urlencoded::de::Error),
#[error("failed generating the response body")] #[error("failed generating the response body")]
ResponseBody(#[source] BoxError), ResponseBody(#[source] BoxError),
@ -198,15 +207,6 @@ pub enum Error {
Service(#[source] BoxError), Service(#[source] BoxError),
} }
impl From<BoxError> for Error {
fn from(err: BoxError) -> Self {
match err.downcast::<Error>() {
Ok(err) => *err,
Err(err) => Error::Service(err),
}
}
}
impl From<Infallible> for Error { impl From<Infallible> for Error {
fn from(err: Infallible) -> Self { fn from(err: Infallible) -> Self {
match err {} match err {}
@ -316,7 +316,7 @@ where
self.svc self.svc
.oneshot(req) .oneshot(req)
.await .await
.map_err(|err| Error::from(err.into())) .map_err(|err| Error::Service(err.into()))
} }
} }
@ -425,7 +425,7 @@ where
fn from_request(req: &mut Request<Body>) -> Self::Future { fn from_request(req: &mut Request<Body>) -> Self::Future {
let result = (|| { let result = (|| {
let query = req.uri().query().ok_or(Error::QueryStringMissing)?; let query = req.uri().query().ok_or(Error::QueryStringMissing)?;
let value = serde_urlencoded::from_str(query)?; let value = serde_urlencoded::from_str(query).map_err(Error::DeserializeQueryString)?;
Ok(Query(value)) Ok(Query(value))
})(); })();
@ -456,7 +456,7 @@ where
Box::pin(async move { Box::pin(async move {
let bytes = hyper::body::to_bytes(body) let bytes = hyper::body::to_bytes(body)
.await .await
.map_err(Error::ConsumeBody)?; .map_err(Error::ConsumeRequestBody)?;
let value = serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?; let value = serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?;
Ok(Json(value)) Ok(Json(value))
}) })
@ -588,53 +588,116 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let response: Response<B> = ready!(self.project().0.poll(cx)).map_err(Into::into)?; let response: Response<B> = ready!(self.project().0.poll(cx)).map_err(Into::into)?;
let response = let response = response.map(|body| {
response.map(|body| body.map_err(|err| Error::ResponseBody(err.into())).boxed()); let body = body.map_err(|err| Error::ResponseBody(err.into()));
BoxBody::new(body)
});
Poll::Ready(Ok(response)) Poll::Ready(Ok(response))
} }
} }
impl<R, T> Service<T> for App<R> pub struct IntoService<R> {
app: App<R>,
poll_ready_error: Option<Error>,
}
impl<R> Clone for IntoService<R>
where where
R: Service<T>, R: Clone,
R::Error: Into<Error>,
{ {
type Response = R::Response; fn clone(&self) -> Self {
type Error = R::Error; Self {
type Future = R::Future; app: self.app.clone(),
poll_ready_error: None,
#[inline] }
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// TODO(david): map error to response
self.router.poll_ready(cx)
}
#[inline]
fn call(&mut self, req: T) -> Self::Future {
// TODO(david): map error to response
self.router.call(req)
} }
} }
impl<R, T> Service<T> for RouteBuilder<R> impl<R, B, T> Service<T> for IntoService<R>
where where
App<R>: Service<T>, R: Service<T, Response = Response<B>>,
<App<R> as Service<T>>::Error: Into<Error>, R::Error: Into<Error>,
B: Default,
{ {
type Response = <App<R> as Service<T>>::Response; type Response = Response<B>;
type Error = <App<R> as Service<T>>::Error; type Error = Error;
type Future = <App<R> as Service<T>>::Future; type Future = HandleErrorFuture<R::Future, B>;
#[inline] #[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// TODO(david): map error to response self.app.router.poll_ready(cx).map_err(Into::into)
self.app.poll_ready(cx)
} }
#[inline]
fn call(&mut self, req: T) -> Self::Future { fn call(&mut self, req: T) -> Self::Future {
// TODO(david): map error to response if let Some(poll_ready_error) = self.poll_ready_error.take() {
self.app.call(req) match handle_error::<B>(poll_ready_error) {
Ok(res) => {
return HandleErrorFuture(Kind::Response(Some(res)));
}
Err(err) => {
return HandleErrorFuture(Kind::Error(Some(err)));
}
}
}
HandleErrorFuture(Kind::Future(self.app.router.call(req)))
}
}
#[pin_project]
pub struct HandleErrorFuture<F, B>(#[pin] Kind<F, B>);
#[pin_project(project = KindProj)]
enum Kind<F, B> {
Response(Option<Response<B>>),
Error(Option<Error>),
Future(#[pin] F),
}
impl<F, B, E> Future for HandleErrorFuture<F, B>
where
F: Future<Output = Result<Response<B>, E>>,
E: Into<Error>,
B: Default,
{
type Output = Result<Response<B>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().0.project() {
KindProj::Response(res) => Poll::Ready(Ok(res.take().unwrap())),
KindProj::Error(err) => Poll::Ready(Err(err.take().unwrap())),
KindProj::Future(fut) => match ready!(fut.poll(cx)) {
Ok(res) => Poll::Ready(Ok(res)),
Err(err) => Poll::Ready(handle_error(err.into())),
},
}
}
}
fn handle_error<B>(error: Error) -> Result<Response<B>, Error>
where
B: Default,
{
fn make_response<B>(status: StatusCode) -> Result<Response<B>, Error>
where
B: Default,
{
let mut res = Response::new(B::default());
*res.status_mut() = status;
Ok(res)
}
match error {
Error::DeserializeRequestBody(_)
| Error::QueryStringMissing
| Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST),
Error::Service(err) => match err.downcast::<Error>() {
Ok(err) => Err(*err),
Err(err) => Err(Error::Service(err)),
},
err @ Error::ConsumeRequestBody(_) => Err(err),
err @ Error::ResponseBody(_) => Err(err),
} }
} }
@ -648,7 +711,10 @@ mod tests {
use tower::{ use tower::{
layer::util::Identity, make::Shared, service_fn, timeout::TimeoutLayer, ServiceBuilder, layer::util::Identity, make::Shared, service_fn, timeout::TimeoutLayer, ServiceBuilder,
}; };
use tower_http::trace::TraceLayer; use tower_http::{
compression::CompressionLayer,
trace::{Trace, TraceLayer},
};
#[tokio::test] #[tokio::test]
async fn basic() { async fn basic() {
@ -667,9 +733,15 @@ mod tests {
Ok(Response::new(Body::from("Hello, World!"))) Ok(Response::new(Body::from("Hello, World!")))
} }
let mut app = app() async fn large_static_file(_: Request<Body>) -> Result<Response<Body>, Error> {
Ok(Response::new(Body::empty()))
}
let app = app()
// routes with functions
.at("/") .at("/")
.get(root.layer(TimeoutLayer::new(Duration::from_secs(30)))) .get(root)
// routes with closures
.at("/users") .at("/users")
.get(|_: Request<Body>, pagination: Query<Pagination>| async { .get(|_: Request<Body>, pagination: Query<Pagination>| async {
let pagination = pagination.into_inner(); let pagination = pagination.into_inner();
@ -684,8 +756,23 @@ mod tests {
Ok::<_, Error>(Response::new(Body::from("users#create"))) Ok::<_, Error>(Response::new(Body::from("users#create")))
}) })
// routes with a service
.at("/service") .at("/service")
.get_service(service_fn(root)); .get_service(service_fn(root))
// routes with layers applied
.at("/large-static-file")
.get(
large_static_file.layer(
ServiceBuilder::new()
.layer(TimeoutLayer::new(Duration::from_secs(30)))
.layer(CompressionLayer::new())
.into_inner(),
),
)
.into_service();
// can add more middleware
let mut app = Trace::new_for_http(app);
let res = app let res = app
.ready() .ready()
@ -719,6 +806,22 @@ mod tests {
assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.status(), StatusCode::OK);
assert_eq!(body_to_string(res).await, "users#index"); assert_eq!(body_to_string(res).await, "users#index");
let res = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::GET)
.uri("/users")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(body_to_string(res).await, "");
let res = app let res = app
.ready() .ready()
.await .await
@ -748,12 +851,16 @@ mod tests {
#[allow(dead_code)] #[allow(dead_code)]
// this should just compile // this should just compile
async fn compatible_with_hyper_and_tower_http() { async fn compatible_with_hyper_and_tower_http() {
let app = app().at("/").get(|_: Request<Body>| async { let app = app()
Ok::<_, Error>(Response::new(Body::from("Hello, World!"))) .at("/")
}); .get(|_: Request<Body>| async {
Ok::<_, Error>(Response::new(Body::from("Hello, World!")))
})
.into_service();
let app = ServiceBuilder::new() let app = ServiceBuilder::new()
.layer(TraceLayer::new_for_http()) .layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
.service(app); .service(app);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000));