diff --git a/axum-debug/CHANGELOG.md b/axum-debug/CHANGELOG.md index 702e6b5a..492d62ff 100644 --- a/axum-debug/CHANGELOG.md +++ b/axum-debug/CHANGELOG.md @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- Fix `Result` generating invalid code ([#588]) + +[#588]: https://github.com/tokio-rs/axum/pull/588 # 0.3.0 (03. December 2021) diff --git a/axum-debug/src/lib.rs b/axum-debug/src/lib.rs index 36da6bc8..8fe522c4 100644 --- a/axum-debug/src/lib.rs +++ b/axum-debug/src/lib.rs @@ -251,6 +251,24 @@ mod debug_handler { }; let span = ty.span(); + let declare_inputs = item_fn + .sig + .inputs + .iter() + .filter_map(|arg| match arg { + FnArg::Receiver(_) => None, + FnArg::Typed(pat_ty) => { + let pat = &pat_ty.pat; + let ty = &pat_ty.ty; + Some(quote! { + let #pat: #ty = panic!(); + }) + } + }) + .collect::(); + + let block = &item_fn.block; + let make_value_name = format_ident!( "__axum_debug_check_{}_into_response_make_value", item_fn.sig.ident @@ -259,18 +277,18 @@ mod debug_handler { let make = if item_fn.sig.asyncness.is_some() { quote_spanned! {span=> #[allow(warnings)] - async fn #make_value_name() -> #ty { panic!() } - } - } else if let syn::Type::ImplTrait(_) = &**ty { - // lets just assume it returns `impl Future` - quote_spanned! {span=> - #[allow(warnings)] - fn #make_value_name() -> #ty { async { panic!() } } + async fn #make_value_name() -> #ty { + #declare_inputs + #block + } } } else { quote_spanned! {span=> #[allow(warnings)] - fn #make_value_name() -> #ty { panic!() } + fn #make_value_name() -> #ty { + #declare_inputs + #block + } } }; @@ -285,7 +303,7 @@ mod debug_handler { let value = #receiver #make_value_name().await; fn check(_: T) where T: ::axum::response::IntoResponse - {} + {} check(value); } } @@ -296,9 +314,11 @@ mod debug_handler { #make let value = #make_value_name().await; + fn check(_: T) - where T: ::axum::response::IntoResponse + where T: ::axum::response::IntoResponse {} + check(value); } } diff --git a/axum-debug/tests/pass/result_impl_into_response.rs b/axum-debug/tests/pass/result_impl_into_response.rs new file mode 100644 index 00000000..c35773c7 --- /dev/null +++ b/axum-debug/tests/pass/result_impl_into_response.rs @@ -0,0 +1,132 @@ +use axum::{ + async_trait, + extract::{FromRequest, RequestParts}, + response::IntoResponse, +}; +use axum_debug::debug_handler; + +fn main() {} + +#[debug_handler] +fn concrete_future() -> std::future::Ready> { + std::future::ready(Ok(())) +} + +#[debug_handler] +fn impl_future() -> impl std::future::Future> { + std::future::ready(Ok(())) +} + +// === no args === + +#[debug_handler] +async fn handler_no_arg_one() -> Result { + Ok(()) +} + +#[debug_handler] +async fn handler_no_arg_two() -> Result<(), impl IntoResponse> { + Err(()) +} + +#[debug_handler] +async fn handler_no_arg_three() -> Result { + Ok::<_, ()>(()) +} + +#[debug_handler] +async fn handler_no_arg_four() -> Result { + Err::<(), _>(()) +} + +// === args === + +#[debug_handler] +async fn handler_one(foo: String) -> Result { + dbg!(foo); + Ok(()) +} + +#[debug_handler] +async fn handler_two(foo: String) -> Result<(), impl IntoResponse> { + dbg!(foo); + Err(()) +} + +#[debug_handler] +async fn handler_three(foo: String) -> Result { + dbg!(foo); + Ok::<_, ()>(()) +} + +#[debug_handler] +async fn handler_four(foo: String) -> Result { + dbg!(foo); + Err::<(), _>(()) +} + +// === no args with receiver === + +struct A; + +impl A { + #[debug_handler] + async fn handler_no_arg_one(self) -> Result { + Ok(()) + } + + #[debug_handler] + async fn handler_no_arg_two(self) -> Result<(), impl IntoResponse> { + Err(()) + } + + #[debug_handler] + async fn handler_no_arg_three(self) -> Result { + Ok::<_, ()>(()) + } + + #[debug_handler] + async fn handler_no_arg_four(self) -> Result { + Err::<(), _>(()) + } +} + +// === args with receiver === + +impl A { + #[debug_handler] + async fn handler_one(self, foo: String) -> Result { + dbg!(foo); + Ok(()) + } + + #[debug_handler] + async fn handler_two(self, foo: String) -> Result<(), impl IntoResponse> { + dbg!(foo); + Err(()) + } + + #[debug_handler] + async fn handler_three(self, foo: String) -> Result { + dbg!(foo); + Ok::<_, ()>(()) + } + + #[debug_handler] + async fn handler_four(self, foo: String) -> Result { + dbg!(foo); + Err::<(), _>(()) + } +} + +#[async_trait] +impl FromRequest for A +where + B: Send + 'static, +{ + type Rejection = (); + + async fn from_request(_req: &mut RequestParts) -> Result { + unimplemented!() + } +}