Move axum-handle-error-extract into axum (#534)

* Move `axum-handle-error-extract` into axum

With 0.4 underway we can now nuke `axum-handle-error-extract` and move
its code directly into axum.

So this replaces the old `HandleErrorLayer` with one that supports async
functions and extractors.

* changelog

* fix CI
This commit is contained in:
David Pedersen 2021-11-17 20:09:58 +01:00 committed by GitHub
parent a317072467
commit 9a410371a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 254 additions and 650 deletions

View file

@ -100,7 +100,6 @@ jobs:
command: test
args: >
-p axum
-p axum-handle-error-extract
--all-features --all-targets
# the compiler errors are different on 1.54 which makes
# the trybuild tests in axum-debug fail, so just run the doc

View file

@ -2,6 +2,5 @@
members = [
"axum",
"axum-debug",
"axum-handle-error-extract",
"examples/*",
]

View file

@ -1,14 +0,0 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
# Unreleased
- None
# 0.1.0 (05. November, 2021)
- Initial release.

View file

@ -1,22 +0,0 @@
[package]
name = "axum-handle-error-extract"
version = "0.1.0"
authors = ["David Pedersen <david.pdrsn@gmail.com>"]
categories = ["asynchronous", "network-programming", "web-programming"]
description = "Error handling layer for axum that supports extractors and async functions"
edition = "2018"
homepage = "https://github.com/tokio-rs/axum"
keywords = ["http", "web", "framework"]
license = "MIT"
readme = "README.md"
repository = "https://github.com/tokio-rs/axum"
[dependencies]
axum = { version = "0.3.2", path = "../axum" }
tower-service = "0.3"
tower-layer = "0.3"
tower = { version = "0.4", features = ["util"] }
pin-project-lite = "0.2"
[dev-dependencies]
tower = { version = "0.4", features = ["util", "timeout"] }

View file

@ -1,25 +0,0 @@
Copyright (c) 2019 Tower Contributors
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

View file

@ -1,45 +0,0 @@
# axum-handle-error-extract
Error handling layer for axum that supports extractors and async functions
[![Build status](https://github.com/tokio-rs/axum-handle-error-extract/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/tokio-rs/axum-handle-error-extract/actions/workflows/CI.yml)
[![Crates.io](https://img.shields.io/crates/v/axum-handle-error-extract)](https://crates.io/crates/axum-handle-error-extract)
[![Documentation](https://docs.rs/axum-handle-error-extract/badge.svg)](https://docs.rs/axum-handle-error-extract)
More information about this crate can be found in the [crate documentation][docs].
## Safety
This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in
100% safe Rust.
## Minimum supported Rust version
axum-handle-error-extract's MSRV is 1.54.
## Getting Help
You're also welcome to ask in the [Discord channel][chat] or open an [issue]
with your question.
## Contributing
:balloon: Thanks for your help improving the project! We are so happy to have
you! We have a [contributing guide][contributing] to help you get involved in the
`axum` project.
## License
This project is licensed under the [MIT license][license].
### Contribution
Unless you explicitly state otherwise, any contribution intentionally submitted
for inclusion in `axum` by you, shall be licensed as MIT, without any
additional terms or conditions.
[docs]: https://docs.rs/axum-handle-error-extract
[contributing]: /CONTRIBUTING.md
[chat]: https://discord.gg/tokio
[issue]: https://github.com/tokio-rs/axum/issues/new
[license]: /axum/LICENSE

View file

@ -1,402 +0,0 @@
//! Error handling layer for axum that supports extractors and async functions.
//!
//! This crate provides [`HandleErrorLayer`] which works similarly to
//! [`axum::error_handling::HandleErrorLayer`] except that it supports
//! extractors and async functions:
//!
//! ```rust
//! use axum::{
//! Router,
//! BoxError,
//! response::IntoResponse,
//! http::{StatusCode, Method, Uri},
//! routing::get,
//! };
//! use tower::{ServiceBuilder, timeout::error::Elapsed};
//! use std::time::Duration;
//! use axum_handle_error_extract::HandleErrorLayer;
//!
//! let app = Router::new()
//! .route("/", get(|| async {}))
//! .layer(
//! ServiceBuilder::new()
//! // timeouts produces errors, so we handle those with `handle_error`
//! .layer(HandleErrorLayer::new(handle_error))
//! .timeout(Duration::from_secs(10))
//! );
//!
//! // our handler take can 0 to 16 extractors and the final argument must
//! // always be the error produced by the middleware
//! async fn handle_error(
//! method: Method,
//! uri: Uri,
//! error: BoxError,
//! ) -> impl IntoResponse {
//! if error.is::<Elapsed>() {
//! (
//! StatusCode::REQUEST_TIMEOUT,
//! format!("{} {} took too long", method, uri),
//! )
//! } else {
//! (
//! StatusCode::INTERNAL_SERVER_ERROR,
//! format!("{} {} failed: {}", method, uri, error),
//! )
//! }
//! }
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! Not running any extractors is also supported:
//!
//! ```rust
//! use axum::{
//! Router,
//! BoxError,
//! response::IntoResponse,
//! http::StatusCode,
//! routing::get,
//! };
//! use tower::{ServiceBuilder, timeout::error::Elapsed};
//! use std::time::Duration;
//! use axum_handle_error_extract::HandleErrorLayer;
//!
//! let app = Router::new()
//! .route("/", get(|| async {}))
//! .layer(
//! ServiceBuilder::new()
//! .layer(HandleErrorLayer::new(handle_error))
//! .timeout(Duration::from_secs(10))
//! );
//!
//! // this function just takes the error
//! async fn handle_error(error: BoxError) -> impl IntoResponse {
//! if error.is::<Elapsed>() {
//! (
//! StatusCode::REQUEST_TIMEOUT,
//! "Request timeout".to_string(),
//! )
//! } else {
//! (
//! StatusCode::INTERNAL_SERVER_ERROR,
//! format!("Unhandled internal error: {}", error),
//! )
//! }
//! }
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! See [`axum::error_handling`] for more details on axum's error handling model and
//! [`axum::extract`] for more details on extractors.
//!
//! # The future
//!
//! In axum 0.4 this will replace the current [`axum::error_handling::HandleErrorLayer`].
#![warn(
clippy::all,
clippy::dbg_macro,
clippy::todo,
clippy::empty_enum,
clippy::enum_glob_use,
clippy::mem_forget,
clippy::unused_self,
clippy::filter_map_next,
clippy::needless_continue,
clippy::needless_borrow,
clippy::match_wildcard_for_single_variants,
clippy::if_let_mutex,
clippy::mismatched_target_os,
clippy::await_holding_lock,
clippy::match_on_vec_items,
clippy::imprecise_flops,
clippy::suboptimal_flops,
clippy::lossy_float_literal,
clippy::rest_pat_in_fully_bound_structs,
clippy::fn_params_excessive_bools,
clippy::exit,
clippy::inefficient_to_string,
clippy::linkedlist,
clippy::macro_use_imports,
clippy::option_option,
clippy::verbose_file_reads,
clippy::unnested_or_patterns,
rust_2018_idioms,
future_incompatible,
nonstandard_style,
missing_debug_implementations,
missing_docs
)]
#![deny(unreachable_pub, private_in_public)]
#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
#![forbid(unsafe_code)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![cfg_attr(test, allow(clippy::float_cmp))]
use axum::{
body::{boxed, BoxBody, Bytes, Full, HttpBody},
extract::{FromRequest, RequestParts},
http::{Request, Response, StatusCode},
response::IntoResponse,
BoxError,
};
use pin_project_lite::pin_project;
use std::{
convert::Infallible,
fmt,
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use tower::ServiceExt;
use tower_layer::Layer;
use tower_service::Service;
/// [`Layer`] that applies [`HandleError`] which is a [`Service`] adapter
/// that handles errors by converting them into responses.
///
/// See [module docs](self) for more details on axum's error handling model.
pub struct HandleErrorLayer<F, T> {
f: F,
_extractor: PhantomData<fn() -> T>,
}
impl<F, T> HandleErrorLayer<F, T> {
/// Create a new `HandleErrorLayer`.
pub fn new(f: F) -> Self {
Self {
f,
_extractor: PhantomData,
}
}
}
impl<F, T> Clone for HandleErrorLayer<F, T>
where
F: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
_extractor: PhantomData,
}
}
}
impl<F, E> fmt::Debug for HandleErrorLayer<F, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HandleErrorLayer")
.field("f", &format_args!("{}", std::any::type_name::<F>()))
.finish()
}
}
impl<S, F, T> Layer<S> for HandleErrorLayer<F, T>
where
F: Clone,
{
type Service = HandleError<S, F, T>;
fn layer(&self, inner: S) -> Self::Service {
HandleError::new(inner, self.f.clone())
}
}
/// A [`Service`] adapter that handles errors by converting them into responses.
///
/// See [module docs](self) for more details on axum's error handling model.
pub struct HandleError<S, F, T> {
inner: S,
f: F,
_extractor: PhantomData<fn() -> T>,
}
impl<S, F, T> HandleError<S, F, T> {
/// Create a new `HandleError`.
pub fn new(inner: S, f: F) -> Self {
Self {
inner,
f,
_extractor: PhantomData,
}
}
}
impl<S, F, T> Clone for HandleError<S, F, T>
where
S: Clone,
F: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
f: self.f.clone(),
_extractor: PhantomData,
}
}
}
impl<S, F, E> fmt::Debug for HandleError<S, F, E>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HandleError")
.field("inner", &self.inner)
.field("f", &format_args!("{}", std::any::type_name::<F>()))
.finish()
}
}
impl<S, F, ReqBody, ResBody, Fut, Res> Service<Request<ReqBody>> for HandleError<S, F, ()>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Error: Send,
S::Future: Send,
F: FnOnce(S::Error) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
Res: IntoResponse,
ReqBody: Send + 'static,
ResBody: HttpBody<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = ResponseFuture;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let f = self.f.clone();
let clone = self.inner.clone();
let inner = std::mem::replace(&mut self.inner, clone);
let future = Box::pin(async move {
match inner.oneshot(req).await {
Ok(res) => Ok(res.map(boxed)),
Err(err) => Ok(f(err).await.into_response().map(boxed)),
}
});
ResponseFuture { future }
}
}
#[allow(unused_macros)]
macro_rules! impl_service {
( $($ty:ident),* $(,)? ) => {
impl<S, F, ReqBody, ResBody, Res, Fut, $($ty,)*> Service<Request<ReqBody>>
for HandleError<S, F, ($($ty,)*)>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Error: Send,
S::Future: Send,
F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
Res: IntoResponse,
$( $ty: FromRequest<ReqBody> + Send,)*
ReqBody: Send + 'static,
ResBody: HttpBody<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = ResponseFuture;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[allow(non_snake_case)]
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let f = self.f.clone();
let clone = self.inner.clone();
let inner = std::mem::replace(&mut self.inner, clone);
let future = Box::pin(async move {
let mut req = RequestParts::new(req);
$(
let $ty = match $ty::from_request(&mut req).await {
Ok(value) => value,
Err(rejection) => return Ok(rejection.into_response().map(boxed)),
};
)*
let req = match req.try_into_request() {
Ok(req) => req,
Err(err) => {
return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(boxed(Full::from(err.to_string())))
.unwrap());
}
};
match inner.oneshot(req).await {
Ok(res) => Ok(res.map(boxed)),
Err(err) => Ok(f($($ty),*, err).await.into_response().map(boxed)),
}
});
ResponseFuture { future }
}
}
}
}
impl_service!(T1);
impl_service!(T1, T2);
impl_service!(T1, T2, T3);
impl_service!(T1, T2, T3, T4);
impl_service!(T1, T2, T3, T4, T5);
impl_service!(T1, T2, T3, T4, T5, T6);
impl_service!(T1, T2, T3, T4, T5, T6, T7);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
pin_project! {
/// Response future for [`HandleError`].
pub struct ResponseFuture {
#[pin]
future: Pin<Box<dyn Future<Output = Result<Response<BoxBody>, Infallible>> + Send + 'static>>,
}
}
impl Future for ResponseFuture {
type Output = Result<Response<BoxBody>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().future.poll(cx)
}
}
/// Extension trait to [`Service`] for handling errors by mapping them to
/// responses.
///
/// See [module docs](self) for more details on axum's error handling model.
pub trait HandleErrorExt<B>: Service<Request<B>> + Sized {
/// Apply a [`HandleError`] middleware.
fn handle_error<F>(self, f: F) -> HandleError<Self, F, B> {
HandleError::new(self, f)
}
}
impl<B, S> HandleErrorExt<B> for S where S: Service<Request<B>> {}

