From 75b5615ccdf7c63ff4c2ad49729645f83d164d61 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sat, 7 Aug 2021 19:56:44 +0200 Subject: [PATCH] Add `axum::Error` (#150) Replace `BoxStdError` and supports downcasting --- CHANGELOG.md | 2 ++ src/body.rs | 25 +++-------------------- src/error.rs | 28 ++++++++++++++++++++++++++ src/extract/form.rs | 3 ++- src/extract/rejection.rs | 19 ++++++++++-------- src/extract/request_parts.rs | 5 +++-- src/extract/ws.rs | 39 ++++++++++++++++++------------------ src/json.rs | 3 ++- src/lib.rs | 9 ++++++--- src/macros.rs | 6 +++--- src/response.rs | 9 ++++++--- src/sse.rs | 5 +++-- 12 files changed, 88 insertions(+), 65 deletions(-) create mode 100644 src/error.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index b923d755..def255bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `EmptyRouter` - `ExtractorMiddleware` - `ExtractorMiddlewareLayer` +- Replace `axum::body::BoxStdError` with `axum::Error`, which supports downcasting ([#150](https://github.com/tokio-rs/axum/pull/150)) +- `WebSocket` now uses `axum::Error` as its error type ([#150](https://github.com/tokio-rs/axum/pull/150)) # 0.1.3 (06. August, 2021) diff --git a/src/body.rs b/src/body.rs index c18c8823..5ec10bbc 100644 --- a/src/body.rs +++ b/src/body.rs @@ -1,8 +1,8 @@ //! HTTP body utilities. +use crate::Error; use bytes::Bytes; use http_body::Body as _; -use std::{error::Error as StdError, fmt}; use tower::BoxError; #[doc(no_inline)] @@ -12,7 +12,7 @@ pub use hyper::body::Body; /// /// This is used in axum as the response body type for applications. Its /// necessary to unify multiple response bodies types into one. -pub type BoxBody = http_body::combinators::BoxBody; +pub type BoxBody = http_body::combinators::BoxBody; /// Convert a [`http_body::Body`] into a [`BoxBody`]. pub fn box_body(body: B) -> BoxBody @@ -20,28 +20,9 @@ where B: http_body::Body + Send + Sync + 'static, B::Error: Into, { - body.map_err(|err| BoxStdError(err.into())).boxed() + body.map_err(Error::new).boxed() } pub(crate) fn empty() -> BoxBody { box_body(http_body::Empty::new()) } - -/// A boxed error trait object that implements [`std::error::Error`]. -/// -/// This is necessary for compatibility with middleware that changes the error -/// type of the response body. -#[derive(Debug)] -pub struct BoxStdError(pub(crate) tower::BoxError); - -impl StdError for BoxStdError { - fn source(&self) -> std::option::Option<&(dyn StdError + 'static)> { - self.0.source() - } -} - -impl fmt::Display for BoxStdError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.0.fmt(f) - } -} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 00000000..849d5ce8 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,28 @@ +use std::{error::Error as StdError, fmt}; +use tower::BoxError; + +/// Errors that can happen when using axum. +#[derive(Debug)] +pub struct Error { + inner: BoxError, +} + +impl Error { + pub(crate) fn new(error: impl Into) -> Self { + Self { + inner: error.into(), + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(f) + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + Some(&*self.inner) + } +} diff --git a/src/extract/form.rs b/src/extract/form.rs index 0aed6856..1dc93fa2 100644 --- a/src/extract/form.rs +++ b/src/extract/form.rs @@ -4,6 +4,7 @@ use bytes::Buf; use http::Method; use serde::de::DeserializeOwned; use std::ops::Deref; +use tower::BoxError; /// Extractor that deserializes `application/x-www-form-urlencoded` requests /// into some type. @@ -44,7 +45,7 @@ where T: DeserializeOwned, B: http_body::Body + Send, B::Data: Send, - B::Error: Into, + B::Error: Into, { type Rejection = FormRejection; diff --git a/src/extract/rejection.rs b/src/extract/rejection.rs index c2488b4a..adfa3628 100644 --- a/src/extract/rejection.rs +++ b/src/extract/rejection.rs @@ -1,7 +1,10 @@ //! Rejection response types. use super::IntoResponse; -use crate::body::{box_body, BoxBody, BoxStdError}; +use crate::{ + body::{box_body, BoxBody}, + Error, +}; use bytes::Bytes; use http_body::Full; use std::convert::Infallible; @@ -46,7 +49,7 @@ define_rejection! { #[status = BAD_REQUEST] #[body = "Failed to parse the request body as JSON"] /// Rejection type for [`Json`](super::Json). - pub struct InvalidJsonBody(BoxError); + pub struct InvalidJsonBody(Error); } define_rejection! { @@ -62,7 +65,7 @@ define_rejection! { #[body = "Missing request extension"] /// Rejection type for [`Extension`](super::Extension) if an expected /// request extension was not found. - pub struct MissingExtension(BoxError); + pub struct MissingExtension(Error); } define_rejection! { @@ -70,7 +73,7 @@ define_rejection! { #[body = "Failed to buffer the request body"] /// Rejection type for extractors that buffer the request body. Used if the /// request body cannot be buffered due to an error. - pub struct FailedToBufferBody(BoxError); + pub struct FailedToBufferBody(Error); } define_rejection! { @@ -78,7 +81,7 @@ define_rejection! { #[body = "Request body didn't contain valid UTF-8"] /// Rejection type used when buffering the request into a [`String`] if the /// body doesn't contain valid UTF-8. - pub struct InvalidUtf8(BoxError); + pub struct InvalidUtf8(Error); } define_rejection! { @@ -183,7 +186,7 @@ impl IntoResponse for InvalidPathParam { /// couldn't be deserialized into the target type. #[derive(Debug)] pub struct FailedToDeserializeQueryString { - error: BoxError, + error: Error, type_name: &'static str, } @@ -193,7 +196,7 @@ impl FailedToDeserializeQueryString { E: Into, { FailedToDeserializeQueryString { - error: error.into(), + error: Error::new(error), type_name: std::any::type_name::(), } } @@ -330,7 +333,7 @@ where T: IntoResponse, { type Body = BoxBody; - type BodyError = BoxStdError; + type BodyError = Error; fn into_response(self) -> http::Response { match self { diff --git a/src/extract/request_parts.rs b/src/extract/request_parts.rs index 2be45a12..f4afcdd4 100644 --- a/src/extract/request_parts.rs +++ b/src/extract/request_parts.rs @@ -7,6 +7,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; +use tower::BoxError; #[async_trait] impl FromRequest for Request @@ -176,7 +177,7 @@ impl FromRequest for Bytes where B: http_body::Body + Send, B::Data: Send, - B::Error: Into, + B::Error: Into, { type Rejection = BytesRejection; @@ -196,7 +197,7 @@ impl FromRequest for String where B: http_body::Body + Send, B::Data: Send, - B::Error: Into, + B::Error: Into, { type Rejection = StringRejection; diff --git a/src/extract/ws.rs b/src/extract/ws.rs index 87e22adc..d2c07d7d 100644 --- a/src/extract/ws.rs +++ b/src/extract/ws.rs @@ -37,7 +37,7 @@ use self::rejection::*; use super::{rejection::*, FromRequest, RequestParts}; -use crate::response::IntoResponse; +use crate::{response::IntoResponse, Error}; use async_trait::async_trait; use bytes::Bytes; use futures_util::{ @@ -64,7 +64,6 @@ use tokio_tungstenite::{ }, WebSocketStream, }; -use tower::BoxError; /// Extractor for establishing WebSocket connections. /// @@ -332,32 +331,32 @@ impl WebSocket { /// Receive another message. /// /// Returns `None` if the stream stream has closed. - pub async fn recv(&mut self) -> Option> { + pub async fn recv(&mut self) -> Option> { self.next().await } /// Send a message. - pub async fn send(&mut self, msg: Message) -> Result<(), BoxError> { + pub async fn send(&mut self, msg: Message) -> Result<(), Error> { self.inner .send(msg.into_tungstenite()) .await - .map_err(Into::into) + .map_err(Error::new) } /// Gracefully close this WebSocket. - pub async fn close(mut self) -> Result<(), BoxError> { - self.inner.close(None).await.map_err(Into::into) + pub async fn close(mut self) -> Result<(), Error> { + self.inner.close(None).await.map_err(Error::new) } } impl Stream for WebSocket { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.inner.poll_next_unpin(cx).map(|option_msg| { option_msg.map(|result_msg| { result_msg - .map_err(Into::into) + .map_err(Error::new) .map(Message::from_tungstenite) }) }) @@ -365,24 +364,24 @@ impl Stream for WebSocket { } impl Sink for WebSocket { - type Error = BoxError; + type Error = Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into) + Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new) } fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { Pin::new(&mut self.inner) .start_send(item.into_tungstenite()) - .map_err(Into::into) + .map_err(Error::new) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into) + Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into) + Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new) } } @@ -483,12 +482,12 @@ impl Message { } /// Attempt to consume the WebSocket message and convert it to a String. - pub fn into_text(self) -> Result { + pub fn into_text(self) -> Result { match self { Self::Text(string) => Ok(string), - Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => { - Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?) - } + Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data) + .map_err(|err| err.utf8_error()) + .map_err(Error::new)?), Self::Close(None) => Ok(String::new()), Self::Close(Some(frame)) => Ok(frame.reason.into_owned()), } @@ -496,11 +495,11 @@ impl Message { /// Attempt to get a &str from the WebSocket message, /// this will try to convert binary data to utf8. - pub fn to_text(&self) -> Result<&str, BoxError> { + pub fn to_text(&self) -> Result<&str, Error> { match *self { Self::Text(ref string) => Ok(string), Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => { - Ok(std::str::from_utf8(data)?) + Ok(std::str::from_utf8(data).map_err(Error::new)?) } Self::Close(None) => Ok(""), Self::Close(Some(ref frame)) => Ok(&frame.reason), diff --git a/src/json.rs b/src/json.rs index b0df06be..a62158ed 100644 --- a/src/json.rs +++ b/src/json.rs @@ -15,6 +15,7 @@ use std::{ convert::Infallible, ops::{Deref, DerefMut}, }; +use tower::BoxError; /// JSON Extractor/Response /// @@ -89,7 +90,7 @@ where T: DeserializeOwned, B: http_body::Body + Send, B::Data: Send, - B::Error: Into, + B::Error: Into, { type Rejection = JsonRejection; diff --git a/src/lib.rs b/src/lib.rs index 8e0d2369..18968381 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -134,7 +134,7 @@ //! //! ```rust //! use axum::{prelude::*, body::BoxBody}; -//! use tower::{Service, ServiceExt, BoxError}; +//! use tower::{Service, ServiceExt}; //! use http::{Method, Response, StatusCode}; //! use std::convert::Infallible; //! @@ -563,7 +563,7 @@ //! use tower_http::services::ServeFile; //! use http::Response; //! use std::convert::Infallible; -//! use tower::{service_fn, BoxError}; +//! use tower::service_fn; //! //! let app = route( //! // Any request to `/` goes to a service @@ -715,6 +715,7 @@ use tower::Service; pub(crate) mod macros; mod buffer; +mod error; mod json; mod util; @@ -729,14 +730,16 @@ pub mod sse; #[cfg(test)] mod tests; +#[doc(no_inline)] pub use async_trait::async_trait; #[doc(no_inline)] pub use http; #[doc(no_inline)] pub use hyper::Server; +#[doc(no_inline)] pub use tower_http::add_extension::{AddExtension, AddExtensionLayer}; -pub use crate::json::Json; +pub use self::{error::Error, json::Json}; pub mod prelude { //! Re-exports of important traits, types, and functions used with axum. Meant to be glob diff --git a/src/macros.rs b/src/macros.rs index cfb09eeb..daf4d0dd 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -64,18 +64,18 @@ macro_rules! define_rejection { #[status = $status:ident] #[body = $body:expr] $(#[$m:meta])* - pub struct $name:ident (BoxError); + pub struct $name:ident (Error); ) => { $(#[$m])* #[derive(Debug)] - pub struct $name(pub(crate) tower::BoxError); + pub struct $name(pub(crate) crate::Error); impl $name { pub(crate) fn from_err(err: E) -> Self where E: Into, { - Self(err.into()) + Self(crate::Error::new(err)) } } diff --git a/src/response.rs b/src/response.rs index e0c7ad97..6036e7e8 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,6 +1,9 @@ //! Types and traits for generating responses. -use crate::body::{box_body, BoxBody, BoxStdError}; +use crate::{ + body::{box_body, BoxBody}, + Error, +}; use bytes::Bytes; use http::{header, HeaderMap, HeaderValue, Response, StatusCode}; use http_body::{ @@ -139,7 +142,7 @@ where K: IntoResponse, { type Body = BoxBody; - type BodyError = BoxStdError; + type BodyError = Error; fn into_response(self) -> Response { match self { @@ -155,7 +158,7 @@ where E: IntoResponse, { type Body = BoxBody; - type BodyError = BoxStdError; + type BodyError = Error; fn into_response(self) -> Response { match self { diff --git a/src/sse.rs b/src/sse.rs index 1fa13a1e..4a967a84 100644 --- a/src/sse.rs +++ b/src/sse.rs @@ -70,9 +70,10 @@ //! ``` use crate::{ - body::{box_body, BoxBody, BoxStdError}, + body::{box_body, BoxBody}, extract::{FromRequest, RequestParts}, response::IntoResponse, + Error, }; use async_trait::async_trait; use futures_util::{ @@ -276,7 +277,7 @@ where let stream = stream .map_ok(|event| event.to_string()) - .map_err(|err| BoxStdError(err.into())) + .map_err(Error::new) .into_stream(); let body = box_body(Body::wrap_stream(stream));