diff --git a/axum/src/body/mod.rs b/axum/src/body/mod.rs index 5705ac57..6b449cf4 100644 --- a/axum/src/body/mod.rs +++ b/axum/src/body/mod.rs @@ -1,6 +1,6 @@ //! HTTP body utilities. -use crate::{BoxError, Error}; +use crate::{util::try_downcast, BoxError, Error}; mod stream_body; @@ -27,7 +27,7 @@ where B: http_body::Body + Send + 'static, B::Error: Into, { - body.map_err(Error::new).boxed_unsync() + try_downcast(body).unwrap_or_else(|body| body.map_err(Error::new).boxed_unsync()) } pub(crate) fn empty() -> BoxBody { diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index f660a5cd..f384999a 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -8,7 +8,7 @@ use crate::{ MatchedPath, OriginalUri, }, routing::strip_prefix::StripPrefix, - util::{ByteStr, PercentDecodedByteStr}, + util::{try_downcast, ByteStr, PercentDecodedByteStr}, BoxError, }; use bytes::Bytes; @@ -657,20 +657,6 @@ impl Fallback { } } -fn try_downcast(k: K) -> Result -where - T: 'static, - K: Send + 'static, -{ - use std::any::Any; - - let k = Box::new(k) as Box; - match k.downcast() { - Ok(t) => Ok(*t), - Err(other) => Err(*other.downcast().unwrap()), - } -} - enum Endpoint { MethodRouter(MethodRouter), Route(Route), diff --git a/axum/src/util.rs b/axum/src/util.rs index 61419c8e..71c50e77 100644 --- a/axum/src/util.rs +++ b/axum/src/util.rs @@ -65,3 +65,22 @@ pin_project! { B { #[pin] inner: B }, } } + +pub(crate) fn try_downcast(k: K) -> Result +where + T: 'static, + K: Send + 'static, +{ + let mut k = Some(k); + if let Some(k) = ::downcast_mut::>(&mut k) { + Ok(k.take().unwrap()) + } else { + Err(k.unwrap()) + } +} + +#[test] +fn test_try_downcast() { + assert_eq!(try_downcast::(5_u32), Err(5_u32)); + assert_eq!(try_downcast::(5_i32), Ok(5_i32)); +}