Support nesting services with error handling

This commit is contained in:
David Pedersen 2021-06-01 11:23:56 +02:00
parent 093ad3622e
commit f690e74275
8 changed files with 288 additions and 96 deletions

View file

@ -24,6 +24,15 @@ reqwest = { version = "0.11", features = ["json", "stream"] }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] }
tower = { version = "0.4", features = ["util", "make", "timeout"] }
tower-http = { version = "0.1", features = ["trace", "compression", "add-extension"] }
tracing = "0.1"
tracing-subscriber = "0.2"
[dev-dependencies.tower-http]
version = "0.1"
features = [
"trace",
"compression",
"compression-full",
"add-extension",
"fs",
]

View file

@ -1,9 +1,8 @@
use http::Request;
use http::{Request, StatusCode};
use hyper::Server;
use std::net::SocketAddr;
use tower::{make::Shared, ServiceBuilder};
use tower_http::trace::TraceLayer;
use tower_web::{body::Body, response::Html};
use tower::make::Shared;
use tower_web::{body::Body, extract, response::Html};
#[tokio::main]
async fn main() {
@ -13,14 +12,11 @@ async fn main() {
let app = tower_web::app()
.at("/")
.get(handler)
.at("/greet/:name")
.get(greet)
// convert it into a `Service`
.into_service();
// add some middleware
let app = ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.service(app);
// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
@ -31,3 +27,12 @@ async fn main() {
async fn handler(_req: Request<Body>) -> Html<&'static str> {
Html("<h1>Hello, World!</h1>")
}
async fn greet(_req: Request<Body>, params: extract::UrlParamsMap) -> Result<String, StatusCode> {
if let Some(name) = params.get("name") {
Ok(format!("Hello {}!", name))
} else {
// if the route matches "name" will be present
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}

View file

@ -11,7 +11,7 @@ use tower::{make::Shared, ServiceBuilder};
use tower_http::{
add_extension::AddExtensionLayer, compression::CompressionLayer, trace::TraceLayer,
};
use tower_web::{body::Body, extract};
use tower_web::{body::Body, extract, handler::Handler};
#[tokio::main]
async fn main() {
@ -20,7 +20,7 @@ async fn main() {
// build our application with some routes
let app = tower_web::app()
.at("/:key")
.get(get)
.get(get.layer(CompressionLayer::new()))
.post(set)
// convert it into a `Service`
.into_service();
@ -29,7 +29,6 @@ async fn main() {
let app = ServiceBuilder::new()
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
.layer(AddExtensionLayer::new(SharedState::default()))
.service(app);
@ -51,10 +50,6 @@ async fn get(
_req: Request<Body>,
params: extract::UrlParams<(String,)>,
state: extract::Extension<SharedState>,
// Anything that implements `IntoResponse` can be used a response
//
// Handlers cannot return errors. Everything will be converted
// into a response. `BoxError` becomes `500 Internal server error`
) -> Result<Bytes, StatusCode> {
let state = state.into_inner();
let db = &state.lock().unwrap().db;
@ -73,8 +68,6 @@ async fn set(
params: extract::UrlParams<(String,)>,
value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
state: extract::Extension<SharedState>,
// `()` also implements `IntoResponse` so we can use that to return
// an empty response
) {
let state = state.into_inner();
let db = &mut state.lock().unwrap().db;

View file

@ -1,78 +1,64 @@
use bytes::Buf;
use futures_util::ready;
use http_body::{Body as _, Empty};
use bytes::Bytes;
use http_body::{Empty, Full};
use std::{
fmt,
pin::Pin,
task::{Context, Poll},
};
use tower::BoxError;
pub use hyper::body::Body;
use crate::BoxStdError;
/// A boxed [`Body`] trait object.
pub struct BoxBody<D, E> {
inner: Pin<Box<dyn http_body::Body<Data = D, Error = E> + Send + Sync + 'static>>,
pub struct BoxBody {
// when we've gotten rid of `BoxStdError` we should be able to change the error type to
// `BoxError`
inner: Pin<Box<dyn http_body::Body<Data = Bytes, Error = BoxStdError> + Send + Sync + 'static>>,
}
impl<D, E> BoxBody<D, E> {
impl BoxBody {
/// Create a new `BoxBody`.
pub fn new<B>(body: B) -> Self
where
B: http_body::Body<Data = D, Error = E> + Send + Sync + 'static,
D: Buf,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{
Self {
inner: Box::pin(body),
inner: Box::pin(body.map_err(|error| BoxStdError(error.into()))),
}
}
}
// TODO: upstream this to http-body?
impl<D, E> Default for BoxBody<D, E>
where
D: bytes::Buf + 'static,
{
impl Default for BoxBody {
fn default() -> Self {
BoxBody::new(Empty::<D>::new().map_err(|err| match err {}))
BoxBody::new(Empty::<Bytes>::new())
}
}
impl<D, E> fmt::Debug for BoxBody<D, E> {
impl fmt::Debug for BoxBody {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BoxBody").finish()
}
}
// when we've gotten rid of `BoxStdError` then we can remove this
impl<D, E> http_body::Body for BoxBody<D, E>
where
D: Buf,
E: Into<tower::BoxError>,
{
type Data = D;
impl http_body::Body for BoxBody {
type Data = Bytes;
type Error = BoxStdError;
fn poll_data(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
match ready!(self.inner.as_mut().poll_data(cx)) {
Some(Ok(chunk)) => Some(Ok(chunk)).into(),
Some(Err(err)) => Some(Err(BoxStdError(err.into()))).into(),
None => None.into(),
}
self.inner.as_mut().poll_data(cx)
}
fn poll_trailers(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
match ready!(self.inner.as_mut().poll_trailers(cx)) {
Ok(trailers) => Ok(trailers).into(),
Err(err) => Err(BoxStdError(err.into())).into(),
}
self.inner.as_mut().poll_trailers(cx)
}
fn is_end_stream(&self) -> bool {
@ -84,13 +70,11 @@ where
}
}
impl From<String> for BoxBody<bytes::Bytes, tower::BoxError> {
fn from(s: String) -> Self {
let body = hyper::Body::from(s);
let body = body.map_err(Into::<tower::BoxError>::into);
BoxBody {
inner: Box::pin(body),
}
impl<B> From<B> for BoxBody
where
B: Into<Bytes>,
{
fn from(s: B) -> Self {
BoxBody::new(Full::from(s.into()))
}
}

View file

@ -2,17 +2,19 @@ use self::{
body::Body,
routing::{AlwaysNotFound, RouteAt},
};
use body::BoxBody;
use bytes::Bytes;
use futures_util::ready;
use http::Response;
use http::{Request, Response};
use pin_project::pin_project;
use std::{
convert::Infallible,
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::Service;
use tower::{BoxError, Service};
pub mod body;
pub mod extract;
@ -69,7 +71,7 @@ where
{
type Response = Response<B>;
type Error = Infallible;
type Future = HandleErrorFuture<R::Future>;
type Future = R::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match ready!(self.app.service_tree.poll_ready(cx)) {
@ -79,22 +81,7 @@ where
}
fn call(&mut self, req: T) -> Self::Future {
HandleErrorFuture(self.app.service_tree.call(req))
}
}
#[pin_project]
pub struct HandleErrorFuture<F>(#[pin] F);
impl<F, B> Future for HandleErrorFuture<F>
where
F: Future<Output = Result<Response<B>, Infallible>>,
B: Default,
{
type Output = Result<Response<B>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().0.poll(cx)
self.app.service_tree.call(req)
}
}
@ -121,4 +108,141 @@ impl<T> ResultExt<T> for Result<T, Infallible> {
// totally fix it at some point.
#[derive(Debug, thiserror::Error)]
#[error("{0}")]
pub struct BoxStdError(#[source] tower::BoxError);
pub struct BoxStdError(#[source] pub(crate) tower::BoxError);
pub trait ServiceExt<B>: Service<Request<Body>, Response = Response<B>> {
fn handle_error<F, NewBody>(self, f: F) -> HandleError<Self, F, Self::Error>
where
Self: Sized,
F: FnOnce(Self::Error) -> Response<NewBody>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
NewBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
NewBody::Error: Into<BoxError> + Send + Sync + 'static,
{
HandleError {
inner: self,
f,
poll_ready_error: None,
}
}
}
impl<S, B> ServiceExt<B> for S where S: Service<Request<Body>, Response = Response<B>> {}
pub struct HandleError<S, F, E> {
inner: S,
f: F,
poll_ready_error: Option<E>,
}
impl<S, F, E> fmt::Debug for HandleError<S, F, E>
where
S: fmt::Debug,
E: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HandleError")
.field("inner", &self.inner)
.field("f", &format_args!("{}", std::any::type_name::<F>()))
.field("poll_ready_error", &self.poll_ready_error)
.finish()
}
}
impl<S, F, E> Clone for HandleError<S, F, E>
where
S: Clone,
F: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
f: self.f.clone(),
poll_ready_error: None,
}
}
}
impl<S, F, B, NewBody> Service<Request<Body>> for HandleError<S, F, S::Error>
where
S: Service<Request<Body>, Response = Response<B>>,
F: FnOnce(S::Error) -> Response<NewBody> + Clone,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
NewBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
NewBody::Error: Into<BoxError> + Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = HandleErrorFuture<S::Future, F, S::Error>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match ready!(self.inner.poll_ready(cx)) {
Ok(_) => Poll::Ready(Ok(())),
Err(err) => {
self.poll_ready_error = Some(err);
Poll::Ready(Ok(()))
}
}
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
if let Some(err) = self.poll_ready_error.take() {
return HandleErrorFuture {
f: Some(self.f.clone()),
kind: Kind::Error(Some(err)),
};
}
HandleErrorFuture {
f: Some(self.f.clone()),
kind: Kind::Future(self.inner.call(req)),
}
}
}
#[pin_project]
pub struct HandleErrorFuture<Fut, F, E> {
#[pin]
kind: Kind<Fut, E>,
f: Option<F>,
}
#[pin_project(project = KindProj)]
enum Kind<Fut, E> {
Future(#[pin] Fut),
Error(Option<E>),
}
impl<Fut, F, E, B, NewBody> Future for HandleErrorFuture<Fut, F, E>
where
Fut: Future<Output = Result<Response<B>, E>>,
F: FnOnce(E) -> Response<NewBody>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
NewBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
NewBody::Error: Into<BoxError> + Send + Sync + 'static,
{
type Output = Result<Response<BoxBody>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.kind.project() {
KindProj::Future(future) => match ready!(future.poll(cx)) {
Ok(res) => Ok(res.map(BoxBody::new)).into(),
Err(err) => {
let f = this.f.take().unwrap();
let res = f(err);
Ok(res.map(BoxBody::new)).into()
}
},
KindProj::Error(err) => {
let f = this.f.take().unwrap();
let res = f(err.take().unwrap());
Ok(res.map(BoxBody::new)).into()
}
}
}
}

View file

@ -8,6 +8,8 @@ use tower::util::Either;
pub trait IntoResponse<B> {
fn into_response(self) -> Response<B>;
// TODO(david): remove this an return return `Response<B>` instead. That is what this method
// does anyway.
fn boxed(self) -> BoxIntoResponse<B>
where
Self: Sized + 'static,

View file

@ -55,8 +55,7 @@ impl<R> RouteAt<R> {
pub fn get_service<S, B>(self, service: S) -> RouteBuilder<Or<S, R>>
where
S: Service<Request<Body>, Response = Response<B>> + Clone,
S::Error: Into<BoxError>,
S: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Clone,
{
self.add_route_service(service, Method::GET)
}
@ -70,8 +69,7 @@ impl<R> RouteAt<R> {
pub fn post_service<S, B>(self, service: S) -> RouteBuilder<Or<S, R>>
where
S: Service<Request<Body>, Response = Response<B>> + Clone,
S::Error: Into<BoxError>,
S: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Clone,
{
self.add_route_service(service, Method::POST)
}
@ -141,8 +139,7 @@ impl<R> RouteBuilder<R> {
pub fn get_service<S, B>(self, service: S) -> RouteBuilder<Or<S, R>>
where
S: Service<Request<Body>, Response = Response<B>> + Clone,
S::Error: Into<BoxError>,
S: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Clone,
{
self.app.at_bytes(self.route_spec).get_service(service)
}
@ -156,8 +153,7 @@ impl<R> RouteBuilder<R> {
pub fn post_service<S, B>(self, service: S) -> RouteBuilder<Or<S, R>>
where
S: Service<Request<Body>, Response = Response<B>> + Clone,
S::Error: Into<BoxError>,
S: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Clone,
{
self.app.at_bytes(self.route_spec).post_service(service)
}
@ -173,7 +169,6 @@ impl<R> RouteBuilder<R> {
where
R: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Send + 'static,
R::Future: Send,
// TODO(david): do we still need default here
B: From<String> + 'static,
{
let svc = ServiceBuilder::new()
@ -274,14 +269,14 @@ impl RouteSpec {
impl<H, F, HB, FB> Service<Request<Body>> for Or<H, F>
where
H: Service<Request<Body>, Response = Response<HB>, Error = Infallible>,
HB: http_body::Body + Send + Sync + 'static,
HB: http_body::Body<Data = Bytes> + Send + Sync + 'static,
HB::Error: Into<BoxError>,
F: Service<Request<Body>, Response = Response<FB>, Error = Infallible>,
FB: http_body::Body<Data = HB::Data> + Send + Sync + 'static,
FB: http_body::Body<Data = Bytes> + Send + Sync + 'static,
FB::Error: Into<BoxError>,
{
type Response = Response<BoxBody<HB::Data, BoxError>>;
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = future::Either<BoxResponseBody<H::Future>, BoxResponseBody<F::Future>>;
@ -337,10 +332,10 @@ pub struct BoxResponseBody<F>(#[pin] F);
impl<F, B> Future for BoxResponseBody<F>
where
F: Future<Output = Result<Response<B>, Infallible>>,
B: http_body::Body + Send + Sync + 'static,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError>,
{
type Output = Result<Response<BoxBody<B::Data, BoxError>>, Infallible>;
type Output = Result<Response<BoxBody>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let response: Response<B> = ready!(self.project().0.poll(cx)).unwrap_infallible();

View file

@ -1,4 +1,4 @@
use crate::{app, extract};
use crate::{app, extract, handler::Handler};
use http::{Request, Response, StatusCode};
use hyper::{Body, Server};
use serde::Deserialize;
@ -273,9 +273,89 @@ async fn boxing() {
assert_eq!(res.text().await.unwrap(), "hi from POST");
}
// TODO(david): tests for adding middleware to single services
#[tokio::test]
async fn service_handlers() {
use crate::{body::BoxBody, ServiceExt as _};
use std::convert::Infallible;
use tower::service_fn;
use tower_http::services::ServeFile;
// TODO(david): tests for nesting services
let app = app()
.at("/echo")
.post_service(service_fn(|req: Request<Body>| async move {
Ok::<_, Infallible>(Response::new(req.into_body()))
}))
// calling boxed isn't necessary here but done so
// we're sure it compiles
.boxed()
.at("/static/Cargo.toml")
.get_service(
ServeFile::new("Cargo.toml").handle_error(|error: std::io::Error| {
// `ServeFile` internally maps some errors to `404` so we don't have
// to handle those here
let body = BoxBody::from(error.to_string());
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(body)
.unwrap()
}),
)
// calling boxed isn't necessary here but done so
// we're sure it compiles
.boxed()
.into_service();
let addr = run_in_background(app).await;
let client = reqwest::Client::new();
let res = client
.post(format!("http://{}/echo", addr))
.body("foobar")
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.unwrap(), "foobar");
let res = client
.get(format!("http://{}/static/Cargo.toml", addr))
.body("foobar")
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert!(res.text().await.unwrap().contains("edition ="));
}
#[tokio::test]
async fn middleware_on_single_route() {
use tower::ServiceBuilder;
use tower_http::{compression::CompressionLayer, trace::TraceLayer};
async fn handle(_: Request<Body>) -> &'static str {
"Hello, World!"
}
let app = app()
.at("/")
.get(
handle.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
.into_inner(),
),
)
.into_service();
let addr = run_in_background(app).await;
let res = reqwest::get(format!("http://{}", addr)).await.unwrap();
let body = res.text().await.unwrap();
assert_eq!(body, "Hello, World!");
}
/// Run a `tower::Service` in the background and get a URI for it.
pub async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr