diff --git a/CHANGELOG.md b/CHANGELOG.md
index 803e35ac..d236d9a6 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -13,8 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Adding a conflicting route will now cause a panic instead of silently making
a route unreachable.
- Route matching is faster as number of routes increase.
- - The routes `/foo` and `/:key` are considered to overlap and will cause a
- panic when constructing the router. This might be fixed in the future.
+ - **breaking:** The routes `/foo` and `/:key` are considered to overlap and
+ will cause a panic when constructing the router. This might be fixed in the future.
- Improve performance of `BoxRoute` ([#339])
- Expand accepted content types for JSON requests ([#378])
- **breaking:** Automatically do percent decoding in `extract::Path`
@@ -28,6 +28,73 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
thus wasn't necessary
- **breaking:** Change `Connected::connect_info` to return `Self` and remove
the associated type `ConnectInfo` ([#396])
+- **breaking:** Simplify error handling model ([#402]):
+ - All services part of the router are now required to be infallible.
+ - Error handling utilities have been moved to an `error_handling` module.
+ - `Router::check_infallible` has been removed since routers are always
+ infallible with the error handling changes.
+ - Error handling closures must now handle all errors and thus always return
+ something that implements `IntoResponse`.
+
+ With these changes handling errors from fallible middleware is done like so:
+
+ ```rust,no_run
+ use axum::{
+ handler::get,
+ http::StatusCode,
+ error_handling::HandleErrorLayer,
+ response::IntoResponse,
+ Router, BoxError,
+ };
+ use tower::ServiceBuilder;
+ use std::time::Duration;
+
+ let middleware_stack = ServiceBuilder::new()
+ // Handle errors from middleware
+ //
+ // This middleware most be added above any fallible
+ // ones if you're using `ServiceBuilder`, due to how ordering works
+ .layer(HandleErrorLayer::new(handle_error))
+ // Return an error after 30 seconds
+ .timeout(Duration::from_secs(30));
+
+ let app = Router::new()
+ .route("/", get(|| async { /* ... */ }))
+ .layer(middleware_stack);
+
+ fn handle_error(_error: BoxError) -> impl IntoResponse {
+ StatusCode::REQUEST_TIMEOUT
+ }
+ ```
+
+ And handling errors from fallible leaf services is done like so:
+
+ ```rust
+ use axum::{
+ Router, service,
+ body::Body,
+ handler::get,
+ response::IntoResponse,
+ http::{Request, Response},
+ error_handling::HandleErrorExt, // for `.handle_error`
+ };
+ use std::{io, convert::Infallible};
+ use tower::service_fn;
+
+ let app = Router::new()
+ .route(
+ "/",
+ service::get(service_fn(|_req: Request
| async {
+ let contents = tokio::fs::read_to_string("some_file").await?;
+ Ok::<_, io::Error>(Response::new(Body::from(contents)))
+ }))
+ .handle_error(handle_io_error),
+ );
+
+ fn handle_io_error(error: io::Error) -> impl IntoResponse {
+ // ...
+ }
+ ```
[#339]: https://github.com/tokio-rs/axum/pull/339
[#286]: https://github.com/tokio-rs/axum/pull/286
@@ -35,6 +102,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#378]: https://github.com/tokio-rs/axum/pull/378
[#363]: https://github.com/tokio-rs/axum/pull/363
[#396]: https://github.com/tokio-rs/axum/pull/396
+[#402]: https://github.com/tokio-rs/axum/pull/402
# 0.2.8 (07. October, 2021)
diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs
index b8153181..c768b441 100644
--- a/examples/key-value-store/src/main.rs
+++ b/examples/key-value-store/src/main.rs
@@ -8,6 +8,7 @@
use axum::{
body::Bytes,
+ error_handling::HandleErrorLayer,
extract::{ContentLengthLimit, Extension, Path},
handler::{delete, get, Handler},
http::StatusCode,
@@ -18,7 +19,6 @@ use axum::{
use std::{
borrow::Cow,
collections::HashMap,
- convert::Infallible,
net::SocketAddr,
sync::{Arc, RwLock},
time::Duration,
@@ -52,16 +52,15 @@ async fn main() {
// Add middleware to all routes
.layer(
ServiceBuilder::new()
+ // Handle errors from middleware
+ .layer(HandleErrorLayer::new(handle_error))
.load_shed()
.concurrency_limit(1024)
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
.layer(AddExtensionLayer::new(SharedState::default()))
.into_inner(),
- )
- // Handle errors from middleware
- .handle_error(handle_error)
- .check_infallible();
+ );
// Run our app with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
@@ -126,20 +125,20 @@ fn admin_routes() -> Router {
.boxed()
}
-fn handle_error(error: BoxError) -> Result {
+fn handle_error(error: BoxError) -> impl IntoResponse {
if error.is::() {
- return Ok((StatusCode::REQUEST_TIMEOUT, Cow::from("request timed out")));
+ return (StatusCode::REQUEST_TIMEOUT, Cow::from("request timed out"));
}
if error.is::() {
- return Ok((
+ return (
StatusCode::SERVICE_UNAVAILABLE,
Cow::from("service is overloaded, try again later"),
- ));
+ );
}
- Ok((
+ (
StatusCode::INTERNAL_SERVER_ERROR,
Cow::from(format!("Unhandled internal error: {}", error)),
- ))
+ )
}
diff --git a/examples/print-request-response/src/main.rs b/examples/print-request-response/src/main.rs
index 15fc7f6a..023e3c69 100644
--- a/examples/print-request-response/src/main.rs
+++ b/examples/print-request-response/src/main.rs
@@ -6,12 +6,13 @@
use axum::{
body::{Body, BoxBody, Bytes},
+ error_handling::HandleErrorLayer,
handler::post,
- http::{Request, Response},
+ http::{Request, Response, StatusCode},
Router,
};
use std::net::SocketAddr;
-use tower::{filter::AsyncFilterLayer, util::AndThenLayer, BoxError};
+use tower::{filter::AsyncFilterLayer, util::AndThenLayer, BoxError, ServiceBuilder};
#[tokio::main]
async fn main() {
@@ -26,8 +27,17 @@ async fn main() {
let app = Router::new()
.route("/", post(|| async move { "Hello from `POST /`" }))
- .layer(AsyncFilterLayer::new(map_request))
- .layer(AndThenLayer::new(map_response));
+ .layer(
+ ServiceBuilder::new()
+ .layer(HandleErrorLayer::new(|error| {
+ (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ format!("Unhandled internal error: {}", error),
+ )
+ }))
+ .layer(AndThenLayer::new(map_response))
+ .layer(AsyncFilterLayer::new(map_request)),
+ );
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
diff --git a/examples/sse/src/main.rs b/examples/sse/src/main.rs
index 53c49feb..406ad4c0 100644
--- a/examples/sse/src/main.rs
+++ b/examples/sse/src/main.rs
@@ -5,6 +5,7 @@
//! ```
use axum::{
+ error_handling::HandleErrorExt,
extract::TypedHeader,
handler::get,
http::StatusCode,
@@ -28,10 +29,10 @@ async fn main() {
ServeDir::new("examples/sse/assets").append_index_html_on_directories(true),
)
.handle_error(|error: std::io::Error| {
- Ok::<_, std::convert::Infallible>((
+ (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
- ))
+ )
});
// build our application with a route
diff --git a/examples/static-file-server/src/main.rs b/examples/static-file-server/src/main.rs
index dbdc2904..ec449788 100644
--- a/examples/static-file-server/src/main.rs
+++ b/examples/static-file-server/src/main.rs
@@ -4,8 +4,8 @@
//! cargo run -p example-static-file-server
//! ```
-use axum::{http::StatusCode, service, Router};
-use std::{convert::Infallible, net::SocketAddr};
+use axum::{error_handling::HandleErrorExt, http::StatusCode, service, Router};
+use std::net::SocketAddr;
use tower_http::{services::ServeDir, trace::TraceLayer};
#[tokio::main]
@@ -23,10 +23,10 @@ async fn main() {
.nest(
"/static",
service::get(ServeDir::new(".")).handle_error(|error: std::io::Error| {
- Ok::<_, Infallible>((
+ (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
- ))
+ )
}),
)
.layer(TraceLayer::new_for_http());
diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs
index b5a204ef..5e4ac543 100644
--- a/examples/todos/src/main.rs
+++ b/examples/todos/src/main.rs
@@ -14,6 +14,7 @@
//! ```
use axum::{
+ error_handling::HandleErrorLayer,
extract::{Extension, Path, Query},
handler::{get, patch},
http::StatusCode,
@@ -23,7 +24,6 @@ use axum::{
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
- convert::Infallible,
net::SocketAddr,
sync::{Arc, RwLock},
time::Duration,
@@ -49,25 +49,21 @@ async fn main() {
// Add middleware to all routes
.layer(
ServiceBuilder::new()
+ .layer(HandleErrorLayer::new(|error: BoxError| {
+ if error.is::() {
+ Ok(StatusCode::REQUEST_TIMEOUT)
+ } else {
+ Err((
+ StatusCode::INTERNAL_SERVER_ERROR,
+ format!("Unhandled internal error: {}", error),
+ ))
+ }
+ }))
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
.layer(AddExtensionLayer::new(db))
.into_inner(),
- )
- .handle_error(|error: BoxError| {
- let result = if error.is::() {
- Ok(StatusCode::REQUEST_TIMEOUT)
- } else {
- Err((
- StatusCode::INTERNAL_SERVER_ERROR,
- format!("Unhandled internal error: {}", error),
- ))
- };
-
- Ok::<_, Infallible>(result)
- })
- // Make sure all errors have been handled
- .check_infallible();
+ );
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
diff --git a/examples/websockets/src/main.rs b/examples/websockets/src/main.rs
index 2170cd75..fbde0803 100644
--- a/examples/websockets/src/main.rs
+++ b/examples/websockets/src/main.rs
@@ -7,6 +7,7 @@
//! ```
use axum::{
+ error_handling::HandleErrorExt,
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
TypedHeader,
@@ -38,10 +39,10 @@ async fn main() {
ServeDir::new("examples/websockets/assets").append_index_html_on_directories(true),
)
.handle_error(|error: std::io::Error| {
- Ok::<_, std::convert::Infallible>((
+ (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
- ))
+ )
}),
)
// routes are matched from bottom to top, so we have to put `nest` at the
diff --git a/src/error_handling/mod.rs b/src/error_handling/mod.rs
new file mode 100644
index 00000000..a925bdb6
--- /dev/null
+++ b/src/error_handling/mod.rs
@@ -0,0 +1,199 @@
+//! Error handling utilities
+//!
+//! See [error handling](../index.html#error-handling) for more details on how
+//! error handling works in axum.
+
+use crate::{
+ body::{box_body, BoxBody},
+ response::IntoResponse,
+ BoxError,
+};
+use bytes::Bytes;
+use futures_util::ready;
+use http::{Request, Response};
+use pin_project_lite::pin_project;
+use std::convert::Infallible;
+use std::{
+ fmt,
+ future::Future,
+ marker::PhantomData,
+ pin::Pin,
+ task::{Context, Poll},
+};
+use tower::{util::Oneshot, ServiceExt as _};
+use tower_layer::Layer;
+use tower_service::Service;
+
+/// [`Layer`] that applies [`HandleErrorLayer`] which is a [`Service`] adapter
+/// that handles errors by converting them into responses.
+///
+/// See [error handling](../index.html#error-handling) for more details.
+pub struct HandleErrorLayer {
+ f: F,
+ _marker: PhantomData B>,
+}
+
+impl HandleErrorLayer {
+ /// Create a new `HandleErrorLayer`.
+ pub fn new(f: F) -> Self {
+ Self {
+ f,
+ _marker: PhantomData,
+ }
+ }
+}
+
+impl Layer for HandleErrorLayer
+where
+ F: Clone,
+{
+ type Service = HandleError;
+
+ fn layer(&self, inner: S) -> Self::Service {
+ HandleError {
+ inner,
+ f: self.f.clone(),
+ _marker: PhantomData,
+ }
+ }
+}
+
+impl Clone for HandleErrorLayer
+where
+ F: Clone,
+{
+ fn clone(&self) -> Self {
+ Self {
+ f: self.f.clone(),
+ _marker: PhantomData,
+ }
+ }
+}
+
+impl fmt::Debug for HandleErrorLayer {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("HandleErrorLayer")
+ .field("f", &format_args!("{}", std::any::type_name::()))
+ .finish()
+ }
+}
+
+/// A [`Service`] adapter that handles errors by converting them into responses.
+///
+/// See [error handling](../index.html#error-handling) for more details.
+pub struct HandleError {
+ inner: S,
+ f: F,
+ _marker: PhantomData B>,
+}
+
+impl Clone for HandleError
+where
+ S: Clone,
+ F: Clone,
+{
+ fn clone(&self) -> Self {
+ Self::new(self.inner.clone(), self.f.clone())
+ }
+}
+
+impl HandleError {
+ /// Create a new `HandleError`.
+ pub fn new(inner: S, f: F) -> Self {
+ Self {
+ inner,
+ f,
+ _marker: PhantomData,
+ }
+ }
+}
+
+impl fmt::Debug for HandleError
+where
+ S: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("HandleError")
+ .field("inner", &self.inner)
+ .field("f", &format_args!("{}", std::any::type_name::()))
+ .finish()
+ }
+}
+
+impl Service> for HandleError
+where
+ S: Service, Response = Response> + Clone,
+ F: FnOnce(S::Error) -> Res + Clone,
+ Res: IntoResponse,
+ ResBody: http_body::Body + Send + Sync + 'static,
+ ResBody::Error: Into + Send + Sync + 'static,
+{
+ type Response = Response;
+ type Error = Infallible;
+ type Future = HandleErrorFuture>, F>;
+
+ fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn call(&mut self, req: Request) -> Self::Future {
+ HandleErrorFuture {
+ f: Some(self.f.clone()),
+ inner: self.inner.clone().oneshot(req),
+ }
+ }
+}
+
+/// Handle errors this service might produce, by mapping them to responses.
+///
+/// See [error handling](../index.html#error-handling) for more details.
+pub trait HandleErrorExt: Service> + Sized {
+ /// Apply a [`HandleError`] middleware.
+ fn handle_error(self, f: F) -> HandleError {
+ HandleError::new(self, f)
+ }
+}
+
+impl HandleErrorExt for S where S: Service> {}
+
+pin_project! {
+ /// Response future for [`HandleError`](super::HandleError).
+ #[derive(Debug)]
+ pub struct HandleErrorFuture {
+ #[pin]
+ pub(super) inner: Fut,
+ pub(super) f: Option,
+ }
+}
+
+impl Future for HandleErrorFuture
+where
+ Fut: Future