diff --git a/src/dialogue_state.rs b/src/dialogue_state.rs index fa74e3ae..bd626bd9 100644 --- a/src/dialogue_state.rs +++ b/src/dialogue_state.rs @@ -1,9 +1,9 @@ use proc_macro2::{Ident, Span, TokenStream}; -use quote::{quote, ToTokens}; +use quote::{format_ident, quote, ToTokens}; use syn::{ parse::{Parse, ParseStream}, spanned::Spanned, - Fields, GenericParam, ItemEnum, Path, Type, + Fields, FieldsNamed, GenericParam, ItemEnum, Path, Type, }; pub fn expand(item: ItemEnum) -> Result { @@ -64,12 +64,13 @@ pub fn expand(item: ItemEnum) -> Result { }; branches.extend(match &variant.fields { - Fields::Named(_) => { - return Err(syn::Error::new( - variant.span(), - "Named fields does not allowed", - )) - } + Fields::Named(fields) => create_branch_multiple_fields_named( + &enum_ident, + &self_params, + &variant.ident, + &handler.func, + fields, + ), Fields::Unnamed(fields) => match fields.unnamed.len() { 1 => create_branch_one_field( &enum_ident, @@ -77,12 +78,13 @@ pub fn expand(item: ItemEnum) -> Result { &variant.ident, &handler.func, ), - len => { - return Err(syn::Error::new( - fields.span(), - format!("Expected 1 field, found {}", len), - )); - } + len => create_branch_multiple_fields( + &enum_ident, + &self_params, + &variant.ident, + &handler.func, + len, + ), }, Fields::Unit => create_branch_no_fields( &enum_ident, @@ -142,6 +144,59 @@ fn create_branch_one_field( } } +fn create_branch_multiple_fields( + state: &Ident, + state_generics: impl ToTokens, + kind: &Ident, + handler: &Path, + fields_count: usize, +) -> TokenStream { + let fields = gen_variant_field_names(fields_count); + + quote! { + .branch( + dptree::filter_map(|state: #state #state_generics| async move { + match state { #state::#kind(#fields) => Some((#fields)), _ => None } + }).endpoint(#handler) + ) + } +} + +fn gen_variant_field_names(len: usize) -> TokenStream { + let mut fields = quote! {}; + + for i in 0..len { + let idx = format_ident!("_{}", i); + fields.extend(quote! { #idx, }); + } + + return fields; +} + +fn create_branch_multiple_fields_named( + state: &Ident, + state_generics: impl ToTokens, + kind: &Ident, + handler: &Path, + fields_named: &FieldsNamed, +) -> TokenStream { + let mut fields = quote! {}; + + for field in fields_named.named.iter() { + let ident = + field.ident.as_ref().expect("Named fields must have identifiers"); + fields.extend(quote! { #ident, }); + } + + quote! { + .branch( + dptree::filter_map(|state: #state #state_generics| async move { + match state { #state::#kind { #fields } => Some((#fields)), _ => None } + }).endpoint(#handler) + ) + } +} + fn parse_out_type( span: Span, attrs: &[syn::Attribute],