Revamp error handling model (#402)

* Revamp error handling model

* changelog improvements and typo fixes

* Fix a few more Infallible bounds

* minor docs fixes
This commit is contained in:
David Pedersen 2021-10-24 19:33:03 +02:00 committed by GitHub
parent 1a78a3f224
commit f10508db0b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 501 additions and 558 deletions

View file

@ -13,8 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Adding a conflicting route will now cause a panic instead of silently making
a route unreachable.
- Route matching is faster as number of routes increase.
- The routes `/foo` and `/:key` are considered to overlap and will cause a
panic when constructing the router. This might be fixed in the future.
- **breaking:** The routes `/foo` and `/:key` are considered to overlap and
will cause a panic when constructing the router. This might be fixed in the future.
- Improve performance of `BoxRoute` ([#339])
- Expand accepted content types for JSON requests ([#378])
- **breaking:** Automatically do percent decoding in `extract::Path`
@ -28,6 +28,73 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
thus wasn't necessary
- **breaking:** Change `Connected::connect_info` to return `Self` and remove
the associated type `ConnectInfo` ([#396])
- **breaking:** Simplify error handling model ([#402]):
- All services part of the router are now required to be infallible.
- Error handling utilities have been moved to an `error_handling` module.
- `Router::check_infallible` has been removed since routers are always
infallible with the error handling changes.
- Error handling closures must now handle all errors and thus always return
something that implements `IntoResponse`.
With these changes handling errors from fallible middleware is done like so:
```rust,no_run
use axum::{
handler::get,
http::StatusCode,
error_handling::HandleErrorLayer,
response::IntoResponse,
Router, BoxError,
};
use tower::ServiceBuilder;
use std::time::Duration;
let middleware_stack = ServiceBuilder::new()
// Handle errors from middleware
//
// This middleware most be added above any fallible
// ones if you're using `ServiceBuilder`, due to how ordering works
.layer(HandleErrorLayer::new(handle_error))
// Return an error after 30 seconds
.timeout(Duration::from_secs(30));
let app = Router::new()
.route("/", get(|| async { /* ... */ }))
.layer(middleware_stack);
fn handle_error(_error: BoxError) -> impl IntoResponse {
StatusCode::REQUEST_TIMEOUT
}
```
And handling errors from fallible leaf services is done like so:
```rust
use axum::{
Router, service,
body::Body,
handler::get,
response::IntoResponse,
http::{Request, Response},
error_handling::HandleErrorExt, // for `.handle_error`
};
use std::{io, convert::Infallible};
use tower::service_fn;
let app = Router::new()
.route(
"/",
service::get(service_fn(|_req: Request<Body>| async {
let contents = tokio::fs::read_to_string("some_file").await?;
Ok::<_, io::Error>(Response::new(Body::from(contents)))
}))
.handle_error(handle_io_error),
);
fn handle_io_error(error: io::Error) -> impl IntoResponse {
// ...
}
```
[#339]: https://github.com/tokio-rs/axum/pull/339
[#286]: https://github.com/tokio-rs/axum/pull/286
@ -35,6 +102,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#378]: https://github.com/tokio-rs/axum/pull/378
[#363]: https://github.com/tokio-rs/axum/pull/363
[#396]: https://github.com/tokio-rs/axum/pull/396
[#402]: https://github.com/tokio-rs/axum/pull/402
# 0.2.8 (07. October, 2021)

View file

@ -8,6 +8,7 @@
use axum::{
body::Bytes,
error_handling::HandleErrorLayer,
extract::{ContentLengthLimit, Extension, Path},
handler::{delete, get, Handler},
http::StatusCode,
@ -18,7 +19,6 @@ use axum::{
use std::{
borrow::Cow,
collections::HashMap,
convert::Infallible,
net::SocketAddr,
sync::{Arc, RwLock},
time::Duration,
@ -52,16 +52,15 @@ async fn main() {
// Add middleware to all routes
.layer(
ServiceBuilder::new()
// Handle errors from middleware
.layer(HandleErrorLayer::new(handle_error))
.load_shed()
.concurrency_limit(1024)
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
.layer(AddExtensionLayer::new(SharedState::default()))
.into_inner(),
)
// Handle errors from middleware
.handle_error(handle_error)
.check_infallible();
);
// Run our app with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
@ -126,20 +125,20 @@ fn admin_routes() -> Router<BoxRoute> {
.boxed()
}
fn handle_error(error: BoxError) -> Result<impl IntoResponse, Infallible> {
fn handle_error(error: BoxError) -> impl IntoResponse {
if error.is::<tower::timeout::error::Elapsed>() {
return Ok((StatusCode::REQUEST_TIMEOUT, Cow::from("request timed out")));
return (StatusCode::REQUEST_TIMEOUT, Cow::from("request timed out"));
}
if error.is::<tower::load_shed::error::Overloaded>() {
return Ok((
return (
StatusCode::SERVICE_UNAVAILABLE,
Cow::from("service is overloaded, try again later"),
));
);
}
Ok((
(
StatusCode::INTERNAL_SERVER_ERROR,
Cow::from(format!("Unhandled internal error: {}", error)),
))
)
}

View file

@ -6,12 +6,13 @@
use axum::{
body::{Body, BoxBody, Bytes},
error_handling::HandleErrorLayer,
handler::post,
http::{Request, Response},
http::{Request, Response, StatusCode},
Router,
};
use std::net::SocketAddr;
use tower::{filter::AsyncFilterLayer, util::AndThenLayer, BoxError};
use tower::{filter::AsyncFilterLayer, util::AndThenLayer, BoxError, ServiceBuilder};
#[tokio::main]
async fn main() {
@ -26,8 +27,17 @@ async fn main() {
let app = Router::new()
.route("/", post(|| async move { "Hello from `POST /`" }))
.layer(AsyncFilterLayer::new(map_request))
.layer(AndThenLayer::new(map_response));
.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|error| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
}))
.layer(AndThenLayer::new(map_response))
.layer(AsyncFilterLayer::new(map_request)),
);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);

View file

@ -5,6 +5,7 @@
//! ```
use axum::{
error_handling::HandleErrorExt,
extract::TypedHeader,
handler::get,
http::StatusCode,
@ -28,10 +29,10 @@ async fn main() {
ServeDir::new("examples/sse/assets").append_index_html_on_directories(true),
)
.handle_error(|error: std::io::Error| {
Ok::<_, std::convert::Infallible>((
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
))
)
});
// build our application with a route

View file

@ -4,8 +4,8 @@
//! cargo run -p example-static-file-server
//! ```
use axum::{http::StatusCode, service, Router};
use std::{convert::Infallible, net::SocketAddr};
use axum::{error_handling::HandleErrorExt, http::StatusCode, service, Router};
use std::net::SocketAddr;
use tower_http::{services::ServeDir, trace::TraceLayer};
#[tokio::main]
@ -23,10 +23,10 @@ async fn main() {
.nest(
"/static",
service::get(ServeDir::new(".")).handle_error(|error: std::io::Error| {
Ok::<_, Infallible>((
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
))
)
}),
)
.layer(TraceLayer::new_for_http());

View file

@ -14,6 +14,7 @@
//! ```
use axum::{
error_handling::HandleErrorLayer,
extract::{Extension, Path, Query},
handler::{get, patch},
http::StatusCode,
@ -23,7 +24,6 @@ use axum::{
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
convert::Infallible,
net::SocketAddr,
sync::{Arc, RwLock},
time::Duration,
@ -49,25 +49,21 @@ async fn main() {
// Add middleware to all routes
.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|error: BoxError| {
if error.is::<tower::timeout::error::Elapsed>() {
Ok(StatusCode::REQUEST_TIMEOUT)
} else {
Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
))
}
}))
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
.layer(AddExtensionLayer::new(db))
.into_inner(),
)
.handle_error(|error: BoxError| {
let result = if error.is::<tower::timeout::error::Elapsed>() {
Ok(StatusCode::REQUEST_TIMEOUT)
} else {
Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
))
};
Ok::<_, Infallible>(result)
})
// Make sure all errors have been handled
.check_infallible();
);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);

View file

@ -7,6 +7,7 @@
//! ```
use axum::{
error_handling::HandleErrorExt,
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
TypedHeader,
@ -38,10 +39,10 @@ async fn main() {
ServeDir::new("examples/websockets/assets").append_index_html_on_directories(true),
)
.handle_error(|error: std::io::Error| {
Ok::<_, std::convert::Infallible>((
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
))
)
}),
)
// routes are matched from bottom to top, so we have to put `nest` at the

199
src/error_handling/mod.rs Normal file
View file

@ -0,0 +1,199 @@
//! Error handling utilities
//!
//! See [error handling](../index.html#error-handling) for more details on how
//! error handling works in axum.
use crate::{
body::{box_body, BoxBody},
response::IntoResponse,
BoxError,
};
use bytes::Bytes;
use futures_util::ready;
use http::{Request, Response};
use pin_project_lite::pin_project;
use std::convert::Infallible;
use std::{
fmt,
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use tower::{util::Oneshot, ServiceExt as _};
use tower_layer::Layer;
use tower_service::Service;
/// [`Layer`] that applies [`HandleErrorLayer`] which is a [`Service`] adapter
/// that handles errors by converting them into responses.
///
/// See [error handling](../index.html#error-handling) for more details.
pub struct HandleErrorLayer<F, B> {
f: F,
_marker: PhantomData<fn() -> B>,
}
impl<F, B> HandleErrorLayer<F, B> {
/// Create a new `HandleErrorLayer`.
pub fn new(f: F) -> Self {
Self {
f,
_marker: PhantomData,
}
}
}
impl<F, B, S> Layer<S> for HandleErrorLayer<F, B>
where
F: Clone,
{
type Service = HandleError<S, F, B>;
fn layer(&self, inner: S) -> Self::Service {
HandleError {
inner,
f: self.f.clone(),
_marker: PhantomData,
}
}
}
impl<F, B> Clone for HandleErrorLayer<F, B>
where
F: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
_marker: PhantomData,
}
}
}
impl<F, B> fmt::Debug for HandleErrorLayer<F, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HandleErrorLayer")
.field("f", &format_args!("{}", std::any::type_name::<F>()))
.finish()
}
}
/// A [`Service`] adapter that handles errors by converting them into responses.
///
/// See [error handling](../index.html#error-handling) for more details.
pub struct HandleError<S, F, B> {
inner: S,
f: F,
_marker: PhantomData<fn() -> B>,
}
impl<S, F, B> Clone for HandleError<S, F, B>
where
S: Clone,
F: Clone,
{
fn clone(&self) -> Self {
Self::new(self.inner.clone(), self.f.clone())
}
}
impl<S, F, B> HandleError<S, F, B> {
/// Create a new `HandleError`.
pub fn new(inner: S, f: F) -> Self {
Self {
inner,
f,
_marker: PhantomData,
}
}
}
impl<S, F, B> fmt::Debug for HandleError<S, F, B>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HandleError")
.field("inner", &self.inner)
.field("f", &format_args!("{}", std::any::type_name::<F>()))
.finish()
}
}
impl<S, F, ReqBody, ResBody, Res> Service<Request<ReqBody>> for HandleError<S, F, ReqBody>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
F: FnOnce(S::Error) -> Res + Clone,
Res: IntoResponse,
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
ResBody::Error: Into<BoxError> + Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = HandleErrorFuture<Oneshot<S, Request<ReqBody>>, F>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
HandleErrorFuture {
f: Some(self.f.clone()),
inner: self.inner.clone().oneshot(req),
}
}
}
/// Handle errors this service might produce, by mapping them to responses.
///
/// See [error handling](../index.html#error-handling) for more details.
pub trait HandleErrorExt<B>: Service<Request<B>> + Sized {
/// Apply a [`HandleError`] middleware.
fn handle_error<F>(self, f: F) -> HandleError<Self, F, B> {
HandleError::new(self, f)
}
}
impl<B, S> HandleErrorExt<B> for S where S: Service<Request<B>> {}
pin_project! {
/// Response future for [`HandleError`](super::HandleError).
#[derive(Debug)]
pub struct HandleErrorFuture<Fut, F> {
#[pin]
pub(super) inner: Fut,
pub(super) f: Option<F>,
}
}
impl<Fut, F, E, B, Res> Future for HandleErrorFuture<Fut, F>
where
Fut: Future<Output = Result<Response<B>, E>>,
F: FnOnce(E) -> Res,
Res: IntoResponse,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{
type Output = Result<Response<BoxBody>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match ready!(this.inner.poll(cx)) {
Ok(res) => Ok(res.map(box_body)).into(),
Err(err) => {
let f = this.f.take().unwrap();
let res = f(err);
Ok(res.into_response().map(box_body)).into()
}
}
}
}
#[test]
fn traits() {
use crate::tests::*;
assert_send::<HandleError<(), (), NotSendSync>>();
assert_sync::<HandleError<(), (), NotSendSync>>();
}