View file

@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`routing::{get, get_service, ..., MethodRouter}`.
- **breaking:** `HandleErrorExt` has been removed in favor of
`MethodRouter::handle_error`.
- **breaking:** `HandleErrorLayer` now requires the handler function to be
`async` ([#534])
- **added:** `HandleErrorLayer` now supports running extractors.
- **breaking:** The `Handler<B, T>` trait is now defined as `Handler<T, B =
Body>`. That is the type parameters have been swapped and `B` defaults to
`axum::body::Body` ([#527])
@ -29,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#525]: https://github.com/tokio-rs/axum/pull/525
[#527]: https://github.com/tokio-rs/axum/pull/527
[#534]: https://github.com/tokio-rs/axum/pull/534
# 0.3.3 (13. November, 2021)

View file

@ -72,7 +72,7 @@ let app = Router::new().route(
// handle errors by converting them into something that implements
// `IntoResponse`
fn handle_anyhow_error(err: anyhow::Error) -> (StatusCode, String) {
async fn handle_anyhow_error(err: anyhow::Error) -> (StatusCode, String) {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {}", err),
@ -109,7 +109,7 @@ let app = Router::new()
.timeout(Duration::from_secs(30))
);
fn handle_timeout_error(err: BoxError) -> (StatusCode, String) {
async fn handle_timeout_error(err: BoxError) -> (StatusCode, String) {
if err.is::<tower::timeout::error::Elapsed>() {
(
StatusCode::REQUEST_TIMEOUT,
@ -127,6 +127,48 @@ fn handle_timeout_error(err: BoxError) -> (StatusCode, String) {
# };
```
# Running extractors for error handling
`HandleErrorLayer` also supports running extractors:
```rust
use axum::{
Router,
BoxError,
routing::get,
http::{StatusCode, Method, Uri},
error_handling::HandleErrorLayer,
};
use std::time::Duration;
use tower::ServiceBuilder;
let app = Router::new()
.route("/", get(|| async {}))
.layer(
ServiceBuilder::new()
// `timeout` will produce an error if the handler takes
// too long so we must handle those
.layer(HandleErrorLayer::new(handle_timeout_error))
.timeout(Duration::from_secs(30))
);
async fn handle_timeout_error(
// `Method` and `Uri` are extractors so they can be used here
method: Method,
uri: Uri,
// the last argument must be the error itself
err: BoxError,
) -> (StatusCode, String) {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("`{} {}` failed with {}", method, uri, err),
)
}
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
# };
```
[`tower::Service`]: `tower::Service`
[`Infallible`]: std::convert::Infallible
[`Response<_>`]: http::Response

View file

@ -69,7 +69,7 @@ use tower::{
use std::convert::Infallible;
use tower_http::trace::TraceLayer;
#
# fn handle_error<T>(error: T) -> axum::http::StatusCode {
# async fn handle_error<T>(error: T) -> axum::http::StatusCode {
# axum::http::StatusCode::INTERNAL_SERVER_ERROR
# }

View file

@ -124,7 +124,9 @@ let app = Router::new()
ServiceBuilder::new()
// this middleware goes above `TimeoutLayer` because it will receive
// errors returned by `TimeoutLayer`
.layer(HandleErrorLayer::new(|_: BoxError| StatusCode::REQUEST_TIMEOUT))
.layer(HandleErrorLayer::new(|_: BoxError| async {
StatusCode::REQUEST_TIMEOUT
}))
.layer(TimeoutLayer::new(Duration::from_secs(10)))
);
# async {

View file

@ -74,13 +74,14 @@ use axum::{
Router,
routing::get_service,
http::StatusCode,
error_handling::HandleErrorLayer,
};
use std::{io, convert::Infallible};
use tower_http::services::ServeDir;
// Serves files inside the `public` directory at `GET /public/*`
let serve_dir_service = get_service(ServeDir::new("public"))
.handle_error(|error: io::Error| {
.handle_error(|error: io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),

View file

@ -113,6 +113,7 @@ use axum::{
body::Body,
routing::{any_service, get_service},
http::{Request, StatusCode},
error_handling::HandleErrorLayer,
};
use tower_http::services::ServeFile;
use http::Response;
@ -147,7 +148,7 @@ let app = Router::new()
"/static/Cargo.toml",
get_service(ServeFile::new("Cargo.toml"))
// though we must handle any potential errors
.handle_error(|error: io::Error| {
.handle_error(|error: io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),

View file

@ -1,15 +1,20 @@
#![doc = include_str!("../docs/error_handling.md")]
use crate::{body::BoxBody, response::IntoResponse, BoxError};
use bytes::Bytes;
use http::{Request, Response};
use crate::{
body::{boxed, BoxBody, Bytes, Full, HttpBody},
extract::{FromRequest, RequestParts},
http::{Request, Response, StatusCode},
response::IntoResponse,
BoxError,
};
use std::{
convert::Infallible,
fmt,
future::Future,
marker::PhantomData,
task::{Context, Poll},
};
use tower::{util::Oneshot, ServiceExt as _};
use tower::ServiceExt;
use tower_layer::Layer;
use tower_service::Service;
@ -17,49 +22,34 @@ use tower_service::Service;
/// that handles errors by converting them into responses.
///
/// See [module docs](self) for more details on axum's error handling model.
pub struct HandleErrorLayer<F, B> {
pub struct HandleErrorLayer<F, T> {
f: F,
_marker: PhantomData<fn() -> B>,
_extractor: PhantomData<fn() -> T>,
}
impl<F, B> HandleErrorLayer<F, B> {
impl<F, T> HandleErrorLayer<F, T> {
/// Create a new `HandleErrorLayer`.
pub fn new(f: F) -> Self {
Self {
f,
_marker: PhantomData,
_extractor: PhantomData,
}
}
}
impl<F, B, S> Layer<S> for HandleErrorLayer<F, B>
where
F: Clone,
{
type Service = HandleError<S, F, B>;
fn layer(&self, inner: S) -> Self::Service {
HandleError {
inner,
f: self.f.clone(),
_marker: PhantomData,
}
}
}
impl<F, B> Clone for HandleErrorLayer<F, B>
impl<F, T> Clone for HandleErrorLayer<F, T>
where
F: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
_marker: PhantomData,
_extractor: PhantomData,
}
}
}
impl<F, B> fmt::Debug for HandleErrorLayer<F, B> {
impl<F, E> fmt::Debug for HandleErrorLayer<F, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HandleErrorLayer")
.field("f", &format_args!("{}", std::any::type_name::<F>()))
@ -67,37 +57,52 @@ impl<F, B> fmt::Debug for HandleErrorLayer<F, B> {
}
}
/// A [`Service`] adapter that handles errors by converting them into responses.
///
/// See [module docs](self) for more details on axum's error handling model.
pub struct HandleError<S, F, B> {
inner: S,
f: F,
_marker: PhantomData<fn() -> B>,
}
impl<S, F, B> Clone for HandleError<S, F, B>
impl<S, F, T> Layer<S> for HandleErrorLayer<F, T>
where
S: Clone,
F: Clone,
{
fn clone(&self) -> Self {
Self::new(self.inner.clone(), self.f.clone())
type Service = HandleError<S, F, T>;
fn layer(&self, inner: S) -> Self::Service {
HandleError::new(inner, self.f.clone())
}
}
impl<S, F, B> HandleError<S, F, B> {
/// A [`Service`] adapter that handles errors by converting them into responses.
///
/// See [module docs](self) for more details on axum's error handling model.
pub struct HandleError<S, F, T> {
inner: S,
f: F,
_extractor: PhantomData<fn() -> T>,
}
impl<S, F, T> HandleError<S, F, T> {
/// Create a new `HandleError`.
pub fn new(inner: S, f: F) -> Self {
Self {
inner,
f,
_marker: PhantomData,
_extractor: PhantomData,
}
}
}
impl<S, F, B> fmt::Debug for HandleError<S, F, B>
impl<S, F, T> Clone for HandleError<S, F, T>
where
S: Clone,
F: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
f: self.f.clone(),
_extractor: PhantomData,
}
}
}
impl<S, F, E> fmt::Debug for HandleError<S, F, E>
where
S: fmt::Debug,
{
@ -109,41 +114,114 @@ where
}
}
impl<S, F, ReqBody, ResBody, Res> Service<Request<ReqBody>> for HandleError<S, F, ReqBody>
impl<S, F, ReqBody, ResBody, Fut, Res> Service<Request<ReqBody>> for HandleError<S, F, ()>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
F: FnOnce(S::Error) -> Res + Clone,
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Error: Send,
S::Future: Send,
F: FnOnce(S::Error) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
Res: IntoResponse,
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ReqBody: Send + 'static,
ResBody: HttpBody<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = future::HandleErrorFuture<Oneshot<S, Request<ReqBody>>, F>;
type Future = future::HandleErrorFuture;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
future::HandleErrorFuture {
f: Some(self.f.clone()),
inner: self.inner.clone().oneshot(req),
let f = self.f.clone();
let clone = self.inner.clone();
let inner = std::mem::replace(&mut self.inner, clone);
let future = Box::pin(async move {
match inner.oneshot(req).await {
Ok(res) => Ok(res.map(boxed)),
Err(err) => Ok(f(err).await.into_response().map(boxed)),
}
});
future::HandleErrorFuture { future }
}
}
#[allow(unused_macros)]
macro_rules! impl_service {
( $($ty:ident),* $(,)? ) => {
impl<S, F, ReqBody, ResBody, Res, Fut, $($ty,)*> Service<Request<ReqBody>>
for HandleError<S, F, ($($ty,)*)>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
S::Error: Send,
S::Future: Send,
F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
Res: IntoResponse,
$( $ty: FromRequest<ReqBody> + Send,)*
ReqBody: Send + 'static,
ResBody: HttpBody<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = future::HandleErrorFuture;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[allow(non_snake_case)]
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let f = self.f.clone();
let clone = self.inner.clone();
let inner = std::mem::replace(&mut self.inner, clone);
let future = Box::pin(async move {
let mut req = RequestParts::new(req);
$(
let $ty = match $ty::from_request(&mut req).await {
Ok(value) => value,
Err(rejection) => return Ok(rejection.into_response().map(boxed)),
};
)*
let req = match req.try_into_request() {
Ok(req) => req,
Err(err) => {
return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(boxed(Full::from(err.to_string())))
.unwrap());
}
};
match inner.oneshot(req).await {
Ok(res) => Ok(res.map(boxed)),
Err(err) => Ok(f($($ty),*, err).await.into_response().map(boxed)),
}
});
future::HandleErrorFuture { future }
}
}
}
}
all_the_tuples!(impl_service);
pub mod future {
//! Future types.
use crate::{
body::{boxed, BoxBody},
response::IntoResponse,
BoxError,
};
use bytes::Bytes;
use futures_util::ready;
use crate::body::BoxBody;
use http::Response;
use pin_project_lite::pin_project;
use std::{
@ -154,36 +232,21 @@ pub mod future {
};
pin_project! {
/// Response future for [`HandleError`](super::HandleError).
#[derive(Debug)]
pub struct HandleErrorFuture<Fut, F> {
/// Response future for [`HandleError`].
pub struct HandleErrorFuture {
#[pin]
pub(super) inner: Fut,
pub(super) f: Option<F>,
pub(super) future: Pin<Box<dyn Future<Output = Result<Response<BoxBody>, Infallible>>
+ Send
+ 'static
>>,
}
}
impl<Fut, F, E, B, Res> Future for HandleErrorFuture<Fut, F>
where
Fut: Future<Output = Result<Response<B>, E>>,
F: FnOnce(E) -> Res,
Res: IntoResponse,
B: http_body::Body<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>,
{
impl Future for HandleErrorFuture {
type Output = Result<Response<BoxBody>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match ready!(this.inner.poll(cx)) {
Ok(res) => Ok(res.map(boxed)).into(),
Err(err) => {
let f = this.f.take().unwrap();
let res = f(err);
Ok(res.into_response().map(boxed)).into()
}
}
self.project().future.poll(cx)
}
}
}

View file

@ -40,19 +40,4 @@ macro_rules! impl_from_request {
};
}
impl_from_request!(T1);
impl_from_request!(T1, T2);
impl_from_request!(T1, T2, T3);
impl_from_request!(T1, T2, T3, T4);
impl_from_request!(T1, T2, T3, T4, T5);
impl_from_request!(T1, T2, T3, T4, T5, T6);
impl_from_request!(T1, T2, T3, T4, T5, T6, T7);
impl_from_request!(T1, T2, T3, T4, T5, T6, T7, T8);
impl_from_request!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
impl_from_request!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
impl_from_request!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
impl_from_request!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
impl_from_request!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
impl_from_request!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
impl_from_request!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
impl_from_request!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
all_the_tuples!(impl_from_request);

View file

@ -308,22 +308,7 @@ macro_rules! impl_handler {
};
}
impl_handler!(T1);
impl_handler!(T1, T2);
impl_handler!(T1, T2, T3);
impl_handler!(T1, T2, T3, T4);
impl_handler!(T1, T2, T3, T4, T5);
impl_handler!(T1, T2, T3, T4, T5, T6);
impl_handler!(T1, T2, T3, T4, T5, T6, T7);
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8);
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
all_the_tuples!(impl_handler);
/// A [`Service`] created from a [`Handler`] by applying a Tower middleware.
///

View file

@ -184,3 +184,24 @@ macro_rules! composite_rejection {
}
};
}
macro_rules! all_the_tuples {
($name:ident) => {
$name!(T1);
$name!(T1, T2);
$name!(T1, T2, T3);
$name!(T1, T2, T3, T4);
$name!(T1, T2, T3, T4, T5);
$name!(T1, T2, T3, T4, T5, T6);
$name!(T1, T2, T3, T4, T5, T6, T7);
$name!(T1, T2, T3, T4, T5, T6, T7, T8);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
};
}

View file

@ -1,6 +1,6 @@
use crate::{
body::{boxed, Body, BoxBody, Bytes},
error_handling::HandleErrorLayer,
error_handling::{HandleError, HandleErrorLayer},
handler::Handler,
http::{Method, Request, Response, StatusCode},
routing::{Fallback, MethodFilter, Route},
@ -795,12 +795,15 @@ impl<ReqBody, E> MethodRouter<ReqBody, E> {
/// Apply a [`HandleErrorLayer`].
///
/// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`.
pub fn handle_error<F, Res>(self, f: F) -> MethodRouter<ReqBody, Infallible>
pub fn handle_error<F, T>(self, f: F) -> MethodRouter<ReqBody, Infallible>
where
F: FnOnce(E) -> Res + Clone + Send + 'static,
Res: crate::response::IntoResponse,
ReqBody: Send + 'static,
F: Clone + Send + 'static,
HandleError<Route<ReqBody, E>, F, T>:
Service<Request<ReqBody>, Response = Response<BoxBody>, Error = Infallible>,
<HandleError<Route<ReqBody, E>, F, T> as Service<Request<ReqBody>>>::Future: Send,
T: 'static,
E: 'static,
ReqBody: 'static,
{
self.layer(HandleErrorLayer::new(f))
}
@ -1047,12 +1050,17 @@ mod tests {
get(ok)
.post(ok)
.route_layer(RequireAuthorizationLayer::bearer("password"))
.merge(delete_service(ServeDir::new(".")).handle_error(|_| StatusCode::NOT_FOUND))
.merge(
delete_service(ServeDir::new("."))
.handle_error(|_| async { StatusCode::NOT_FOUND }),
)
.fallback((|| async { StatusCode::NOT_FOUND }).into_service())
.put(ok)
.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_| StatusCode::REQUEST_TIMEOUT))
.layer(HandleErrorLayer::new(|_| async {
StatusCode::REQUEST_TIMEOUT
}))
.layer(TimeoutLayer::new(Duration::from_secs(10))),
),
);

View file

@ -35,7 +35,7 @@ async fn handler() {
"/",
get(forever.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| {
.layer(HandleErrorLayer::new(|_: BoxError| async {
StatusCode::REQUEST_TIMEOUT
}))
.layer(timeout()),
@ -54,7 +54,7 @@ async fn handler_multiple_methods_first() {
"/",
get(forever.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| {
.layer(HandleErrorLayer::new(|_: BoxError| async {
StatusCode::REQUEST_TIMEOUT
}))
.layer(timeout()),
@ -76,7 +76,7 @@ async fn handler_multiple_methods_middle() {
.get(
forever.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| {
.layer(HandleErrorLayer::new(|_: BoxError| async {
StatusCode::REQUEST_TIMEOUT
}))
.layer(timeout()),
@ -98,7 +98,7 @@ async fn handler_multiple_methods_last() {
delete(unit).get(
forever.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| {
.layer(HandleErrorLayer::new(|_: BoxError| async {
StatusCode::REQUEST_TIMEOUT
}))
.layer(timeout()),

View file

@ -137,7 +137,9 @@ async fn layer_and_handle_error() {
.route("/timeout", get(futures::future::pending::<()>))
.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_| StatusCode::REQUEST_TIMEOUT))
.layer(HandleErrorLayer::new(|_| async {
StatusCode::REQUEST_TIMEOUT
}))
.layer(TimeoutLayer::new(Duration::from_millis(10))),
);
let app = one.merge(two);

View file

@ -299,7 +299,7 @@ async fn middleware_applies_to_routes_above() {
.route("/one", get(std::future::pending::<()>))
.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| {
.layer(HandleErrorLayer::new(|_: BoxError| async move {
StatusCode::REQUEST_TIMEOUT
}))
.layer(TimeoutLayer::new(Duration::new(0, 0))),

View file

@ -169,7 +169,7 @@ async fn nested_service_sees_stripped_uri() {
async fn nest_static_file_server() {
let app = Router::new().nest(
"/static",
get_service(ServeDir::new(".")).handle_error(|error| {
get_service(ServeDir::new(".")).handle_error(|error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),

View file

@ -124,7 +124,7 @@ fn admin_routes() -> Router {
.layer(RequireAuthorizationLayer::bearer("secret-token"))
}
fn handle_error(error: BoxError) -> impl IntoResponse {
async fn handle_error(error: BoxError) -> impl IntoResponse {
if error.is::<tower::timeout::error::Elapsed>() {
return (StatusCode::REQUEST_TIMEOUT, Cow::from("request timed out"));
}

View file

@ -29,7 +29,7 @@ async fn main() {
.route("/", post(|| async move { "Hello from `POST /`" }))
.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|error| {
.layer(HandleErrorLayer::new(|error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),

View file

@ -26,7 +26,7 @@ async fn main() {
let static_files_service =
get_service(ServeDir::new("examples/sse/assets").append_index_html_on_directories(true))
.handle_error(|error: std::io::Error| {
.handle_error(|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),

View file

@ -22,7 +22,7 @@ async fn main() {
let app = Router::new()
.nest(
"/static",
get_service(ServeDir::new(".")).handle_error(|error: std::io::Error| {
get_service(ServeDir::new(".")).handle_error(|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),

View file

@ -49,7 +49,7 @@ async fn main() {
// Add middleware to all routes
.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|error: BoxError| {
.layer(HandleErrorLayer::new(|error: BoxError| async move {
if error.is::<tower::timeout::error::Elapsed>() {
Ok(StatusCode::REQUEST_TIMEOUT)
} else {

View file

@ -36,7 +36,7 @@ async fn main() {
get_service(
ServeDir::new("examples/websockets/assets").append_index_html_on_directories(true),
)
.handle_error(|error: std::io::Error| {
.handle_error(|error: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),