diff --git a/src/dispatching/dialogue/mod.rs b/src/dispatching/dialogue/mod.rs index 9efe7416..de466e8a 100644 --- a/src/dispatching/dialogue/mod.rs +++ b/src/dispatching/dialogue/mod.rs @@ -190,11 +190,11 @@ where /// (This next handler can be an endpoint or a more complex one.) The payload /// format depend on the form of `MyVariant`: /// -/// - For `State::MyVariant(param)`, the payload is `param`. -/// - For `State::MyVariant(param1, ..., paramN)`, the payload is `(param1, -/// ..., paramN)` (where `N` > 1). -/// - For `State::MyVariant { param1, ..., paramN }`, the payload is `(param1, -/// ..., paramN)`. +/// - For `State::MyVariant(param)` and `State::MyVariant { param }`, the +/// payload is `param`. +/// - For `State::MyVariant(param1, ..., paramN)` and `State::MyVariant { +/// param1, ..., paramN }`, the payload is `(param1, ..., paramN)` (where `N` +/// > 1). /// /// ## Dependency requirements /// @@ -216,6 +216,12 @@ macro_rules! handler { _ => None, }) }; + ($($variant:ident)::+ {$param:ident}) => { + $crate::dptree::filter_map(|state| match state { + $($variant)::+{$param} => Some($param), + _ => None, + }) + }; ($($variant:ident)::+ {$($param:ident),+ $(,)?}) => { $crate::dptree::filter_map(|state| match state { $($variant)::+ { $($param),+ } => Some(($($param),+ ,)), @@ -259,6 +265,18 @@ mod tests { assert!(matches!(h.dispatch(dptree::deps![State::Other]).await, ControlFlow::Continue(_))); } + #[tokio::test] + async fn handler_single_fn_variant_trailing_comma() { + let input = State::B(42); + let h = handler![State::B(x,)].endpoint(|(x,): (i32,)| async move { + assert_eq!(x, 42); + 123 + }); + + assert_eq!(h.dispatch(dptree::deps![input]).await, ControlFlow::Break(123)); + assert!(matches!(h.dispatch(dptree::deps![State::Other]).await, ControlFlow::Continue(_))); + } + #[tokio::test] async fn handler_fn_variant() { let input = State::C(42, "abc"); @@ -275,7 +293,20 @@ mod tests { #[tokio::test] async fn handler_single_struct_variant() { let input = State::D { foo: 42 }; - let h = handler![State::D { foo }].endpoint(|(x,): (i32,)| async move { + let h = handler![State::D { foo }].endpoint(|x: i32| async move { + assert_eq!(x, 42); + 123 + }); + + assert_eq!(h.dispatch(dptree::deps![input]).await, ControlFlow::Break(123)); + assert!(matches!(h.dispatch(dptree::deps![State::Other]).await, ControlFlow::Continue(_))); + } + + #[tokio::test] + async fn handler_single_struct_variant_trailing_comma() { + let input = State::D { foo: 42 }; + #[rustfmt::skip] // rustfmt removes the trailing comma from `State::D { foo, }`, but it plays a vital role in this test. + let h = handler![State::D { foo, }].endpoint(|(x,): (i32,)| async move { assert_eq!(x, 42); 123 });