More error handling

This commit is contained in:
David Pedersen 2021-06-01 00:34:09 +02:00
parent 18f613ff98
commit 0e38037c74
10 changed files with 222 additions and 228 deletions

View file

@ -2,7 +2,7 @@
This is *not* https://github.com/carllerche/tower-web even though the name is
the same. Its just a prototype of a minimal HTTP framework I've been toying
with.
with. Will probably change the name to something else.
# What is this?

View file

@ -1,9 +1,6 @@
#![allow(warnings)]
use bytes::Bytes;
use http::{Request, Response, StatusCode};
use http::{Request, StatusCode};
use hyper::Server;
use serde::Deserialize;
use std::{
collections::HashMap,
net::SocketAddr,
@ -14,12 +11,7 @@ use tower::{make::Shared, ServiceBuilder};
use tower_http::{
add_extension::AddExtensionLayer, compression::CompressionLayer, trace::TraceLayer,
};
use tower_web::{
body::Body,
extract,
response::{self, IntoResponse},
Error,
};
use tower_web::{body::Body, extract};
#[tokio::main]
async fn main() {
@ -59,7 +51,11 @@ async fn get(
_req: Request<Body>,
params: extract::UrlParams<(String,)>,
state: extract::Extension<SharedState>,
) -> Result<Bytes, NotFound> {
// Anything that implements `IntoResponse` can be used a response
//
// Handlers cannot return errors. Everything will be converted
// into a response. `BoxError` becomes `500 Internal server error`
) -> Result<Bytes, StatusCode> {
let state = state.into_inner();
let db = &state.lock().unwrap().db;
@ -68,7 +64,7 @@ async fn get(
if let Some(value) = db.get(&key) {
Ok(value.clone())
} else {
Err(NotFound)
Err(StatusCode::NOT_FOUND)
}
}
@ -77,6 +73,8 @@ async fn set(
params: extract::UrlParams<(String,)>,
value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
state: extract::Extension<SharedState>,
// `()` also implements `IntoResponse` so we can use that to return
// an empty response
) {
let state = state.into_inner();
let db = &mut state.lock().unwrap().db;
@ -84,16 +82,5 @@ async fn set(
let key = params.into_inner();
let value = value.into_inner();
db.insert(key.to_string(), value);
}
struct NotFound;
impl IntoResponse<Body> for NotFound {
fn into_response(self) -> Response<Body> {
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap()
}
db.insert(key, value);
}

View file

@ -1,4 +1,5 @@
use bytes::Buf;
use futures_util::ready;
use http_body::{Body as _, Empty};
use std::{
fmt,
@ -8,6 +9,8 @@ use std::{
pub use hyper::body::Body;
use crate::BoxStdError;
/// A boxed [`Body`] trait object.
pub struct BoxBody<D, E> {
inner: Pin<Box<dyn http_body::Body<Data = D, Error = E> + Send + Sync + 'static>>,
@ -42,25 +45,34 @@ impl<D, E> fmt::Debug for BoxBody<D, E> {
}
}
// when we've gotten rid of `BoxStdError` then we can remove this
impl<D, E> http_body::Body for BoxBody<D, E>
where
D: Buf,
E: Into<tower::BoxError>,
{
type Data = D;
type Error = E;
type Error = BoxStdError;
fn poll_data(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
self.inner.as_mut().poll_data(cx)
match ready!(self.inner.as_mut().poll_data(cx)) {
Some(Ok(chunk)) => Some(Ok(chunk)).into(),
Some(Err(err)) => Some(Err(BoxStdError(err.into()))).into(),
None => None.into(),
}
}
fn poll_trailers(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
self.inner.as_mut().poll_trailers(cx)
match ready!(self.inner.as_mut().poll_trailers(cx)) {
Ok(trailers) => Ok(trailers).into(),
Err(err) => Err(BoxStdError(err.into())).into(),
}
}
fn is_end_stream(&self) -> bool {
@ -71,3 +83,14 @@ where
self.inner.size_hint()
}
}
impl From<String> for BoxBody<bytes::Bytes, tower::BoxError> {
fn from(s: String) -> Self {
let body = hyper::Body::from(s);
let body = body.map_err(Into::<tower::BoxError>::into);
BoxBody {
inner: Box::pin(body),
}
}
}

View file

@ -1,104 +0,0 @@
use std::convert::Infallible;
use http::{Response, StatusCode};
use tower::BoxError;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
#[error("failed to deserialize the request body")]
DeserializeRequestBody(#[source] serde_json::Error),
#[error("failed to serialize the response body")]
SerializeResponseBody(#[source] serde_json::Error),
#[error("failed to consume the body")]
ConsumeRequestBody(#[source] hyper::Error),
#[error("URI contained no query string")]
QueryStringMissing,
#[error("failed to deserialize query string")]
DeserializeQueryString(#[source] serde_urlencoded::de::Error),
#[error("failed generating the response body")]
ResponseBody(#[source] BoxError),
#[error("some dynamic error happened")]
Dynamic(#[source] BoxError),
#[error("request extension of type `{type_name}` was not set")]
MissingExtension { type_name: &'static str },
#[error("`Content-Length` header is missing but was required")]
LengthRequired,
#[error("response body was too large")]
PayloadTooLarge,
#[error("response failed with status {0}")]
Status(StatusCode),
#[error("invalid URL param. Expected something of type `{type_name}`")]
InvalidUrlParam { type_name: &'static str },
#[error("unknown URL param `{0}`")]
UnknownUrlParam(String),
#[error("response body didn't contain valid UTF-8")]
InvalidUtf8,
}
impl From<BoxError> for Error {
fn from(err: BoxError) -> Self {
match err.downcast::<Error>() {
Ok(err) => *err,
Err(err) => Error::Dynamic(err),
}
}
}
impl From<Infallible> for Error {
fn from(err: Infallible) -> Self {
match err {}
}
}
pub(crate) fn handle_error<B>(error: Error) -> Result<Response<B>, Error>
where
B: Default,
{
fn make_response<B>(status: StatusCode) -> Result<Response<B>, Error>
where
B: Default,
{
let mut res = Response::new(B::default());
*res.status_mut() = status;
Ok(res)
}
match error {
Error::DeserializeRequestBody(_)
| Error::QueryStringMissing
| Error::DeserializeQueryString(_)
| Error::InvalidUrlParam { .. }
| Error::InvalidUtf8 => make_response(StatusCode::BAD_REQUEST),
Error::Status(status) => make_response(status),
Error::LengthRequired => make_response(StatusCode::LENGTH_REQUIRED),
Error::PayloadTooLarge => make_response(StatusCode::PAYLOAD_TOO_LARGE),
Error::MissingExtension { .. }
| Error::SerializeResponseBody(_)
| Error::UnknownUrlParam(_) => make_response(StatusCode::INTERNAL_SERVER_ERROR),
Error::Dynamic(err) => match err.downcast::<Error>() {
Ok(err) => Err(*err),
Err(err) => Err(Error::Dynamic(err)),
},
err @ Error::ConsumeRequestBody(_) => Err(err),
err @ Error::ResponseBody(_) => Err(err),
}
}

View file

@ -1,7 +1,6 @@
use crate::{
body::Body,
response::{BoxIntoResponse, IntoResponse},
Error,
};
use async_trait::async_trait;
use bytes::Bytes;
@ -315,21 +314,15 @@ define_rejection! {
pub struct UrlParamsMap(HashMap<String, String>);
impl UrlParamsMap {
pub fn get(&self, key: &str) -> Result<&str, Error> {
if let Some(value) = self.0.get(key) {
Ok(value)
} else {
Err(Error::UnknownUrlParam(key.to_string()))
}
pub fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).map(|s| &**s)
}
pub fn get_typed<T>(&self, key: &str) -> Result<T, Error>
pub fn get_typed<T>(&self, key: &str) -> Option<T>
where
T: FromStr,
{
self.get(key)?.parse().map_err(|_| Error::InvalidUrlParam {
type_name: std::any::type_name::<T>(),
})
self.get(key)?.parse().ok()
}
}

View file

@ -1,4 +1,4 @@
use crate::{body::Body, error::Error, extract::FromRequest, response::IntoResponse};
use crate::{body::Body, extract::FromRequest, response::IntoResponse};
use async_trait::async_trait;
use futures_util::future;
use http::{Request, Response};
@ -8,7 +8,7 @@ use std::{
marker::PhantomData,
task::{Context, Poll},
};
use tower::{BoxError, Layer, Service, ServiceExt};
use tower::{Layer, Service, ServiceExt};
mod sealed {
pub trait HiddentTrait {}
@ -30,6 +30,8 @@ pub trait Handler<B, In>: Sized {
fn layer<L>(self, layer: L) -> Layered<L::Service, In>
where
L: Layer<HandlerSvc<Self, B, In>>,
<L as Layer<HandlerSvc<Self, B, In>>>::Service: Service<Request<Body>>,
<<L as Layer<HandlerSvc<Self, B, In>>>::Service as Service<Request<Body>>>::Error: IntoResponse<B>,
{
Layered::new(layer.layer(HandlerSvc::new(self)))
}

View file

@ -7,6 +7,7 @@ use futures_util::ready;
use http::Response;
use pin_project::pin_project;
use std::{
convert::Infallible,
future::Future,
pin::Pin,
task::{Context, Poll},
@ -19,13 +20,9 @@ pub mod handler;
pub mod response;
pub mod routing;
mod error;
#[cfg(test)]
mod tests;
pub use self::error::Error;
pub fn app() -> App<AlwaysNotFound> {
App {
service_tree: AlwaysNotFound(()),
@ -52,7 +49,6 @@ impl<R> App<R> {
pub struct IntoService<R> {
app: App<R>,
poll_ready_error: Option<Error>,
}
impl<R> Clone for IntoService<R>
@ -62,71 +58,67 @@ where
fn clone(&self) -> Self {
Self {
app: self.app.clone(),
poll_ready_error: None,
}
}
}
impl<R, B, T> Service<T> for IntoService<R>
where
R: Service<T, Response = Response<B>>,
R::Error: Into<Error>,
R: Service<T, Response = Response<B>, Error = Infallible>,
B: Default,
{
type Response = Response<B>;
type Error = Error;
type Future = HandleErrorFuture<R::Future, B>;
type Error = Infallible;
type Future = HandleErrorFuture<R::Future>;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if let Err(err) = ready!(self.app.service_tree.poll_ready(cx)).map_err(Into::into) {
self.poll_ready_error = Some(err);
match ready!(self.app.service_tree.poll_ready(cx)) {
Ok(_) => Poll::Ready(Ok(())),
Err(err) => match err {},
}
Poll::Ready(Ok(()))
}
fn call(&mut self, req: T) -> Self::Future {
if let Some(poll_ready_error) = self.poll_ready_error.take() {
match error::handle_error::<B>(poll_ready_error) {
Ok(res) => {
return HandleErrorFuture(Kind::Response(Some(res)));
}
Err(err) => {
return HandleErrorFuture(Kind::Error(Some(err)));
}
}
}
HandleErrorFuture(Kind::Future(self.app.service_tree.call(req)))
HandleErrorFuture(self.app.service_tree.call(req))
}
}
#[pin_project]
pub struct HandleErrorFuture<F, B>(#[pin] Kind<F, B>);
pub struct HandleErrorFuture<F>(#[pin] F);
#[pin_project(project = KindProj)]
enum Kind<F, B> {
Response(Option<Response<B>>),
Error(Option<Error>),
Future(#[pin] F),
}
impl<F, B, E> Future for HandleErrorFuture<F, B>
impl<F, B> Future for HandleErrorFuture<F>
where
F: Future<Output = Result<Response<B>, E>>,
E: Into<Error>,
F: Future<Output = Result<Response<B>, Infallible>>,
B: Default,
{
type Output = Result<Response<B>, Error>;
type Output = Result<Response<B>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().0.project() {
KindProj::Response(res) => Poll::Ready(Ok(res.take().unwrap())),
KindProj::Error(err) => Poll::Ready(Err(err.take().unwrap())),
KindProj::Future(fut) => match ready!(fut.poll(cx)) {
Ok(res) => Poll::Ready(Ok(res)),
Err(err) => Poll::Ready(error::handle_error(err.into())),
},
self.project().0.poll(cx)
}
}
pub(crate) trait ResultExt<T> {
fn unwrap_infallible(self) -> T;
}
impl<T> ResultExt<T> for Result<T, Infallible> {
fn unwrap_infallible(self) -> T {
match self {
Ok(value) => value,
Err(err) => match err {},
}
}
}
// work around for `BoxError` not implementing `std::error::Error`
//
// This is currently required since tower-http's Compression middleware's body type's
// error only implements error when the inner error type does:
// https://github.com/tower-rs/tower-http/blob/master/tower-http/src/lib.rs#L310
//
// Fixing that is a breaking change to tower-http so we should wait a bit, but should
// totally fix it at some point.
#[derive(Debug, thiserror::Error)]
#[error("{0}")]
pub struct BoxStdError(#[source] tower::BoxError);

View file

@ -1,6 +1,6 @@
use crate::Body;
use bytes::Bytes;
use http::{header, HeaderValue, Response, StatusCode};
use http::{HeaderMap, HeaderValue, Response, StatusCode, header};
use serde::Serialize;
use std::convert::Infallible;
use tower::{util::Either, BoxError};
@ -176,9 +176,45 @@ impl<B> IntoResponse<B> for BoxIntoResponse<B> {
impl IntoResponse<Body> for BoxError {
fn into_response(self) -> Response<Body> {
// TODO(david): test for know error types like std::io::Error
// or common errors types from tower and map those more appropriately
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from(self.to_string()))
.unwrap()
}
}
impl<B> IntoResponse<B> for StatusCode
where
B: Default,
{
fn into_response(self) -> Response<B> {
Response::builder().status(self).body(B::default()).unwrap()
}
}
impl<T> IntoResponse<Body> for (StatusCode, T)
where
T: Into<Body>,
{
fn into_response(self) -> Response<Body> {
Response::builder()
.status(self.0)
.body(self.1.into())
.unwrap()
}
}
impl<T> IntoResponse<Body> for (StatusCode, HeaderMap, T)
where
T: Into<Body>,
{
fn into_response(self) -> Response<Body> {
let mut res = Response::new(self.2.into());
*res.status_mut() = self.0;
*res.headers_mut() = self.1;
res
}
}

View file

@ -1,8 +1,7 @@
use crate::{
body::{Body, BoxBody},
error::Error,
handler::{Handler, HandlerSvc},
App, IntoService,
App, IntoService, ResultExt,
};
use bytes::Bytes;
use futures_util::{future, ready};
@ -164,17 +163,18 @@ impl<R> RouteBuilder<R> {
}
pub fn into_service(self) -> IntoService<R> {
IntoService {
app: self.app,
poll_ready_error: None,
}
IntoService { app: self.app }
}
// TODO(david): Add `layer` method here that applies a `tower::Layer` inside the service tree
// that way we get to map errors
pub fn boxed<B>(self) -> RouteBuilder<BoxServiceTree<B>>
where
R: Service<Request<Body>, Response = Response<B>, Error = Error> + Send + 'static,
R: Service<Request<Body>, Response = Response<B>, Error = Infallible> + Send + 'static,
R::Future: Send,
B: Default + 'static,
// TODO(david): do we still need default here
B: Default + From<String> + 'static,
{
let svc = ServiceBuilder::new()
.layer(BufferLayer::new(1024))
@ -182,7 +182,10 @@ impl<R> RouteBuilder<R> {
.service(self.app.service_tree);
let app = App {
service_tree: BoxServiceTree { inner: svc },
service_tree: BoxServiceTree {
inner: svc,
poll_ready_error: None,
},
};
RouteBuilder {
@ -270,29 +273,27 @@ impl RouteSpec {
impl<H, F, HB, FB> Service<Request<Body>> for Or<H, F>
where
H: Service<Request<Body>, Response = Response<HB>>,
H::Error: Into<Error>,
H: Service<Request<Body>, Response = Response<HB>, Error = Infallible>,
HB: http_body::Body + Send + Sync + 'static,
HB::Error: Into<BoxError>,
F: Service<Request<Body>, Response = Response<FB>>,
F::Error: Into<Error>,
F: Service<Request<Body>, Response = Response<FB>, Error = Infallible>,
FB: http_body::Body<Data = HB::Data> + Send + Sync + 'static,
FB::Error: Into<BoxError>,
{
type Response = Response<BoxBody<HB::Data, Error>>;
type Error = Error;
type Response = Response<BoxBody<HB::Data, BoxError>>;
type Error = Infallible;
type Future = future::Either<BoxResponseBody<H::Future>, BoxResponseBody<F::Future>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
loop {
if !self.handler_ready {
ready!(self.service.poll_ready(cx)).map_err(Into::into)?;
ready!(self.service.poll_ready(cx)).unwrap_infallible();
self.handler_ready = true;
}
if !self.fallback_ready {
ready!(self.fallback.poll_ready(cx)).map_err(Into::into)?;
ready!(self.fallback.poll_ready(cx)).unwrap_infallible();
self.fallback_ready = true;
}
@ -333,20 +334,18 @@ pub(crate) struct UrlParams(pub(crate) Vec<(String, String)>);
#[pin_project]
pub struct BoxResponseBody<F>(#[pin] F);
impl<F, B, E> Future for BoxResponseBody<F>
impl<F, B> Future for BoxResponseBody<F>
where
F: Future<Output = Result<Response<B>, E>>,
E: Into<Error>,
F: Future<Output = Result<Response<B>, Infallible>>,
B: http_body::Body + Send + Sync + 'static,
B::Error: Into<BoxError>,
{
type Output = Result<Response<BoxBody<B::Data, Error>>, Error>;
type Output = Result<Response<BoxBody<B::Data, BoxError>>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let response: Response<B> = ready!(self.project().0.poll(cx)).map_err(Into::into)?;
let response: Response<B> = ready!(self.project().0.poll(cx)).unwrap_infallible();
let response = response.map(|body| {
// TODO(david): attempt to downcast this into `Error`
let body = body.map_err(|err| Error::ResponseBody(err.into()));
let body = body.map_err(Into::into);
BoxBody::new(body)
});
Poll::Ready(Ok(response))
@ -354,13 +353,15 @@ where
}
pub struct BoxServiceTree<B> {
inner: Buffer<BoxService<Request<Body>, Response<B>, Error>, Request<Body>>,
inner: Buffer<BoxService<Request<Body>, Response<B>, Infallible>, Request<Body>>,
poll_ready_error: Option<BoxError>,
}
impl<B> Clone for BoxServiceTree<B> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
poll_ready_error: None,
}
}
}
@ -373,21 +374,36 @@ impl<B> fmt::Debug for BoxServiceTree<B> {
impl<B> Service<Request<Body>> for BoxServiceTree<B>
where
B: 'static,
B: From<String> + 'static,
{
type Response = Response<B>;
type Error = Error;
type Error = Infallible;
type Future = BoxServiceTreeResponseFuture<B>;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Error::from)
// TODO(david): downcast this into one of the cases in `tower::buffer::error`
// and convert the error into a response. `ServiceError` should never be able to happen
// since all inner services use `Infallible` as the error type.
match ready!(self.inner.poll_ready(cx)) {
Ok(_) => Poll::Ready(Ok(())),
Err(err) => {
self.poll_ready_error = Some(err);
Poll::Ready(Ok(()))
}
}
}
#[inline]
fn call(&mut self, req: Request<Body>) -> Self::Future {
if let Some(err) = self.poll_ready_error.take() {
return BoxServiceTreeResponseFuture {
kind: Kind::Response(Some(handle_buffer_error(err))),
};
}
BoxServiceTreeResponseFuture {
inner: self.inner.call(req),
kind: Kind::Future(self.inner.call(req)),
}
}
}
@ -395,24 +411,71 @@ where
#[pin_project]
pub struct BoxServiceTreeResponseFuture<B> {
#[pin]
inner: InnerFuture<B>,
kind: Kind<B>,
}
#[pin_project(project = KindProj)]
enum Kind<B> {
Response(Option<Response<B>>),
Future(#[pin] InnerFuture<B>),
}
type InnerFuture<B> = tower::buffer::future::ResponseFuture<
Pin<Box<dyn Future<Output = Result<Response<B>, Error>> + Send + 'static>>,
Pin<Box<dyn Future<Output = Result<Response<B>, Infallible>> + Send + 'static>>,
>;
impl<B> Future for BoxServiceTreeResponseFuture<B> {
type Output = Result<Response<B>, Error>;
impl<B> Future for BoxServiceTreeResponseFuture<B>
where
B: From<String>,
{
type Output = Result<Response<B>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project()
.inner
.poll(cx)
.map_err(Error::from)
match self.project().kind.project() {
KindProj::Response(res) => Poll::Ready(Ok(res.take().unwrap())),
KindProj::Future(future) => match ready!(future.poll(cx)) {
Ok(res) => Poll::Ready(Ok(res)),
Err(err) => Poll::Ready(Ok(handle_buffer_error(err))),
},
}
}
}
fn handle_buffer_error<B>(error: BoxError) -> Response<B>
where
B: From<String>,
{
use tower::buffer::error::{Closed, ServiceError};
let error = match error.downcast::<Closed>() {
Ok(closed) => {
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(B::from(closed.to_string()))
.unwrap();
}
Err(e) => e,
};
let error = match error.downcast::<ServiceError>() {
Ok(service_error) => {
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(B::from(format!("Service error: {}. This is a bug in tower-web. All inner services should be infallible. Please file an issue", service_error)))
.unwrap();
}
Err(e) => e,
};
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(B::from(format!(
"Uncountered an unknown error: {}. This should never happen. Please file an issue",
error
)))
.unwrap()
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]

View file

@ -273,6 +273,8 @@ async fn boxing() {
assert_eq!(res.text().await.unwrap(), "hi from POST");
}
// TODO(david): tests for adding middleware to single services
/// Run a `tower::Service` in the background and get a URI for it.
pub async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
where