View file

@ -5,7 +5,6 @@ use crate::{
extract::{FromRequest, RequestParts},
response::IntoResponse,
routing::{EmptyRouter, MethodFilter},
service::HandleError,
util::Either,
BoxError,
};
@ -264,9 +263,6 @@ pub trait Handler<B, T>: Clone + Send + Sized + 'static {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// When adding middleware that might fail its recommended to handle those
/// errors. See [`Layered::handle_error`] for more details.
fn layer<L>(self, layer: L) -> Layered<L::Service, T>
where
L: Layer<OnMethod<Self, B, T, EmptyRouter>>,
@ -426,28 +422,6 @@ impl<S, T> Layered<S, T> {
_input: PhantomData,
}
}
/// Create a new [`Layered`] handler where errors will be handled using the
/// given closure.
///
/// This is used to convert errors to responses rather than simply
/// terminating the connection.
///
/// It works similarly to [`routing::Router::handle_error`]. See that for more details.
///
/// [`routing::Router::handle_error`]: crate::routing::Router::handle_error
pub fn handle_error<F, ReqBody, ResBody, Res, E>(
self,
f: F,
) -> Layered<HandleError<S, F, ReqBody>, T>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
F: FnOnce(S::Error) -> Result<Res, E>,
Res: IntoResponse,
{
let svc = HandleError::new(self.svc, f);
Layered::new(svc)
}
}
/// A handler [`Service`] that accepts requests based on a [`MethodFilter`] and

