diff --git a/CHANGELOG.md b/CHANGELOG.md index 803e35ac..d236d9a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,8 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Adding a conflicting route will now cause a panic instead of silently making a route unreachable. - Route matching is faster as number of routes increase. - - The routes `/foo` and `/:key` are considered to overlap and will cause a - panic when constructing the router. This might be fixed in the future. + - **breaking:** The routes `/foo` and `/:key` are considered to overlap and + will cause a panic when constructing the router. This might be fixed in the future. - Improve performance of `BoxRoute` ([#339]) - Expand accepted content types for JSON requests ([#378]) - **breaking:** Automatically do percent decoding in `extract::Path` @@ -28,6 +28,73 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 thus wasn't necessary - **breaking:** Change `Connected::connect_info` to return `Self` and remove the associated type `ConnectInfo` ([#396]) +- **breaking:** Simplify error handling model ([#402]): + - All services part of the router are now required to be infallible. + - Error handling utilities have been moved to an `error_handling` module. + - `Router::check_infallible` has been removed since routers are always + infallible with the error handling changes. + - Error handling closures must now handle all errors and thus always return + something that implements `IntoResponse`. + + With these changes handling errors from fallible middleware is done like so: + + ```rust,no_run + use axum::{ + handler::get, + http::StatusCode, + error_handling::HandleErrorLayer, + response::IntoResponse, + Router, BoxError, + }; + use tower::ServiceBuilder; + use std::time::Duration; + + let middleware_stack = ServiceBuilder::new() + // Handle errors from middleware + // + // This middleware most be added above any fallible + // ones if you're using `ServiceBuilder`, due to how ordering works + .layer(HandleErrorLayer::new(handle_error)) + // Return an error after 30 seconds + .timeout(Duration::from_secs(30)); + + let app = Router::new() + .route("/", get(|| async { /* ... */ })) + .layer(middleware_stack); + + fn handle_error(_error: BoxError) -> impl IntoResponse { + StatusCode::REQUEST_TIMEOUT + } + ``` + + And handling errors from fallible leaf services is done like so: + + ```rust + use axum::{ + Router, service, + body::Body, + handler::get, + response::IntoResponse, + http::{Request, Response}, + error_handling::HandleErrorExt, // for `.handle_error` + }; + use std::{io, convert::Infallible}; + use tower::service_fn; + + let app = Router::new() + .route( + "/", + service::get(service_fn(|_req: Request| async { + let contents = tokio::fs::read_to_string("some_file").await?; + Ok::<_, io::Error>(Response::new(Body::from(contents))) + })) + .handle_error(handle_io_error), + ); + + fn handle_io_error(error: io::Error) -> impl IntoResponse { + // ... + } + ``` [#339]: https://github.com/tokio-rs/axum/pull/339 [#286]: https://github.com/tokio-rs/axum/pull/286 @@ -35,6 +102,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#378]: https://github.com/tokio-rs/axum/pull/378 [#363]: https://github.com/tokio-rs/axum/pull/363 [#396]: https://github.com/tokio-rs/axum/pull/396 +[#402]: https://github.com/tokio-rs/axum/pull/402 # 0.2.8 (07. October, 2021) diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index b8153181..c768b441 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -8,6 +8,7 @@ use axum::{ body::Bytes, + error_handling::HandleErrorLayer, extract::{ContentLengthLimit, Extension, Path}, handler::{delete, get, Handler}, http::StatusCode, @@ -18,7 +19,6 @@ use axum::{ use std::{ borrow::Cow, collections::HashMap, - convert::Infallible, net::SocketAddr, sync::{Arc, RwLock}, time::Duration, @@ -52,16 +52,15 @@ async fn main() { // Add middleware to all routes .layer( ServiceBuilder::new() + // Handle errors from middleware + .layer(HandleErrorLayer::new(handle_error)) .load_shed() .concurrency_limit(1024) .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) .layer(AddExtensionLayer::new(SharedState::default())) .into_inner(), - ) - // Handle errors from middleware - .handle_error(handle_error) - .check_infallible(); + ); // Run our app with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -126,20 +125,20 @@ fn admin_routes() -> Router { .boxed() } -fn handle_error(error: BoxError) -> Result { +fn handle_error(error: BoxError) -> impl IntoResponse { if error.is::() { - return Ok((StatusCode::REQUEST_TIMEOUT, Cow::from("request timed out"))); + return (StatusCode::REQUEST_TIMEOUT, Cow::from("request timed out")); } if error.is::() { - return Ok(( + return ( StatusCode::SERVICE_UNAVAILABLE, Cow::from("service is overloaded, try again later"), - )); + ); } - Ok(( + ( StatusCode::INTERNAL_SERVER_ERROR, Cow::from(format!("Unhandled internal error: {}", error)), - )) + ) } diff --git a/examples/print-request-response/src/main.rs b/examples/print-request-response/src/main.rs index 15fc7f6a..023e3c69 100644 --- a/examples/print-request-response/src/main.rs +++ b/examples/print-request-response/src/main.rs @@ -6,12 +6,13 @@ use axum::{ body::{Body, BoxBody, Bytes}, + error_handling::HandleErrorLayer, handler::post, - http::{Request, Response}, + http::{Request, Response, StatusCode}, Router, }; use std::net::SocketAddr; -use tower::{filter::AsyncFilterLayer, util::AndThenLayer, BoxError}; +use tower::{filter::AsyncFilterLayer, util::AndThenLayer, BoxError, ServiceBuilder}; #[tokio::main] async fn main() { @@ -26,8 +27,17 @@ async fn main() { let app = Router::new() .route("/", post(|| async move { "Hello from `POST /`" })) - .layer(AsyncFilterLayer::new(map_request)) - .layer(AndThenLayer::new(map_response)); + .layer( + ServiceBuilder::new() + .layer(HandleErrorLayer::new(|error| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Unhandled internal error: {}", error), + ) + })) + .layer(AndThenLayer::new(map_response)) + .layer(AsyncFilterLayer::new(map_request)), + ); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); diff --git a/examples/sse/src/main.rs b/examples/sse/src/main.rs index 53c49feb..406ad4c0 100644 --- a/examples/sse/src/main.rs +++ b/examples/sse/src/main.rs @@ -5,6 +5,7 @@ //! ``` use axum::{ + error_handling::HandleErrorExt, extract::TypedHeader, handler::get, http::StatusCode, @@ -28,10 +29,10 @@ async fn main() { ServeDir::new("examples/sse/assets").append_index_html_on_directories(true), ) .handle_error(|error: std::io::Error| { - Ok::<_, std::convert::Infallible>(( + ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), - )) + ) }); // build our application with a route diff --git a/examples/static-file-server/src/main.rs b/examples/static-file-server/src/main.rs index dbdc2904..ec449788 100644 --- a/examples/static-file-server/src/main.rs +++ b/examples/static-file-server/src/main.rs @@ -4,8 +4,8 @@ //! cargo run -p example-static-file-server //! ``` -use axum::{http::StatusCode, service, Router}; -use std::{convert::Infallible, net::SocketAddr}; +use axum::{error_handling::HandleErrorExt, http::StatusCode, service, Router}; +use std::net::SocketAddr; use tower_http::{services::ServeDir, trace::TraceLayer}; #[tokio::main] @@ -23,10 +23,10 @@ async fn main() { .nest( "/static", service::get(ServeDir::new(".")).handle_error(|error: std::io::Error| { - Ok::<_, Infallible>(( + ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), - )) + ) }), ) .layer(TraceLayer::new_for_http()); diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs index b5a204ef..5e4ac543 100644 --- a/examples/todos/src/main.rs +++ b/examples/todos/src/main.rs @@ -14,6 +14,7 @@ //! ``` use axum::{ + error_handling::HandleErrorLayer, extract::{Extension, Path, Query}, handler::{get, patch}, http::StatusCode, @@ -23,7 +24,6 @@ use axum::{ use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, - convert::Infallible, net::SocketAddr, sync::{Arc, RwLock}, time::Duration, @@ -49,25 +49,21 @@ async fn main() { // Add middleware to all routes .layer( ServiceBuilder::new() + .layer(HandleErrorLayer::new(|error: BoxError| { + if error.is::() { + Ok(StatusCode::REQUEST_TIMEOUT) + } else { + Err(( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Unhandled internal error: {}", error), + )) + } + })) .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) .layer(AddExtensionLayer::new(db)) .into_inner(), - ) - .handle_error(|error: BoxError| { - let result = if error.is::() { - Ok(StatusCode::REQUEST_TIMEOUT) - } else { - Err(( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Unhandled internal error: {}", error), - )) - }; - - Ok::<_, Infallible>(result) - }) - // Make sure all errors have been handled - .check_infallible(); + ); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); diff --git a/examples/websockets/src/main.rs b/examples/websockets/src/main.rs index 2170cd75..fbde0803 100644 --- a/examples/websockets/src/main.rs +++ b/examples/websockets/src/main.rs @@ -7,6 +7,7 @@ //! ``` use axum::{ + error_handling::HandleErrorExt, extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, TypedHeader, @@ -38,10 +39,10 @@ async fn main() { ServeDir::new("examples/websockets/assets").append_index_html_on_directories(true), ) .handle_error(|error: std::io::Error| { - Ok::<_, std::convert::Infallible>(( + ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), - )) + ) }), ) // routes are matched from bottom to top, so we have to put `nest` at the diff --git a/src/error_handling/mod.rs b/src/error_handling/mod.rs new file mode 100644 index 00000000..a925bdb6 --- /dev/null +++ b/src/error_handling/mod.rs @@ -0,0 +1,199 @@ +//! Error handling utilities +//! +//! See [error handling](../index.html#error-handling) for more details on how +//! error handling works in axum. + +use crate::{ + body::{box_body, BoxBody}, + response::IntoResponse, + BoxError, +}; +use bytes::Bytes; +use futures_util::ready; +use http::{Request, Response}; +use pin_project_lite::pin_project; +use std::convert::Infallible; +use std::{ + fmt, + future::Future, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; +use tower::{util::Oneshot, ServiceExt as _}; +use tower_layer::Layer; +use tower_service::Service; + +/// [`Layer`] that applies [`HandleErrorLayer`] which is a [`Service`] adapter +/// that handles errors by converting them into responses. +/// +/// See [error handling](../index.html#error-handling) for more details. +pub struct HandleErrorLayer { + f: F, + _marker: PhantomData B>, +} + +impl HandleErrorLayer { + /// Create a new `HandleErrorLayer`. + pub fn new(f: F) -> Self { + Self { + f, + _marker: PhantomData, + } + } +} + +impl Layer for HandleErrorLayer +where + F: Clone, +{ + type Service = HandleError; + + fn layer(&self, inner: S) -> Self::Service { + HandleError { + inner, + f: self.f.clone(), + _marker: PhantomData, + } + } +} + +impl Clone for HandleErrorLayer +where + F: Clone, +{ + fn clone(&self) -> Self { + Self { + f: self.f.clone(), + _marker: PhantomData, + } + } +} + +impl fmt::Debug for HandleErrorLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HandleErrorLayer") + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +/// A [`Service`] adapter that handles errors by converting them into responses. +/// +/// See [error handling](../index.html#error-handling) for more details. +pub struct HandleError { + inner: S, + f: F, + _marker: PhantomData B>, +} + +impl Clone for HandleError +where + S: Clone, + F: Clone, +{ + fn clone(&self) -> Self { + Self::new(self.inner.clone(), self.f.clone()) + } +} + +impl HandleError { + /// Create a new `HandleError`. + pub fn new(inner: S, f: F) -> Self { + Self { + inner, + f, + _marker: PhantomData, + } + } +} + +impl fmt::Debug for HandleError +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HandleError") + .field("inner", &self.inner) + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Service> for HandleError +where + S: Service, Response = Response> + Clone, + F: FnOnce(S::Error) -> Res + Clone, + Res: IntoResponse, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into + Send + Sync + 'static, +{ + type Response = Response; + type Error = Infallible; + type Future = HandleErrorFuture>, F>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + HandleErrorFuture { + f: Some(self.f.clone()), + inner: self.inner.clone().oneshot(req), + } + } +} + +/// Handle errors this service might produce, by mapping them to responses. +/// +/// See [error handling](../index.html#error-handling) for more details. +pub trait HandleErrorExt: Service> + Sized { + /// Apply a [`HandleError`] middleware. + fn handle_error(self, f: F) -> HandleError { + HandleError::new(self, f) + } +} + +impl HandleErrorExt for S where S: Service> {} + +pin_project! { + /// Response future for [`HandleError`](super::HandleError). + #[derive(Debug)] + pub struct HandleErrorFuture { + #[pin] + pub(super) inner: Fut, + pub(super) f: Option, + } +} + +impl Future for HandleErrorFuture +where + Fut: Future, E>>, + F: FnOnce(E) -> Res, + Res: IntoResponse, + B: http_body::Body + Send + Sync + 'static, + B::Error: Into + Send + Sync + 'static, +{ + type Output = Result, Infallible>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + match ready!(this.inner.poll(cx)) { + Ok(res) => Ok(res.map(box_body)).into(), + Err(err) => { + let f = this.f.take().unwrap(); + let res = f(err); + Ok(res.into_response().map(box_body)).into() + } + } + } +} + +#[test] +fn traits() { + use crate::tests::*; + + assert_send::>(); + assert_sync::>(); +} diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 2fd0f4c1..c0d1598b 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -5,7 +5,6 @@ use crate::{ extract::{FromRequest, RequestParts}, response::IntoResponse, routing::{EmptyRouter, MethodFilter}, - service::HandleError, util::Either, BoxError, }; @@ -264,9 +263,6 @@ pub trait Handler: Clone + Send + Sized + 'static { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - /// - /// When adding middleware that might fail its recommended to handle those - /// errors. See [`Layered::handle_error`] for more details. fn layer(self, layer: L) -> Layered where L: Layer>, @@ -426,28 +422,6 @@ impl Layered { _input: PhantomData, } } - - /// Create a new [`Layered`] handler where errors will be handled using the - /// given closure. - /// - /// This is used to convert errors to responses rather than simply - /// terminating the connection. - /// - /// It works similarly to [`routing::Router::handle_error`]. See that for more details. - /// - /// [`routing::Router::handle_error`]: crate::routing::Router::handle_error - pub fn handle_error( - self, - f: F, - ) -> Layered, T> - where - S: Service, Response = Response>, - F: FnOnce(S::Error) -> Result, - Res: IntoResponse, - { - let svc = HandleError::new(self.svc, f); - Layered::new(svc) - } } /// A handler [`Service`] that accepts requests based on a [`MethodFilter`] and diff --git a/src/lib.rs b/src/lib.rs index f65f9f00..d83aa0bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,7 +7,6 @@ //! - [Handlers](#handlers) //! - [Debugging handler type errors](#debugging-handler-type-errors) //! - [Routing](#routing) -//! - [Matching multiple methods](#matching-multiple-methods) //! - [Routing to any `Service`](#routing-to-any-service) //! - [Routing to fallible services](#routing-to-fallible-services) //! - [Wildcard routes](#wildcard-routes) @@ -18,10 +17,10 @@ //! - [Optional extractors](#optional-extractors) //! - [Customizing extractor responses](#customizing-extractor-responses) //! - [Building responses](#building-responses) +//! - [Error handling](#error-handling) //! - [Applying middleware](#applying-middleware) //! - [To individual handlers](#to-individual-handlers) //! - [To groups of routes](#to-groups-of-routes) -//! - [Error handling](#error-handling) //! - [Applying multiple middleware](#applying-multiple-middleware) //! - [Commonly used middleware](#commonly-used-middleware) //! - [Writing your own middleware](#writing-your-own-middleware) @@ -186,14 +185,15 @@ //! //! ```rust,no_run //! use axum::{ -//! body::Body, -//! http::Request, //! Router, -//! service +//! service, +//! body::Body, +//! error_handling::HandleErrorExt, +//! http::{Request, StatusCode}, //! }; //! use tower_http::services::ServeFile; //! use http::Response; -//! use std::convert::Infallible; +//! use std::{convert::Infallible, io}; //! use tower::service_fn; //! //! let app = Router::new() @@ -205,7 +205,7 @@ //! // to have the response body mapped //! service::any(service_fn(|_: Request| async { //! let res = Response::new(Body::from("Hi from `GET /`")); -//! Ok(res) +//! Ok::<_, Infallible>(res) //! })) //! ) //! .route( @@ -216,13 +216,20 @@ //! let body = Body::from(format!("Hi from `{} /foo`", req.method())); //! let body = axum::body::box_body(body); //! let res = Response::new(body); -//! Ok(res) +//! Ok::<_, Infallible>(res) //! }) //! ) //! .route( //! // GET `/static/Cargo.toml` goes to a service from tower-http //! "/static/Cargo.toml", //! service::get(ServeFile::new("Cargo.toml")) +//! // though we must handle any potential errors +//! .handle_error(|error: io::Error| { +//! ( +//! StatusCode::INTERNAL_SERVER_ERROR, +//! format!("Unhandled internal error: {}", error), +//! ) +//! }) //! ); //! # async { //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); @@ -270,12 +277,12 @@ //! //! ``` //! use axum::{ -//! Router, -//! service, -//! handler::get, -//! http::{Request, Response}, -//! response::IntoResponse, +//! Router, service, //! body::Body, +//! handler::get, +//! response::IntoResponse, +//! http::{Request, Response}, +//! error_handling::HandleErrorExt, //! }; //! use std::{io, convert::Infallible}; //! use tower::service_fn; @@ -293,8 +300,7 @@ //! .handle_error(handle_io_error), //! ); //! -//! fn handle_io_error(error: io::Error) -> Result { -//! # Ok(()) +//! fn handle_io_error(error: io::Error) -> impl IntoResponse { //! // ... //! } //! # async { @@ -664,7 +670,7 @@ //! 1. Use `Result` as your extractor like shown in ["Optional //! extractors"](#optional-extractors). This works well if you're only using //! the extractor in a single handler. -//! 2. Create your own extractor that in its [`FromRequest`] implementing calls +//! 2. Create your own extractor that in its [`FromRequest`] implemention calls //! one of axum's built in extractors but returns a different response for //! rejections. See the [customize-extractor-error] example for more details. //! @@ -781,6 +787,26 @@ //! # }; //! ``` //! +//! # Error handling +//! +//! In the context of axum an "error" specifically means if a [`Service`]'s +//! response future resolves to `Err(Service::Error)`. That means async handler +//! functions can _never_ fail since they always produce a response and their +//! `Service::Error` type is [`Infallible`]. Returning statuses like 404 or 500 +//! are _not_ errors. +//! +//! axum works this way because hyper will close the connection, without sending +//! a response, if an error is encountered. This is not desireable so axum makes +//! it impossible to forget to handle errors. +//! +//! Sometimes you need to route to fallible services or apply fallible +//! middleware in which case you need to handle the errors. That can be done +//! using things from [`error_handling`]. +//! +//! You can find examples here: +//! - [Routing to fallible services](#routing-to-fallible-services) +//! - [Applying fallible middleware](#applying-multiple-middleware) +//! //! # Applying middleware //! //! axum is designed to take full advantage of the tower and tower-http @@ -864,69 +890,6 @@ //! # }; //! ``` //! -//! ## Error handling -//! -//! Handlers created from async functions must always produce a response, even -//! when returning a `Result` the error type must implement -//! [`IntoResponse`]. In practice this makes error handling very predictable and -//! easier to reason about. -//! -//! However when applying middleware, or embedding other tower services, errors -//! might happen. For example [`Timeout`] will return an error if the timeout -//! elapses. By default these errors will be propagated all the way up to hyper -//! where the connection will be closed. If that isn't desirable you can call -//! [`handle_error`](handler::Layered::handle_error) to handle errors from -//! adding a middleware to a handler: -//! -//! ```rust,no_run -//! use axum::{ -//! handler::{get, Handler}, -//! Router, -//! }; -//! use tower::{ -//! BoxError, timeout::{TimeoutLayer, error::Elapsed}, -//! }; -//! use std::{borrow::Cow, time::Duration, convert::Infallible}; -//! use http::StatusCode; -//! -//! let app = Router::new() -//! .route( -//! "/", -//! get(handle -//! .layer(TimeoutLayer::new(Duration::from_secs(30))) -//! // `Timeout` uses `BoxError` as the error type -//! .handle_error(|error: BoxError| { -//! // Check if the actual error type is `Elapsed` which -//! // `Timeout` returns -//! if error.is::() { -//! return Ok::<_, Infallible>(( -//! StatusCode::REQUEST_TIMEOUT, -//! "Request took too long".into(), -//! )); -//! } -//! -//! // If we encounter some error we don't handle return a generic -//! // error -//! return Ok::<_, Infallible>(( -//! StatusCode::INTERNAL_SERVER_ERROR, -//! // `Cow` lets us return either `&str` or `String` -//! Cow::from(format!("Unhandled internal error: {}", error)), -//! )); -//! })), -//! ); -//! -//! async fn handle() {} -//! # async { -//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -//! # }; -//! ``` -//! -//! The closure passed to [`handle_error`](handler::Layered::handle_error) must -//! return `Result` where `T` implements -//! [`IntoResponse`](response::IntoResponse). -//! -//! See [`routing::Router::handle_error`] for more details. -//! //! ## Applying multiple middleware //! //! [`tower::ServiceBuilder`] can be used to combine multiple middleware: @@ -935,14 +898,21 @@ //! use axum::{ //! body::Body, //! handler::get, -//! http::Request, -//! Router, +//! http::{Request, StatusCode}, +//! error_handling::HandleErrorLayer, +//! response::IntoResponse, +//! Router, BoxError, //! }; //! use tower::ServiceBuilder; //! use tower_http::compression::CompressionLayer; //! use std::{borrow::Cow, time::Duration}; //! //! let middleware_stack = ServiceBuilder::new() +//! // Handle errors from middleware +//! // +//! // This middleware most be added above any fallible +//! // ones if you're using `ServiceBuilder`, due to how ordering works +//! .layer(HandleErrorLayer::new(handle_error)) //! // Return an error after 30 seconds //! .timeout(Duration::from_secs(30)) //! // Shed load if we're receiving too many requests @@ -950,17 +920,25 @@ //! // Process at most 100 requests concurrently //! .concurrency_limit(100) //! // Compress response bodies -//! .layer(CompressionLayer::new()) -//! .into_inner(); +//! .layer(CompressionLayer::new()); //! //! let app = Router::new() //! .route("/", get(|_: Request| async { /* ... */ })) //! .layer(middleware_stack); +//! +//! fn handle_error(error: BoxError) -> impl IntoResponse { +//! ( +//! StatusCode::INTERNAL_SERVER_ERROR, +//! format!("Something went wrong: {}", error), +//! ) +//! } //! # async { //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` //! +//! See [Error handling](#error-handling) for more details on general error handling in axum. +//! //! ## Commonly used middleware //! //! [`tower::util`] and [`tower_http`] have a large collection of middleware that are compatible @@ -971,6 +949,7 @@ //! body::{Body, BoxBody}, //! handler::get, //! http::{Request, Response}, +//! error_handling::HandleErrorLayer, //! Router, //! }; //! use tower::{ @@ -980,15 +959,23 @@ //! }; //! use std::convert::Infallible; //! use tower_http::trace::TraceLayer; +//! # +//! # fn handle_error(error: T) -> axum::http::StatusCode { +//! # axum::http::StatusCode::INTERNAL_SERVER_ERROR +//! # } //! //! let middleware_stack = ServiceBuilder::new() +//! // Handle errors from middleware +//! // +//! // This middleware most be added above any fallible +//! // ones if you're using `ServiceBuilder`, due to how ordering works +//! .layer(HandleErrorLayer::new(handle_error)) //! // `TraceLayer` adds high level tracing and logging //! .layer(TraceLayer::new_for_http()) //! // `AsyncFilterLayer` lets you asynchronously transform the request //! .layer(AsyncFilterLayer::new(map_request)) //! // `AndThenLayer` lets you asynchronously transform the response -//! .layer(AndThenLayer::new(map_response)) -//! .into_inner(); +//! .layer(AndThenLayer::new(map_response)); //! //! async fn map_request(req: Request) -> Result, Infallible> { //! Ok(req) @@ -1010,6 +997,8 @@ //! a middleware. Among other things, this can be useful for doing authorization. See //! [`extract::extractor_middleware()`] for more details. //! +//! See [Error handling](#error-handling) for more details on general error handling in axum. +//! //! ## Writing your own middleware //! //! You can also write you own middleware by implementing [`tower::Service`]: @@ -1166,7 +1155,7 @@ //! [`Service`]: tower::Service //! [`Service::poll_ready`]: tower::Service::poll_ready //! [`tower::Service`]: tower::Service -//! [`handle_error`]: routing::Router::handle_error +//! [`handle_error`]: error_handling::HandleErrorExt::handle_error //! [tower-guides]: https://github.com/tower-rs/tower/tree/master/guides //! [`Uuid`]: https://docs.rs/uuid/latest/uuid/ //! [`FromRequest`]: crate::extract::FromRequest @@ -1176,6 +1165,7 @@ //! [axum-debug]: https://docs.rs/axum-debug //! [`debug_handler`]: https://docs.rs/axum-debug/latest/axum_debug/attr.debug_handler.html //! [`Handler`]: crate::handler::Handler +//! [`Infallible`]: std::convert::Infallible #![warn( clippy::all, @@ -1227,6 +1217,7 @@ mod json; mod util; pub mod body; +pub mod error_handling; pub mod extract; pub mod handler; pub mod response; diff --git a/src/routing/future.rs b/src/routing/future.rs index 84cbe63d..3be347f3 100644 --- a/src/routing/future.rs +++ b/src/routing/future.rs @@ -4,7 +4,6 @@ use crate::{ body::BoxBody, clone_box_service::CloneBoxService, routing::{FromEmptyRouter, UriStack}, - BoxError, }; use http::{Request, Response}; use pin_project_lite::pin_project; @@ -28,33 +27,24 @@ opaque_future! { pin_project! { /// The response future for [`BoxRoute`](super::BoxRoute). - pub struct BoxRouteFuture - where - E: Into, - { + pub struct BoxRouteFuture { #[pin] pub(super) inner: Oneshot< - CloneBoxService, Response, E>, + CloneBoxService, Response, Infallible>, Request, >, } } -impl Future for BoxRouteFuture -where - E: Into, -{ - type Output = Result, E>; +impl Future for BoxRouteFuture { + type Output = Result, Infallible>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.project().inner.poll(cx) } } -impl fmt::Debug for BoxRouteFuture -where - E: Into, -{ +impl fmt::Debug for BoxRouteFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BoxRouteFuture").finish() } @@ -112,11 +102,11 @@ pin_project! { impl Future for RouteFuture where - S: Service, Response = Response>, - F: Service, Response = Response, Error = S::Error>, + S: Service, Response = Response, Error = Infallible>, + F: Service, Response = Response, Error = Infallible>, B: Send + Sync + 'static, { - type Output = Result, S::Error>; + type Output = Result, Infallible>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project().state.project() { @@ -141,11 +131,11 @@ pin_project! { impl Future for NestedFuture where - S: Service, Response = Response>, - F: Service, Response = Response, Error = S::Error>, + S: Service, Response = Response, Error = Infallible>, + F: Service, Response = Response, Error = Infallible>, B: Send + Sync + 'static, { - type Output = Result, S::Error>; + type Output = Result, Infallible>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut res: Response<_> = futures_util::ready!(self.project().inner.poll(cx)?); diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 3591d8ef..991f691b 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -8,7 +8,6 @@ use crate::{ connect_info::{Connected, IntoMakeServiceWithConnectInfo}, OriginalUri, }, - service::HandleError, util::{ByteStr, PercentDecodedByteStr}, BoxError, }; @@ -64,7 +63,7 @@ enum MaybeSharedNode { Shared(Arc>), } -impl Router> { +impl Router { /// Create a new `Router`. /// /// Unless you add additional routes this will respond to `404 Not Found` to @@ -77,7 +76,7 @@ impl Router> { } } -impl Default for Router> { +impl Default for Router { fn default() -> Self { Self::new() } @@ -176,7 +175,10 @@ impl Router { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - pub fn route(mut self, path: &str, svc: T) -> Router> { + pub fn route(mut self, path: &str, svc: T) -> Router> + where + T: Service, Error = Infallible>, + { let id = RouteId::next(); if let Err(err) = self.update_node(|node| node.insert(path, id)) { @@ -262,11 +264,20 @@ impl Router { /// use axum::{ /// Router, /// service::get, + /// error_handling::HandleErrorExt, + /// http::StatusCode, /// }; + /// use std::{io, convert::Infallible}; /// use tower_http::services::ServeDir; /// /// // Serves files inside the `public` directory at `GET /public/*` - /// let serve_dir_service = ServeDir::new("public"); + /// let serve_dir_service = ServeDir::new("public") + /// .handle_error(|error: io::Error| { + /// ( + /// StatusCode::INTERNAL_SERVER_ERROR, + /// format!("Unhandled internal error: {}", error), + /// ) + /// }); /// /// let app = Router::new().nest("/public", get(serve_dir_service)); /// # async { @@ -305,7 +316,10 @@ impl Router { /// for more details. /// /// [`OriginalUri`]: crate::extract::OriginalUri - pub fn nest(mut self, path: &str, svc: T) -> Router> { + pub fn nest(mut self, path: &str, svc: T) -> Router> + where + T: Service, Error = Infallible>, + { let id = RouteId::next(); if path.contains('*') { @@ -361,10 +375,13 @@ impl Router { /// /// It also helps with compile times when you have a very large number of /// routes. - pub fn boxed(self) -> Router> + pub fn boxed(self) -> Router> where - S: Service, Response = Response> + Clone + Send + Sync + 'static, - S::Error: Into + Send, + S: Service, Response = Response, Error = Infallible> + + Clone + + Send + + Sync + + 'static, S::Future: Send, ReqBody: Send + 'static, ResBody: http_body::Body + Send + Sync + 'static, @@ -374,8 +391,7 @@ impl Router { ServiceBuilder::new() .layer_fn(BoxRoute) .layer_fn(CloneBoxService::new) - .layer(MapResponseBodyLayer::new(box_body)) - .into_inner(), + .layer(MapResponseBodyLayer::new(box_body)), ) } @@ -601,100 +617,19 @@ impl Router { /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - pub fn or(self, other: S2) -> Router> { + pub fn or(self, other: T) -> Router> + where + T: Service, Error = Infallible>, + { self.map(|first| Or { first, second: other, }) } - /// Handle errors services in this router might produce, by mapping them to - /// responses. - /// - /// Unhandled errors will close the connection without sending a response. - /// - /// # Example - /// - /// ``` - /// use axum::{ - /// handler::get, - /// http::StatusCode, - /// Router, - /// }; - /// use tower::{BoxError, timeout::TimeoutLayer}; - /// use std::{time::Duration, convert::Infallible}; - /// - /// // This router can never fail, since handlers can never fail. - /// let app = Router::new().route("/", get(|| async {})); - /// - /// // Now the router can fail since the `tower::timeout::Timeout` - /// // middleware will return an error if the timeout elapses. - /// let app = app.layer(TimeoutLayer::new(Duration::from_secs(10))); - /// - /// // With `handle_error` we can handle errors `Timeout` might produce. - /// // Our router now cannot fail, that is its error type is `Infallible`. - /// let app = app.handle_error(|error: BoxError| { - /// if error.is::() { - /// Ok::<_, Infallible>(( - /// StatusCode::REQUEST_TIMEOUT, - /// "request took too long to handle".to_string(), - /// )) - /// } else { - /// Ok::<_, Infallible>(( - /// StatusCode::INTERNAL_SERVER_ERROR, - /// format!("Unhandled error: {}", error), - /// )) - /// } - /// }); - /// # async { - /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); - /// # }; - /// ``` - /// - /// You can return `Err(_)` from the closure if you don't wish to handle - /// some errors: - /// - /// ``` - /// use axum::{ - /// handler::get, - /// http::StatusCode, - /// Router, - /// }; - /// use tower::{BoxError, timeout::TimeoutLayer}; - /// use std::time::Duration; - /// - /// let app = Router::new() - /// .route("/", get(|| async {})) - /// .layer(TimeoutLayer::new(Duration::from_secs(10))) - /// .handle_error(|error: BoxError| { - /// if error.is::() { - /// Ok(( - /// StatusCode::REQUEST_TIMEOUT, - /// "request took too long to handle".to_string(), - /// )) - /// } else { - /// // return the error as is - /// Err(error) - /// } - /// }); - /// # async { - /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); - /// # }; - /// ``` - pub fn handle_error(self, f: F) -> Router> { - self.map(|svc| HandleError::new(svc, f)) - } - - /// Check that your service cannot fail. - /// - /// That is, its error type is [`Infallible`]. - pub fn check_infallible(self) -> Router> { - self.map(CheckInfallible) - } - - fn map(self, f: F) -> Router + fn map(self, f: F) -> Router where - F: FnOnce(S) -> S2, + F: FnOnce(S) -> T, { Router { routes: f(self.routes), @@ -740,11 +675,11 @@ impl Router { impl Service> for Router where - S: Service, Response = Response> + Clone, + S: Service, Response = Response, Error = Infallible> + Clone, ReqBody: Send + Sync + 'static, { type Response = Response; - type Error = S::Error; + type Error = Infallible; type Future = S::Future; #[inline] @@ -949,12 +884,12 @@ pub struct Route { impl Service> for Route where - S: Service, Response = Response> + Clone, - T: Service, Response = Response, Error = S::Error> + Clone, + S: Service, Response = Response, Error = Infallible> + Clone, + T: Service, Response = Response, Error = Infallible> + Clone, B: Send + Sync + 'static, { type Response = Response; - type Error = S::Error; + type Error = Infallible; type Future = RouteFuture; #[inline] @@ -988,12 +923,12 @@ pub struct Nested { impl Service> for Nested where - S: Service, Response = Response> + Clone, - T: Service, Response = Response, Error = S::Error> + Clone, + S: Service, Response = Response, Error = Infallible> + Clone, + T: Service, Response = Response, Error = Infallible> + Clone, B: Send + Sync + 'static, { type Response = Response; - type Error = S::Error; + type Error = Infallible; type Future = NestedFuture; #[inline] @@ -1020,29 +955,26 @@ where /// A boxed route trait object. /// /// See [`Router::boxed`] for more details. -pub struct BoxRoute( - CloneBoxService, Response, E>, +pub struct BoxRoute( + CloneBoxService, Response, Infallible>, ); -impl Clone for BoxRoute { +impl Clone for BoxRoute { fn clone(&self) -> Self { Self(self.0.clone()) } } -impl fmt::Debug for BoxRoute { +impl fmt::Debug for BoxRoute { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BoxRoute").finish() } } -impl Service> for BoxRoute -where - E: Into, -{ +impl Service> for BoxRoute { type Response = Response; - type Error = E; - type Future = BoxRouteFuture; + type Error = Infallible; + type Future = BoxRouteFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { @@ -1086,31 +1018,6 @@ fn with_path(uri: &Uri, new_path: &str) -> Uri { Uri::from_parts(parts).unwrap() } -/// Middleware that statically verifies that a service cannot fail. -/// -/// Created with [`check_infallible`](Router::check_infallible). -#[derive(Debug, Clone, Copy)] -pub struct CheckInfallible(S); - -impl Service for CheckInfallible -where - S: Service, -{ - type Response = S::Response; - type Error = S::Error; - type Future = S::Future; - - #[inline] - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.0.poll_ready(cx) - } - - #[inline] - fn call(&mut self, req: R) -> Self::Future { - self.0.call(req) - } -} - /// A [`MakeService`] that produces axum router services. /// /// [`MakeService`]: tower::make::MakeService @@ -1161,15 +1068,12 @@ mod tests { assert_send::>(); assert_sync::>(); - assert_send::>(); - assert_sync::>(); + assert_send::>(); + assert_sync::>(); assert_send::>(); assert_sync::>(); - assert_send::>(); - assert_sync::>(); - assert_send::>(); assert_sync::>(); } diff --git a/src/routing/or.rs b/src/routing/or.rs index fabf169b..4c985e95 100644 --- a/src/routing/or.rs +++ b/src/routing/or.rs @@ -6,6 +6,7 @@ use futures_util::ready; use http::{Request, Response}; use pin_project_lite::pin_project; use std::{ + convert::Infallible, future::Future, pin::Pin, task::{Context, Poll}, @@ -33,8 +34,8 @@ fn traits() { impl Service> for Or where - A: Service, Response = Response> + Clone, - B: Service, Response = Response, Error = A::Error> + Clone, + A: Service, Response = Response, Error = Infallible> + Clone, + B: Service, Response = Response, Error = Infallible> + Clone, ReqBody: Send + Sync + 'static, A: Send + 'static, B: Send + 'static, @@ -42,7 +43,7 @@ where B::Future: Send + 'static, { type Response = Response; - type Error = A::Error; + type Error = Infallible; type Future = ResponseFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { @@ -89,11 +90,11 @@ pin_project! { impl Future for ResponseFuture where - A: Service, Response = Response>, - B: Service, Response = Response, Error = A::Error>, + A: Service, Response = Response, Error = Infallible>, + B: Service, Response = Response, Error = Infallible>, ReqBody: Send + Sync + 'static, { - type Output = Result, A::Error>; + type Output = Result, Infallible>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { diff --git a/src/service/future.rs b/src/service/future.rs index ac734508..ad5cb38a 100644 --- a/src/service/future.rs +++ b/src/service/future.rs @@ -2,7 +2,6 @@ use crate::{ body::{box_body, BoxBody}, - response::IntoResponse, util::{Either, EitherProj}, BoxError, }; @@ -19,42 +18,6 @@ use std::{ use tower::util::Oneshot; use tower_service::Service; -pin_project! { - /// Response future for [`HandleError`](super::HandleError). - #[derive(Debug)] - pub struct HandleErrorFuture { - #[pin] - pub(super) inner: Fut, - pub(super) f: Option, - } -} - -impl Future for HandleErrorFuture -where - Fut: Future, E>>, - F: FnOnce(E) -> Result, - Res: IntoResponse, - B: http_body::Body + Send + Sync + 'static, - B::Error: Into + Send + Sync + 'static, -{ - type Output = Result, E2>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - match ready!(this.inner.poll(cx)) { - Ok(res) => Ok(res.map(box_body)).into(), - Err(err) => { - let f = this.f.take().unwrap(); - match f(err) { - Ok(res) => Ok(res.into_response().map(box_body)).into(), - Err(err) => Err(err).into(), - } - } - } - } -} - pin_project! { /// The response future for [`OnMethod`](super::OnMethod). pub struct OnMethodFuture diff --git a/src/service/mod.rs b/src/service/mod.rs index 1d53837a..05e1006e 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -99,17 +99,15 @@ use crate::BoxError; use crate::{ body::BoxBody, - response::IntoResponse, routing::{EmptyRouter, MethodFilter}, }; use bytes::Bytes; use http::{Request, Response}; use std::{ - fmt, marker::PhantomData, task::{Context, Poll}, }; -use tower::{util::Oneshot, ServiceExt as _}; +use tower::ServiceExt as _; use tower_service::Service; pub mod future; @@ -481,18 +479,6 @@ impl OnMethod { _request_body: PhantomData, } } - - /// Handle errors this service might produce, by mapping them to responses. - /// - /// Unhandled errors will close the connection without sending a response. - /// - /// Works similarly to [`Router::handle_error`]. See that for more - /// details. - /// - /// [`Router::handle_error`]: crate::routing::Router::handle_error - pub fn handle_error(self, f: H) -> HandleError { - HandleError::new(self, f) - } } // this is identical to `routing::OnMethod`'s implementation. Would be nice to find a way to clean @@ -532,81 +518,10 @@ where } } -/// A [`Service`] adapter that handles errors with a closure. -/// -/// Created with -/// [`handler::Layered::handle_error`](crate::handler::Layered::handle_error) or -/// [`routing::Router::handle_error`](crate::routing::Router::handle_error). -/// See those methods for more details. -pub struct HandleError { - inner: S, - f: F, - _marker: PhantomData B>, -} - -impl Clone for HandleError -where - S: Clone, - F: Clone, -{ - fn clone(&self) -> Self { - Self::new(self.inner.clone(), self.f.clone()) - } -} - -impl HandleError { - pub(crate) fn new(inner: S, f: F) -> Self { - Self { - inner, - f, - _marker: PhantomData, - } - } -} - -impl fmt::Debug for HandleError -where - S: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("HandleError") - .field("inner", &self.inner) - .field("f", &format_args!("{}", std::any::type_name::())) - .finish() - } -} - -impl Service> for HandleError -where - S: Service, Response = Response> + Clone, - F: FnOnce(S::Error) -> Result + Clone, - Res: IntoResponse, - ResBody: http_body::Body + Send + Sync + 'static, - ResBody::Error: Into + Send + Sync + 'static, -{ - type Response = Response; - type Error = E; - type Future = future::HandleErrorFuture>, F>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - future::HandleErrorFuture { - f: Some(self.f.clone()), - inner: self.inner.clone().oneshot(req), - } - } -} - #[test] fn traits() { use crate::tests::*; assert_send::>(); assert_sync::>(); - - assert_send::>(); - assert_sync::>(); } diff --git a/src/tests/handle_error.rs b/src/tests/handle_error.rs index eec54946..ffb9d92d 100644 --- a/src/tests/handle_error.rs +++ b/src/tests/handle_error.rs @@ -1,6 +1,6 @@ use super::*; use std::future::{pending, ready}; -use tower::{timeout::TimeoutLayer, MakeService}; +use tower::{timeout::TimeoutLayer, ServiceBuilder}; async fn unit() {} @@ -29,23 +29,17 @@ impl Service for Svc { } } -fn check_make_svc(_make_svc: M) -where - M: MakeService<(), R, Response = T, Error = E>, -{ -} - -fn handle_error(_: E) -> Result { - Ok(StatusCode::INTERNAL_SERVER_ERROR) -} - #[tokio::test] async fn handler() { let app = Router::new().route( "/", - get(forever - .layer(timeout()) - .handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT))), + get(forever.layer( + ServiceBuilder::new() + .layer(HandleErrorLayer::new(|_: BoxError| { + StatusCode::REQUEST_TIMEOUT + })) + .layer(timeout()), + )), ); let client = TestClient::new(app); @@ -58,9 +52,13 @@ async fn handler() { async fn handler_multiple_methods_first() { let app = Router::new().route( "/", - get(forever - .layer(timeout()) - .handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT))) + get(forever.layer( + ServiceBuilder::new() + .layer(HandleErrorLayer::new(|_: BoxError| { + StatusCode::REQUEST_TIMEOUT + })) + .layer(timeout()), + )) .post(unit), ); @@ -76,9 +74,13 @@ async fn handler_multiple_methods_middle() { "/", delete(unit) .get( - forever - .layer(timeout()) - .handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)), + forever.layer( + ServiceBuilder::new() + .layer(HandleErrorLayer::new(|_: BoxError| { + StatusCode::REQUEST_TIMEOUT + })) + .layer(timeout()), + ), ) .post(unit), ); @@ -94,9 +96,13 @@ async fn handler_multiple_methods_last() { let app = Router::new().route( "/", delete(unit).get( - forever - .layer(timeout()) - .handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)), + forever.layer( + ServiceBuilder::new() + .layer(HandleErrorLayer::new(|_: BoxError| { + StatusCode::REQUEST_TIMEOUT + })) + .layer(timeout()), + ), ), ); @@ -105,82 +111,3 @@ async fn handler_multiple_methods_last() { let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); } - -#[test] -fn service_propagates_errors() { - let app = Router::new().route("/echo", service::post::<_, Body>(Svc)); - - check_make_svc::<_, _, _, hyper::Error>(app.into_make_service()); -} - -#[test] -fn service_nested_propagates_errors() { - let app = Router::new().route( - "/echo", - Router::new().nest("/foo", service::post::<_, Body>(Svc)), - ); - - check_make_svc::<_, _, _, hyper::Error>(app.into_make_service()); -} - -#[test] -fn service_handle_on_method() { - let app = Router::new().route( - "/echo", - service::get::<_, Body>(Svc).handle_error(handle_error::), - ); - - check_make_svc::<_, _, _, Infallible>(app.into_make_service()); -} - -#[test] -fn service_handle_on_method_multiple() { - let app = Router::new().route( - "/echo", - service::get::<_, Body>(Svc) - .post(Svc) - .handle_error(handle_error::), - ); - - check_make_svc::<_, _, _, Infallible>(app.into_make_service()); -} - -#[test] -fn service_handle_on_router() { - let app = Router::new() - .route("/echo", service::get::<_, Body>(Svc)) - .handle_error(handle_error::); - - check_make_svc::<_, _, _, Infallible>(app.into_make_service()); -} - -#[test] -fn service_handle_on_router_still_impls_routing_dsl() { - let app = Router::new() - .route("/echo", service::get::<_, Body>(Svc)) - .handle_error(handle_error::) - .route("/", get(unit)); - - check_make_svc::<_, _, _, Infallible>(app.into_make_service()); -} - -#[test] -fn layered() { - let app = Router::new() - .route("/echo", get::<_, Body, _>(unit)) - .layer(timeout()) - .handle_error(handle_error::); - - check_make_svc::<_, _, _, Infallible>(app.into_make_service()); -} - -#[tokio::test] // async because of `.boxed()` -async fn layered_boxed() { - let app = Router::new() - .route("/echo", get::<_, Body, _>(unit)) - .layer(timeout()) - .boxed() - .handle_error(handle_error::); - - check_make_svc::<_, _, _, Infallible>(app.into_make_service()); -} diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 78839678..a859d5fa 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,5 +1,6 @@ #![allow(clippy::blacklisted_name)] +use crate::error_handling::HandleErrorLayer; use crate::BoxError; use crate::{ extract::{self, Path}, @@ -339,7 +340,7 @@ async fn middleware_on_single_route() { #[tokio::test] async fn service_in_bottom() { - async fn handler(_req: Request) -> Result, hyper::Error> { + async fn handler(_req: Request) -> Result, Infallible> { Ok(Response::new(hyper::Body::empty())) } @@ -532,7 +533,9 @@ async fn middleware_applies_to_routes_above() { let app = Router::new() .route("/one", get(std::future::pending::<()>)) .layer(TimeoutLayer::new(Duration::new(0, 0))) - .handle_error(|_: BoxError| Ok::<_, Infallible>(StatusCode::REQUEST_TIMEOUT)) + .layer(HandleErrorLayer::new(|_: BoxError| { + StatusCode::REQUEST_TIMEOUT + })) .route("/two", get(|| async {})); let client = TestClient::new(app); diff --git a/src/tests/nest.rs b/src/tests/nest.rs index 228d2237..a91238ae 100644 --- a/src/tests/nest.rs +++ b/src/tests/nest.rs @@ -1,5 +1,6 @@ use super::*; use crate::body::box_body; +use crate::error_handling::HandleErrorExt; use crate::routing::EmptyRouter; use std::collections::HashMap; @@ -169,10 +170,10 @@ async fn nest_static_file_server() { let app = Router::new().nest( "/static", service::get(tower_http::services::ServeDir::new(".")).handle_error(|error| { - Ok::<_, Infallible>(( + ( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", error), - )) + ) }), ); @@ -255,5 +256,5 @@ async fn multiple_top_level_nests() { #[tokio::test] #[should_panic(expected = "Invalid route: nested routes cannot contain wildcards (*)")] async fn nest_cannot_contain_wildcards() { - Router::::new().nest("/one/*rest", Router::::new()); + Router::::new().nest::<_, Body>("/one/*rest", Router::::new()); } diff --git a/src/tests/or.rs b/src/tests/or.rs index 71aeadca..e9b78f94 100644 --- a/src/tests/or.rs +++ b/src/tests/or.rs @@ -1,5 +1,5 @@ use super::*; -use crate::{extract::OriginalUri, response::IntoResponse, Json}; +use crate::{error_handling::HandleErrorLayer, extract::OriginalUri, response::IntoResponse, Json}; use serde_json::{json, Value}; use tower::{limit::ConcurrencyLimitLayer, timeout::TimeoutLayer}; @@ -136,7 +136,7 @@ async fn layer_and_handle_error() { let two = Router::new() .route("/timeout", get(futures::future::pending::<()>)) .layer(TimeoutLayer::new(Duration::from_millis(10))) - .handle_error(|_| Ok(StatusCode::REQUEST_TIMEOUT)); + .layer(HandleErrorLayer::new(|_| StatusCode::REQUEST_TIMEOUT)); let app = one.or(two); let client = TestClient::new(app);