1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-04-03 21:15:55 +02:00

More error handling of layered handlers

This commit is contained in:
David Pedersen 2021-06-01 17:17:10 +02:00
parent d7a0715188
commit 8ee3119fb0
3 changed files with 76 additions and 34 deletions

View file

@ -1,5 +1,6 @@
use crate::{body::Body, extract::FromRequest, response::IntoResponse};
use crate::{body::Body, HandleError, extract::FromRequest, response::IntoResponse};
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::future;
use http::{Request, Response};
use std::{
@ -8,7 +9,7 @@ use std::{
marker::PhantomData,
task::{Context, Poll},
};
use tower::{Layer, Service, ServiceExt};
use tower::{Layer, BoxError, Service, ServiceExt};
mod sealed {
pub trait HiddentTrait {}
@ -141,6 +142,18 @@ impl<S, T> Layered<S, T> {
_input: PhantomData,
}
}
pub fn handle_error<F, B, Res>(self, f: F) -> Layered<HandleError<S, F, S::Error>, T>
where
S: Service<Request<Body>, Response = Response<B>>,
F: FnOnce(S::Error) -> Res,
Res: IntoResponse<Body>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{
let svc = HandleError::new(self.svc, f);
Layered::new(svc)
}
}
pub struct HandlerSvc<H, B, T> {

View file

@ -7,6 +7,7 @@ use bytes::Bytes;
use futures_util::ready;
use http::{Request, Response};
use pin_project::pin_project;
use response::IntoResponse;
use std::{
convert::Infallible,
fmt,
@ -22,8 +23,8 @@ pub mod handler;
pub mod response;
pub mod routing;
pub use tower_http::add_extension::{AddExtension, AddExtensionLayer};
pub use async_trait::async_trait;
pub use tower_http::add_extension::{AddExtension, AddExtensionLayer};
#[cfg(test)]
mod tests;
@ -114,20 +115,15 @@ impl<T> ResultExt<T> for Result<T, Infallible> {
pub struct BoxStdError(#[source] pub(crate) tower::BoxError);
pub trait ServiceExt<B>: Service<Request<Body>, Response = Response<B>> {
fn handle_error<F, NewBody>(self, f: F) -> HandleError<Self, F, Self::Error>
fn handle_error<F, Res>(self, f: F) -> HandleError<Self, F, Self::Error>
where
Self: Sized,
F: FnOnce(Self::Error) -> Response<NewBody>,
F: FnOnce(Self::Error) -> Res,
Res: IntoResponse<Body>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
NewBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
NewBody::Error: Into<BoxError> + Send + Sync + 'static,
{
HandleError {
inner: self,
f,
poll_ready_error: None,
}
HandleError::new(self, f)
}
}
@ -139,6 +135,16 @@ pub struct HandleError<S, F, E> {
poll_ready_error: Option<E>,
}
impl<S, F, E> HandleError<S, F, E> {
pub(crate) fn new(inner: S, f: F) -> Self {
Self {
inner,
f,
poll_ready_error: None,
}
}
}
impl<S, F, E> fmt::Debug for HandleError<S, F, E>
where
S: fmt::Debug,
@ -167,14 +173,13 @@ where
}
}
impl<S, F, B, NewBody> Service<Request<Body>> for HandleError<S, F, S::Error>
impl<S, F, B, Res> Service<Request<Body>> for HandleError<S, F, S::Error>
where
S: Service<Request<Body>, Response = Response<B>>,
F: FnOnce(S::Error) -> Response<NewBody> + Clone,
F: FnOnce(S::Error) -> Res + Clone,
Res: IntoResponse<Body>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
NewBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
NewBody::Error: Into<BoxError> + Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = Infallible;
@ -218,14 +223,13 @@ enum Kind<Fut, E> {
Error(Option<E>),
}
impl<Fut, F, E, B, NewBody> Future for HandleErrorFuture<Fut, F, E>
impl<Fut, F, E, B, Res> Future for HandleErrorFuture<Fut, F, E>
where
Fut: Future<Output = Result<Response<B>, E>>,
F: FnOnce(E) -> Response<NewBody>,
F: FnOnce(E) -> Res,
Res: IntoResponse<Body>,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
NewBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
NewBody::Error: Into<BoxError> + Send + Sync + 'static,
{
type Output = Result<Response<BoxBody>, Infallible>;
@ -237,13 +241,13 @@ where
Ok(res) => Ok(res.map(BoxBody::new)).into(),
Err(err) => {
let f = this.f.take().unwrap();
let res = f(err);
let res = f(err).into_response();
Ok(res.map(BoxBody::new)).into()
}
},
KindProj::Error(err) => {
let f = this.f.take().unwrap();
let res = f(err.take().unwrap());
let res = f(err.take().unwrap()).into_response();
Ok(res.map(BoxBody::new)).into()
}
}

View file

@ -3,7 +3,10 @@ use http::{Request, Response, StatusCode};
use hyper::{Body, Server};
use serde::Deserialize;
use serde_json::json;
use std::net::{SocketAddr, TcpListener};
use std::{
net::{SocketAddr, TcpListener},
time::Duration,
};
use tower::{make::Shared, BoxError, Service};
#[tokio::test]
@ -77,9 +80,7 @@ async fn consume_body_to_json_requires_json_content_type() {
let app = app()
.at("/")
.post(|_: Request<Body>, input: extract::Json<Input>| async {
input.0.foo
})
.post(|_: Request<Body>, input: extract::Json<Input>| async { input.0.foo })
.into_service();
let addr = run_in_background(app).await;
@ -274,7 +275,7 @@ async fn boxing() {
#[tokio::test]
async fn service_handlers() {
use crate::{body::BoxBody, ServiceExt as _};
use crate::ServiceExt as _;
use std::convert::Infallible;
use tower::service_fn;
use tower_http::services::ServeFile;
@ -290,13 +291,7 @@ async fn service_handlers() {
.at("/static/Cargo.toml")
.get_service(
ServeFile::new("Cargo.toml").handle_error(|error: std::io::Error| {
// `ServeFile` internally maps some errors to `404` so we don't have
// to handle those here
let body = BoxBody::from(error.to_string());
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(body)
.unwrap()
(StatusCode::INTERNAL_SERVER_ERROR, error.to_string())
}),
)
// calling boxed isn't necessary here but done so
@ -356,6 +351,36 @@ async fn middleware_on_single_route() {
assert_eq!(body, "Hello, World!");
}
#[tokio::test]
async fn handling_errors_from_layered_single_routes() {
use tower::timeout::TimeoutLayer;
async fn handle(_req: Request<Body>) -> &'static str {
tokio::time::sleep(Duration::from_secs(10)).await;
""
}
let app = app()
.at("/")
.get(
handle
.layer(TimeoutLayer::new(Duration::from_millis(100)))
.handle_error(|_error: BoxError| StatusCode::INTERNAL_SERVER_ERROR),
)
.into_service();
let addr = run_in_background(app).await;
let res = reqwest::get(format!("http://{}", addr)).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
// TODO(david): .layer() on RouteBuilder
// TODO(david): composing two apps
// TODO(david): composing two apps with one at a "sub path"
// TODO(david): composing two boxed apps
// TODO(david): composing two apps that have had layers applied
/// Run a `tower::Service` in the background and get a URI for it.
pub async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
where