View file

@ -7,7 +7,6 @@
//! - [Handlers](#handlers)
//! - [Debugging handler type errors](#debugging-handler-type-errors)
//! - [Routing](#routing)
//! - [Matching multiple methods](#matching-multiple-methods)
//! - [Routing to any `Service`](#routing-to-any-service)
//! - [Routing to fallible services](#routing-to-fallible-services)
//! - [Wildcard routes](#wildcard-routes)
@ -18,10 +17,10 @@
//! - [Optional extractors](#optional-extractors)
//! - [Customizing extractor responses](#customizing-extractor-responses)
//! - [Building responses](#building-responses)
//! - [Error handling](#error-handling)
//! - [Applying middleware](#applying-middleware)
//! - [To individual handlers](#to-individual-handlers)
//! - [To groups of routes](#to-groups-of-routes)
//! - [Error handling](#error-handling)
//! - [Applying multiple middleware](#applying-multiple-middleware)
//! - [Commonly used middleware](#commonly-used-middleware)
//! - [Writing your own middleware](#writing-your-own-middleware)
@ -186,14 +185,15 @@
//!
//! ```rust,no_run
//! use axum::{
//! body::Body,
//! http::Request,
//! Router,
//! service
//! service,
//! body::Body,
//! error_handling::HandleErrorExt,
//! http::{Request, StatusCode},
//! };
//! use tower_http::services::ServeFile;
//! use http::Response;
//! use std::convert::Infallible;
//! use std::{convert::Infallible, io};
//! use tower::service_fn;
//!
//! let app = Router::new()
@ -205,7 +205,7 @@
//! // to have the response body mapped
//! service::any(service_fn(|_: Request<Body>| async {
//! let res = Response::new(Body::from("Hi from `GET /`"));
//! Ok(res)
//! Ok::<_, Infallible>(res)
//! }))
//! )
//! .route(
@ -216,13 +216,20 @@
//! let body = Body::from(format!("Hi from `{} /foo`", req.method()));
//! let body = axum::body::box_body(body);
//! let res = Response::new(body);
//! Ok(res)
//! Ok::<_, Infallible>(res)
//! })
//! )
//! .route(
//! // GET `/static/Cargo.toml` goes to a service from tower-http
//! "/static/Cargo.toml",
//! service::get(ServeFile::new("Cargo.toml"))
//! // though we must handle any potential errors
//! .handle_error(|error: io::Error| {
//! (
//! StatusCode::INTERNAL_SERVER_ERROR,
//! format!("Unhandled internal error: {}", error),
//! )
//! })
//! );
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
@ -270,12 +277,12 @@
//!
//! ```
//! use axum::{
//! Router,
//! service,
//! handler::get,
//! http::{Request, Response},
//! response::IntoResponse,
//! Router, service,
//! body::Body,
//! handler::get,
//! response::IntoResponse,
//! http::{Request, Response},
//! error_handling::HandleErrorExt,
//! };
//! use std::{io, convert::Infallible};
//! use tower::service_fn;
@ -293,8 +300,7 @@
//! .handle_error(handle_io_error),
//! );
//!
//! fn handle_io_error(error: io::Error) -> Result<impl IntoResponse, Infallible> {
//! # Ok(())
//! fn handle_io_error(error: io::Error) -> impl IntoResponse {
//! // ...
//! }
//! # async {
@ -664,7 +670,7 @@
//! 1. Use `Result<T, T::Rejection>` as your extractor like shown in ["Optional
//! extractors"](#optional-extractors). This works well if you're only using
//! the extractor in a single handler.
//! 2. Create your own extractor that in its [`FromRequest`] implementing calls
//! 2. Create your own extractor that in its [`FromRequest`] implemention calls
//! one of axum's built in extractors but returns a different response for
//! rejections. See the [customize-extractor-error] example for more details.
//!
@ -781,6 +787,26 @@
//! # };
//! ```
//!
//! # Error handling
//!
//! In the context of axum an "error" specifically means if a [`Service`]'s
//! response future resolves to `Err(Service::Error)`. That means async handler
//! functions can _never_ fail since they always produce a response and their
//! `Service::Error` type is [`Infallible`]. Returning statuses like 404 or 500
//! are _not_ errors.
//!
//! axum works this way because hyper will close the connection, without sending
//! a response, if an error is encountered. This is not desireable so axum makes
//! it impossible to forget to handle errors.
//!
//! Sometimes you need to route to fallible services or apply fallible
//! middleware in which case you need to handle the errors. That can be done
//! using things from [`error_handling`].
//!
//! You can find examples here:
//! - [Routing to fallible services](#routing-to-fallible-services)
//! - [Applying fallible middleware](#applying-multiple-middleware)
//!
//! # Applying middleware
//!
//! axum is designed to take full advantage of the tower and tower-http
@ -864,69 +890,6 @@
//! # };
//! ```
//!
//! ## Error handling
//!
//! Handlers created from async functions must always produce a response, even
//! when returning a `Result<T, E>` the error type must implement
//! [`IntoResponse`]. In practice this makes error handling very predictable and
//! easier to reason about.
//!
//! However when applying middleware, or embedding other tower services, errors
//! might happen. For example [`Timeout`] will return an error if the timeout
//! elapses. By default these errors will be propagated all the way up to hyper
//! where the connection will be closed. If that isn't desirable you can call
//! [`handle_error`](handler::Layered::handle_error) to handle errors from
//! adding a middleware to a handler:
//!
//! ```rust,no_run
//! use axum::{
//! handler::{get, Handler},
//! Router,
//! };
//! use tower::{
//! BoxError, timeout::{TimeoutLayer, error::Elapsed},
//! };
//! use std::{borrow::Cow, time::Duration, convert::Infallible};
//! use http::StatusCode;
//!
//! let app = Router::new()
//! .route(
//! "/",
//! get(handle
//! .layer(TimeoutLayer::new(Duration::from_secs(30)))
//! // `Timeout` uses `BoxError` as the error type
//! .handle_error(|error: BoxError| {
//! // Check if the actual error type is `Elapsed` which
//! // `Timeout` returns
//! if error.is::<Elapsed>() {
//! return Ok::<_, Infallible>((
//! StatusCode::REQUEST_TIMEOUT,
//! "Request took too long".into(),
//! ));
//! }
//!
//! // If we encounter some error we don't handle return a generic
//! // error
//! return Ok::<_, Infallible>((
//! StatusCode::INTERNAL_SERVER_ERROR,
//! // `Cow` lets us return either `&str` or `String`
//! Cow::from(format!("Unhandled internal error: {}", error)),
//! ));
//! })),
//! );
//!
//! async fn handle() {}
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! The closure passed to [`handle_error`](handler::Layered::handle_error) must
//! return `Result<T, E>` where `T` implements
//! [`IntoResponse`](response::IntoResponse).
//!
//! See [`routing::Router::handle_error`] for more details.
//!
//! ## Applying multiple middleware
//!
//! [`tower::ServiceBuilder`] can be used to combine multiple middleware:
@ -935,14 +898,21 @@
//! use axum::{
//! body::Body,
//! handler::get,
//! http::Request,
//! Router,
//! http::{Request, StatusCode},
//! error_handling::HandleErrorLayer,
//! response::IntoResponse,
//! Router, BoxError,
//! };
//! use tower::ServiceBuilder;
//! use tower_http::compression::CompressionLayer;
//! use std::{borrow::Cow, time::Duration};
//!
//! let middleware_stack = ServiceBuilder::new()
//! // Handle errors from middleware
//! //
//! // This middleware most be added above any fallible
//! // ones if you're using `ServiceBuilder`, due to how ordering works
//! .layer(HandleErrorLayer::new(handle_error))
//! // Return an error after 30 seconds
//! .timeout(Duration::from_secs(30))
//! // Shed load if we're receiving too many requests
@ -950,17 +920,25 @@
//! // Process at most 100 requests concurrently
//! .concurrency_limit(100)
//! // Compress response bodies
//! .layer(CompressionLayer::new())
//! .into_inner();
//! .layer(CompressionLayer::new());
//!
//! let app = Router::new()
//! .route("/", get(|_: Request<Body>| async { /* ... */ }))
//! .layer(middleware_stack);
//!
//! fn handle_error(error: BoxError) -> impl IntoResponse {
//! (
//! StatusCode::INTERNAL_SERVER_ERROR,
//! format!("Something went wrong: {}", error),
//! )
//! }
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! See [Error handling](#error-handling) for more details on general error handling in axum.
//!
//! ## Commonly used middleware
//!
//! [`tower::util`] and [`tower_http`] have a large collection of middleware that are compatible
@ -971,6 +949,7 @@
//! body::{Body, BoxBody},
//! handler::get,
//! http::{Request, Response},
//! error_handling::HandleErrorLayer,
//! Router,
//! };
//! use tower::{
@ -980,15 +959,23 @@
//! };
//! use std::convert::Infallible;
//! use tower_http::trace::TraceLayer;
//! #
//! # fn handle_error<T>(error: T) -> axum::http::StatusCode {
//! # axum::http::StatusCode::INTERNAL_SERVER_ERROR
//! # }
//!
//! let middleware_stack = ServiceBuilder::new()
//! // Handle errors from middleware
//! //
//! // This middleware most be added above any fallible
//! // ones if you're using `ServiceBuilder`, due to how ordering works
//! .layer(HandleErrorLayer::new(handle_error))
//! // `TraceLayer` adds high level tracing and logging
//! .layer(TraceLayer::new_for_http())
//! // `AsyncFilterLayer` lets you asynchronously transform the request
//! .layer(AsyncFilterLayer::new(map_request))
//! // `AndThenLayer` lets you asynchronously transform the response
//! .layer(AndThenLayer::new(map_response))
//! .into_inner();
//! .layer(AndThenLayer::new(map_response));
//!
//! async fn map_request(req: Request<Body>) -> Result<Request<Body>, Infallible> {
//! Ok(req)
@ -1010,6 +997,8 @@
//! a middleware. Among other things, this can be useful for doing authorization. See
//! [`extract::extractor_middleware()`] for more details.
//!
//! See [Error handling](#error-handling) for more details on general error handling in axum.
//!
//! ## Writing your own middleware
//!
//! You can also write you own middleware by implementing [`tower::Service`]:
@ -1166,7 +1155,7 @@
//! [`Service`]: tower::Service
//! [`Service::poll_ready`]: tower::Service::poll_ready
//! [`tower::Service`]: tower::Service
//! [`handle_error`]: routing::Router::handle_error
//! [`handle_error`]: error_handling::HandleErrorExt::handle_error
//! [tower-guides]: https://github.com/tower-rs/tower/tree/master/guides
//! [`Uuid`]: https://docs.rs/uuid/latest/uuid/
//! [`FromRequest`]: crate::extract::FromRequest
@ -1176,6 +1165,7 @@
//! [axum-debug]: https://docs.rs/axum-debug
//! [`debug_handler`]: https://docs.rs/axum-debug/latest/axum_debug/attr.debug_handler.html
//! [`Handler`]: crate::handler::Handler
//! [`Infallible`]: std::convert::Infallible
#![warn(
clippy::all,
@ -1227,6 +1217,7 @@ mod json;
mod util;
pub mod body;
pub mod error_handling;
pub mod extract;
pub mod handler;
pub mod response;

View file

@ -4,7 +4,6 @@ use crate::{
body::BoxBody,
clone_box_service::CloneBoxService,
routing::{FromEmptyRouter, UriStack},
BoxError,
};
use http::{Request, Response};
use pin_project_lite::pin_project;
@ -28,33 +27,24 @@ opaque_future! {
pin_project! {
/// The response future for [`BoxRoute`](super::BoxRoute).
pub struct BoxRouteFuture<B, E>
where
E: Into<BoxError>,
{
pub struct BoxRouteFuture<B> {
#[pin]
pub(super) inner: Oneshot<
CloneBoxService<Request<B>, Response<BoxBody>, E>,
CloneBoxService<Request<B>, Response<BoxBody>, Infallible>,
Request<B>,
>,
}
}
impl<B, E> Future for BoxRouteFuture<B, E>
where
E: Into<BoxError>,
{
type Output = Result<Response<BoxBody>, E>;
impl<B> Future for BoxRouteFuture<B> {
type Output = Result<Response<BoxBody>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
impl<B, E> fmt::Debug for BoxRouteFuture<B, E>
where
E: Into<BoxError>,
{
impl<B> fmt::Debug for BoxRouteFuture<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BoxRouteFuture").finish()
}
@ -112,11 +102,11 @@ pin_project! {
impl<S, F, B> Future for RouteFuture<S, F, B>
where
S: Service<Request<B>, Response = Response<BoxBody>>,
F: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error>,
S: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible>,
F: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible>,
B: Send + Sync + 'static,
{
type Output = Result<Response<BoxBody>, S::Error>;
type Output = Result<Response<BoxBody>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().state.project() {
@ -141,11 +131,11 @@ pin_project! {
impl<S, F, B> Future for NestedFuture<S, F, B>
where
S: Service<Request<B>, Response = Response<BoxBody>>,
F: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error>,
S: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible>,
F: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible>,
B: Send + Sync + 'static,
{
type Output = Result<Response<BoxBody>, S::Error>;
type Output = Result<Response<BoxBody>, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut res: Response<_> = futures_util::ready!(self.project().inner.poll(cx)?);

View file

@ -8,7 +8,6 @@ use crate::{
connect_info::{Connected, IntoMakeServiceWithConnectInfo},
OriginalUri,
},
service::HandleError,
util::{ByteStr, PercentDecodedByteStr},
BoxError,
};
@ -64,7 +63,7 @@ enum MaybeSharedNode {
Shared(Arc<Node<RouteId>>),
}
impl<E> Router<EmptyRouter<E>> {
impl Router<EmptyRouter> {
/// Create a new `Router`.
///
/// Unless you add additional routes this will respond to `404 Not Found` to
@ -77,7 +76,7 @@ impl<E> Router<EmptyRouter<E>> {
}
}
impl<E> Default for Router<EmptyRouter<E>> {
impl Default for Router<EmptyRouter> {
fn default() -> Self {
Self::new()
}
@ -176,7 +175,10 @@ impl<S> Router<S> {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn route<T>(mut self, path: &str, svc: T) -> Router<Route<T, S>> {
pub fn route<T, B>(mut self, path: &str, svc: T) -> Router<Route<T, S>>
where
T: Service<Request<B>, Error = Infallible>,
{
let id = RouteId::next();
if let Err(err) = self.update_node(|node| node.insert(path, id)) {
@ -262,11 +264,20 @@ impl<S> Router<S> {
/// use axum::{
/// Router,
/// service::get,
/// error_handling::HandleErrorExt,
/// http::StatusCode,
/// };
/// use std::{io, convert::Infallible};
/// use tower_http::services::ServeDir;
///
/// // Serves files inside the `public` directory at `GET /public/*`
/// let serve_dir_service = ServeDir::new("public");
/// let serve_dir_service = ServeDir::new("public")
/// .handle_error(|error: io::Error| {
/// (
/// StatusCode::INTERNAL_SERVER_ERROR,
/// format!("Unhandled internal error: {}", error),
/// )
/// });
///
/// let app = Router::new().nest("/public", get(serve_dir_service));
/// # async {
@ -305,7 +316,10 @@ impl<S> Router<S> {
/// for more details.
///
/// [`OriginalUri`]: crate::extract::OriginalUri
pub fn nest<T>(mut self, path: &str, svc: T) -> Router<Nested<T, S>> {
pub fn nest<T, B>(mut self, path: &str, svc: T) -> Router<Nested<T, S>>
where
T: Service<Request<B>, Error = Infallible>,
{
let id = RouteId::next();
if path.contains('*') {
@ -361,10 +375,13 @@ impl<S> Router<S> {
///
/// It also helps with compile times when you have a very large number of
/// routes.
pub fn boxed<ReqBody, ResBody>(self) -> Router<BoxRoute<ReqBody, S::Error>>
pub fn boxed<ReqBody, ResBody>(self) -> Router<BoxRoute<ReqBody>>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
S::Error: Into<BoxError> + Send,
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible>
+ Clone
+ Send
+ Sync
+ 'static,
S::Future: Send,
ReqBody: Send + 'static,
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
@ -374,8 +391,7 @@ impl<S> Router<S> {
ServiceBuilder::new()
.layer_fn(BoxRoute)
.layer_fn(CloneBoxService::new)
.layer(MapResponseBodyLayer::new(box_body))
.into_inner(),
.layer(MapResponseBodyLayer::new(box_body)),
)
}
@ -601,100 +617,19 @@ impl<S> Router<S> {
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn or<S2>(self, other: S2) -> Router<Or<S, S2>> {
pub fn or<T, B>(self, other: T) -> Router<Or<S, T>>
where
T: Service<Request<B>, Error = Infallible>,
{
self.map(|first| Or {
first,
second: other,
})
}
/// Handle errors services in this router might produce, by mapping them to
/// responses.
///
/// Unhandled errors will close the connection without sending a response.
///
/// # Example
///
/// ```
/// use axum::{
/// handler::get,
/// http::StatusCode,
/// Router,
/// };
/// use tower::{BoxError, timeout::TimeoutLayer};
/// use std::{time::Duration, convert::Infallible};
///
/// // This router can never fail, since handlers can never fail.
/// let app = Router::new().route("/", get(|| async {}));
///
/// // Now the router can fail since the `tower::timeout::Timeout`
/// // middleware will return an error if the timeout elapses.
/// let app = app.layer(TimeoutLayer::new(Duration::from_secs(10)));
///
/// // With `handle_error` we can handle errors `Timeout` might produce.
/// // Our router now cannot fail, that is its error type is `Infallible`.
/// let app = app.handle_error(|error: BoxError| {
/// if error.is::<tower::timeout::error::Elapsed>() {
/// Ok::<_, Infallible>((
/// StatusCode::REQUEST_TIMEOUT,
/// "request took too long to handle".to_string(),
/// ))
/// } else {
/// Ok::<_, Infallible>((
/// StatusCode::INTERNAL_SERVER_ERROR,
/// format!("Unhandled error: {}", error),
/// ))
/// }
/// });
/// # async {
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// You can return `Err(_)` from the closure if you don't wish to handle
/// some errors:
///
/// ```
/// use axum::{
/// handler::get,
/// http::StatusCode,
/// Router,
/// };
/// use tower::{BoxError, timeout::TimeoutLayer};
/// use std::time::Duration;
///
/// let app = Router::new()
/// .route("/", get(|| async {}))
/// .layer(TimeoutLayer::new(Duration::from_secs(10)))
/// .handle_error(|error: BoxError| {
/// if error.is::<tower::timeout::error::Elapsed>() {
/// Ok((
/// StatusCode::REQUEST_TIMEOUT,
/// "request took too long to handle".to_string(),
/// ))
/// } else {
/// // return the error as is
/// Err(error)
/// }
/// });
/// # async {
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn handle_error<ReqBody, F>(self, f: F) -> Router<HandleError<S, F, ReqBody>> {
self.map(|svc| HandleError::new(svc, f))
}
/// Check that your service cannot fail.
///
/// That is, its error type is [`Infallible`].
pub fn check_infallible(self) -> Router<CheckInfallible<S>> {
self.map(CheckInfallible)
}
fn map<F, S2>(self, f: F) -> Router<S2>
fn map<F, T>(self, f: F) -> Router<T>
where
F: FnOnce(S) -> S2,
F: FnOnce(S) -> T,
{
Router {
routes: f(self.routes),
@ -740,11 +675,11 @@ impl<S> Router<S> {
impl<ReqBody, ResBody, S> Service<Request<ReqBody>> for Router<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible> + Clone,
ReqBody: Send + Sync + 'static,
{
type Response = Response<ResBody>;
type Error = S::Error;
type Error = Infallible;
type Future = S::Future;
#[inline]
@ -949,12 +884,12 @@ pub struct Route<S, T> {
impl<B, S, T> Service<Request<B>> for Route<S, T>
where
S: Service<Request<B>, Response = Response<BoxBody>> + Clone,
T: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error> + Clone,
S: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible> + Clone,
T: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible> + Clone,
B: Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = S::Error;
type Error = Infallible;
type Future = RouteFuture<S, T, B>;
#[inline]
@ -988,12 +923,12 @@ pub struct Nested<S, T> {
impl<B, S, T> Service<Request<B>> for Nested<S, T>
where
S: Service<Request<B>, Response = Response<BoxBody>> + Clone,
T: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error> + Clone,
S: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible> + Clone,
T: Service<Request<B>, Response = Response<BoxBody>, Error = Infallible> + Clone,
B: Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = S::Error;
type Error = Infallible;
type Future = NestedFuture<S, T, B>;
#[inline]
@ -1020,29 +955,26 @@ where
/// A boxed route trait object.
///
/// See [`Router::boxed`] for more details.
pub struct BoxRoute<B = crate::body::Body, E = Infallible>(
CloneBoxService<Request<B>, Response<BoxBody>, E>,
pub struct BoxRoute<B = crate::body::Body>(
CloneBoxService<Request<B>, Response<BoxBody>, Infallible>,
);
impl<B, E> Clone for BoxRoute<B, E> {
impl<B> Clone for BoxRoute<B> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<B, E> fmt::Debug for BoxRoute<B, E> {
impl<B> fmt::Debug for BoxRoute<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BoxRoute").finish()
}
}
impl<B, E> Service<Request<B>> for BoxRoute<B, E>
where
E: Into<BoxError>,
{
impl<B> Service<Request<B>> for BoxRoute<B> {
type Response = Response<BoxBody>;
type Error = E;
type Future = BoxRouteFuture<B, E>;
type Error = Infallible;
type Future = BoxRouteFuture<B>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -1086,31 +1018,6 @@ fn with_path(uri: &Uri, new_path: &str) -> Uri {
Uri::from_parts(parts).unwrap()
}
/// Middleware that statically verifies that a service cannot fail.
///
/// Created with [`check_infallible`](Router::check_infallible).
#[derive(Debug, Clone, Copy)]
pub struct CheckInfallible<S>(S);
impl<R, S> Service<R> for CheckInfallible<S>
where
S: Service<R, Error = Infallible>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
#[inline]
fn call(&mut self, req: R) -> Self::Future {
self.0.call(req)
}
}
/// A [`MakeService`] that produces axum router services.
///
/// [`MakeService`]: tower::make::MakeService
@ -1161,15 +1068,12 @@ mod tests {
assert_send::<EmptyRouter<NotSendSync>>();
assert_sync::<EmptyRouter<NotSendSync>>();
assert_send::<BoxRoute<(), ()>>();
assert_sync::<BoxRoute<(), ()>>();
assert_send::<BoxRoute<()>>();
assert_sync::<BoxRoute<()>>();
assert_send::<Nested<(), ()>>();
assert_sync::<Nested<(), ()>>();
assert_send::<CheckInfallible<()>>();
assert_sync::<CheckInfallible<()>>();
assert_send::<IntoMakeService<()>>();
assert_sync::<IntoMakeService<()>>();
}

View file

@ -6,6 +6,7 @@ use futures_util::ready;
use http::{Request, Response};
use pin_project_lite::pin_project;
use std::{
convert::Infallible,
future::Future,
pin::Pin,
task::{Context, Poll},
@ -33,8 +34,8 @@ fn traits() {
impl<A, B, ReqBody> Service<Request<ReqBody>> for Or<A, B>
where
A: Service<Request<ReqBody>, Response = Response<BoxBody>> + Clone,
B: Service<Request<ReqBody>, Response = Response<BoxBody>, Error = A::Error> + Clone,
A: Service<Request<ReqBody>, Response = Response<BoxBody>, Error = Infallible> + Clone,
B: Service<Request<ReqBody>, Response = Response<BoxBody>, Error = Infallible> + Clone,
ReqBody: Send + Sync + 'static,
A: Send + 'static,
B: Send + 'static,
@ -42,7 +43,7 @@ where
B::Future: Send + 'static,
{
type Response = Response<BoxBody>;
type Error = A::Error;
type Error = Infallible;
type Future = ResponseFuture<A, B, ReqBody>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -89,11 +90,11 @@ pin_project! {
impl<A, B, ReqBody> Future for ResponseFuture<A, B, ReqBody>
where
A: Service<Request<ReqBody>, Response = Response<BoxBody>>,
B: Service<Request<ReqBody>, Response = Response<BoxBody>, Error = A::Error>,
A: Service<Request<ReqBody>, Response = Response<BoxBody>, Error = Infallible>,
B: Service<Request<ReqBody>, Response = Response<BoxBody>, Error = Infallible>,
ReqBody: Send + Sync + 'static,
{
type Output = Result<Response<BoxBody>, A::Error>;
type Output = Result<Response<BoxBody>, Infallible>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {

View file

@ -2,7 +2,6 @@
use crate::{
body::{box_body, BoxBody},
response::IntoResponse,
util::{Either, EitherProj},
BoxError,
};
@ -19,42 +18,6 @@ use std::{
use tower::util::Oneshot;
use tower_service::Service;
pin_project! {
/// Response future for [`HandleError`](super::HandleError).
#[derive(Debug)]
pub struct HandleErrorFuture<Fut, F> {
#[pin]
pub(super) inner: Fut,
pub(super) f: Option<F>,
}
}
impl<Fut, F, E, E2, B, Res> Future for HandleErrorFuture<Fut, F>
where
Fut: Future<Output = Result<Response<B>, E>>,
F: FnOnce(E) -> Result<Res, E2>,
Res: IntoResponse,
B: http_body::Body<Data = Bytes> + Send + Sync + 'static,
B::Error: Into<BoxError> + Send + Sync + 'static,
{
type Output = Result<Response<BoxBody>, E2>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match ready!(this.inner.poll(cx)) {
Ok(res) => Ok(res.map(box_body)).into(),
Err(err) => {
let f = this.f.take().unwrap();
match f(err) {
Ok(res) => Ok(res.into_response().map(box_body)).into(),
Err(err) => Err(err).into(),
}
}
}
}
}
pin_project! {
/// The response future for [`OnMethod`](super::OnMethod).
pub struct OnMethodFuture<S, F, B>

View file

@ -99,17 +99,15 @@
use crate::BoxError;
use crate::{
body::BoxBody,
response::IntoResponse,
routing::{EmptyRouter, MethodFilter},
};
use bytes::Bytes;
use http::{Request, Response};
use std::{
fmt,
marker::PhantomData,
task::{Context, Poll},
};
use tower::{util::Oneshot, ServiceExt as _};
use tower::ServiceExt as _;
use tower_service::Service;
pub mod future;
@ -481,18 +479,6 @@ impl<S, F, B> OnMethod<S, F, B> {
_request_body: PhantomData,
}
}
/// Handle errors this service might produce, by mapping them to responses.
///
/// Unhandled errors will close the connection without sending a response.
///
/// Works similarly to [`Router::handle_error`]. See that for more
/// details.
///
/// [`Router::handle_error`]: crate::routing::Router::handle_error
pub fn handle_error<ReqBody, H>(self, f: H) -> HandleError<Self, H, ReqBody> {
HandleError::new(self, f)
}
}
// this is identical to `routing::OnMethod`'s implementation. Would be nice to find a way to clean
@ -532,81 +518,10 @@ where
}
}
/// A [`Service`] adapter that handles errors with a closure.
///
/// Created with
/// [`handler::Layered::handle_error`](crate::handler::Layered::handle_error) or
/// [`routing::Router::handle_error`](crate::routing::Router::handle_error).
/// See those methods for more details.
pub struct HandleError<S, F, B> {
inner: S,
f: F,
_marker: PhantomData<fn() -> B>,
}
impl<S, F, B> Clone for HandleError<S, F, B>
where
S: Clone,
F: Clone,
{
fn clone(&self) -> Self {
Self::new(self.inner.clone(), self.f.clone())
}
}
impl<S, F, B> HandleError<S, F, B> {
pub(crate) fn new(inner: S, f: F) -> Self {
Self {
inner,
f,
_marker: PhantomData,
}
}
}
impl<S, F, B> fmt::Debug for HandleError<S, F, B>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HandleError")
.field("inner", &self.inner)
.field("f", &format_args!("{}", std::any::type_name::<F>()))
.finish()
}
}
impl<S, F, ReqBody, ResBody, Res, E> Service<Request<ReqBody>> for HandleError<S, F, ReqBody>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
F: FnOnce(S::Error) -> Result<Res, E> + Clone,
Res: IntoResponse,
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
ResBody::Error: Into<BoxError> + Send + Sync + 'static,
{
type Response = Response<BoxBody>;
type Error = E;
type Future = future::HandleErrorFuture<Oneshot<S, Request<ReqBody>>, F>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
future::HandleErrorFuture {
f: Some(self.f.clone()),
inner: self.inner.clone().oneshot(req),
}
}
}
#[test]
fn traits() {
use crate::tests::*;
assert_send::<OnMethod<(), (), NotSendSync>>();
assert_sync::<OnMethod<(), (), NotSendSync>>();
assert_send::<HandleError<(), (), NotSendSync>>();
assert_sync::<HandleError<(), (), NotSendSync>>();
}

View file

@ -1,6 +1,6 @@
use super::*;
use std::future::{pending, ready};
use tower::{timeout::TimeoutLayer, MakeService};
use tower::{timeout::TimeoutLayer, ServiceBuilder};
async fn unit() {}
@ -29,23 +29,17 @@ impl<R> Service<R> for Svc {
}
}
fn check_make_svc<M, R, T, E>(_make_svc: M)
where
M: MakeService<(), R, Response = T, Error = E>,
{
}
fn handle_error<E>(_: E) -> Result<StatusCode, Infallible> {
Ok(StatusCode::INTERNAL_SERVER_ERROR)
}
#[tokio::test]
async fn handler() {
let app = Router::new().route(
"/",
get(forever
.layer(timeout())
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT))),
get(forever.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| {
StatusCode::REQUEST_TIMEOUT
}))
.layer(timeout()),
)),
);
let client = TestClient::new(app);
@ -58,9 +52,13 @@ async fn handler() {
async fn handler_multiple_methods_first() {
let app = Router::new().route(
"/",
get(forever
.layer(timeout())
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)))
get(forever.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| {
StatusCode::REQUEST_TIMEOUT
}))
.layer(timeout()),
))
.post(unit),
);
@ -76,9 +74,13 @@ async fn handler_multiple_methods_middle() {
"/",
delete(unit)
.get(
forever
.layer(timeout())
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)),
forever.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| {
StatusCode::REQUEST_TIMEOUT
}))
.layer(timeout()),
),
)
.post(unit),
);
@ -94,9 +96,13 @@ async fn handler_multiple_methods_last() {
let app = Router::new().route(
"/",
delete(unit).get(
forever
.layer(timeout())
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)),
forever.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| {
StatusCode::REQUEST_TIMEOUT
}))
.layer(timeout()),
),
),
);
@ -105,82 +111,3 @@ async fn handler_multiple_methods_last() {
let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
}
#[test]
fn service_propagates_errors() {
let app = Router::new().route("/echo", service::post::<_, Body>(Svc));
check_make_svc::<_, _, _, hyper::Error>(app.into_make_service());
}
#[test]
fn service_nested_propagates_errors() {
let app = Router::new().route(
"/echo",
Router::new().nest("/foo", service::post::<_, Body>(Svc)),
);
check_make_svc::<_, _, _, hyper::Error>(app.into_make_service());
}
#[test]
fn service_handle_on_method() {
let app = Router::new().route(
"/echo",
service::get::<_, Body>(Svc).handle_error(handle_error::<hyper::Error>),
);
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
}
#[test]
fn service_handle_on_method_multiple() {
let app = Router::new().route(
"/echo",
service::get::<_, Body>(Svc)
.post(Svc)
.handle_error(handle_error::<hyper::Error>),
);
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
}
#[test]
fn service_handle_on_router() {
let app = Router::new()
.route("/echo", service::get::<_, Body>(Svc))
.handle_error(handle_error::<hyper::Error>);
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
}
#[test]
fn service_handle_on_router_still_impls_routing_dsl() {
let app = Router::new()
.route("/echo", service::get::<_, Body>(Svc))
.handle_error(handle_error::<hyper::Error>)
.route("/", get(unit));
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
}
#[test]
fn layered() {
let app = Router::new()
.route("/echo", get::<_, Body, _>(unit))
.layer(timeout())
.handle_error(handle_error::<BoxError>);
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
}
#[tokio::test] // async because of `.boxed()`
async fn layered_boxed() {
let app = Router::new()
.route("/echo", get::<_, Body, _>(unit))
.layer(timeout())
.boxed()
.handle_error(handle_error::<BoxError>);
check_make_svc::<_, _, _, Infallible>(app.into_make_service());
}

View file

@ -1,5 +1,6 @@
#![allow(clippy::blacklisted_name)]
use crate::error_handling::HandleErrorLayer;
use crate::BoxError;
use crate::{
extract::{self, Path},
@ -339,7 +340,7 @@ async fn middleware_on_single_route() {
#[tokio::test]
async fn service_in_bottom() {
async fn handler(_req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
async fn handler(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
Ok(Response::new(hyper::Body::empty()))
}
@ -532,7 +533,9 @@ async fn middleware_applies_to_routes_above() {
let app = Router::new()
.route("/one", get(std::future::pending::<()>))
.layer(TimeoutLayer::new(Duration::new(0, 0)))
.handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT))
.layer(HandleErrorLayer::new(|_: BoxError| {
StatusCode::REQUEST_TIMEOUT
}))
.route("/two", get(|| async {}));
let client = TestClient::new(app);

View file

@ -1,5 +1,6 @@
use super::*;
use crate::body::box_body;
use crate::error_handling::HandleErrorExt;
use crate::routing::EmptyRouter;
use std::collections::HashMap;
@ -169,10 +170,10 @@ async fn nest_static_file_server() {
let app = Router::new().nest(
"/static",
service::get(tower_http::services::ServeDir::new(".")).handle_error(|error| {
Ok::<_, Infallible>((
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
))
)
}),
);
@ -255,5 +256,5 @@ async fn multiple_top_level_nests() {
#[tokio::test]
#[should_panic(expected = "Invalid route: nested routes cannot contain wildcards (*)")]
async fn nest_cannot_contain_wildcards() {
Router::<EmptyRouter>::new().nest("/one/*rest", Router::<EmptyRouter>::new());
Router::<EmptyRouter>::new().nest::<_, Body>("/one/*rest", Router::<EmptyRouter>::new());
}

View file

@ -1,5 +1,5 @@
use super::*;
use crate::{extract::OriginalUri, response::IntoResponse, Json};
use crate::{error_handling::HandleErrorLayer, extract::OriginalUri, response::IntoResponse, Json};
use serde_json::{json, Value};
use tower::{limit::ConcurrencyLimitLayer, timeout::TimeoutLayer};
@ -136,7 +136,7 @@ async fn layer_and_handle_error() {
let two = Router::new()
.route("/timeout", get(futures::future::pending::<()>))
.layer(TimeoutLayer::new(Duration::from_millis(10)))
.handle_error(|_| Ok(StatusCode::REQUEST_TIMEOUT));
.layer(HandleErrorLayer::new(|_| StatusCode::REQUEST_TIMEOUT));
let app = one.or(two);
let client = TestClient::new(app);