From 05fa8885b590d14a9b5a3fc7e0f962f6fa293831 Mon Sep 17 00:00:00 2001 From: p0lunin Date: Wed, 29 Dec 2021 14:22:52 +0200 Subject: [PATCH] add dialogue_state.rs --- src/dialogue_state.rs | 172 ++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 13 ++++ 2 files changed, 185 insertions(+) create mode 100644 src/dialogue_state.rs diff --git a/src/dialogue_state.rs b/src/dialogue_state.rs new file mode 100644 index 00000000..fc2ff643 --- /dev/null +++ b/src/dialogue_state.rs @@ -0,0 +1,172 @@ +use proc_macro2::{Ident, Span, TokenStream}; +use quote::{quote, ToTokens}; +use syn::{ + parse::{Parse, ParseStream}, + spanned::Spanned, + Fields, GenericParam, ItemEnum, Path, Type, +}; + +pub fn expand(item: ItemEnum) -> Result { + let enum_ident = &item.ident; + let self_params_with_bounds = { + let params = &item.generics.params; + if params.len() != 0 { + quote! { < #params > } + } else { + quote! {} + } + }; + let self_params = { + let params = &item.generics.params; + if params.len() != 0 { + let mut params = quote! { < }; + item.generics.params.iter().for_each(|param| match param { + GenericParam::Type(ty) => { + let ident = &ty.ident; + params.extend(quote! { #ident, }); + } + GenericParam::Lifetime(li) => { + let li = &li.lifetime; + params.extend(quote! { #li, }) + } + GenericParam::Const(_par) => todo!(), + }); + params.extend(quote! { > }); + params + } else { + quote! {} + } + }; + let where_clause = match item.generics.where_clause.clone() { + Some(mut clause) => { + let predicate = quote! { Self: Clone + Send + Sync + 'static }; + clause.predicates.push(syn::parse2(predicate).unwrap()); + Some(clause) + } + x => x, + }; + let out = parse_out_type(item.ident.span(), &item.attrs)?; + + let mut branches = quote! {}; + for variant in item.variants.iter() { + let handler = { + let handler_attr = variant + .attrs + .iter() + .find(|attr| attr.path.is_ident("handler")) + .ok_or_else(|| { + syn::Error::new( + variant.span(), + "Expected `handler` attribute.", + ) + })?; + handler_attr.parse_args::()? + }; + + branches.extend(match &variant.fields { + Fields::Named(_) => { + return Err(syn::Error::new( + variant.span(), + "Named fields does not allowed", + )) + } + Fields::Unnamed(fields) => match fields.unnamed.len() { + 1 => create_branch_one_field( + &enum_ident, + &self_params, + &variant.ident, + &handler.func, + ), + len => { + return Err(syn::Error::new( + fields.span(), + format!("Expected 1 field, found {}", len), + )); + } + }, + Fields::Unit => create_branch_no_fields( + &enum_ident, + &self_params, + &variant.ident, + &handler.func, + ), + }); + } + + Ok(quote! {const _: () = { + fn assert_clone() {} + + use teloxide::dptree; + use teloxide::dispatching2::dialogue::Dialogue; + + impl #self_params_with_bounds teloxide::dispatching2::HandlerFactory for #enum_ident #self_params #where_clause { + type Out = #out; + + fn handler() -> dptree::Handler<'static, dptree::di::DependencyMap, Self::Out> { + assert_clone::<#enum_ident #self_params>(); + + dptree::entry() + #branches + } + } + };}) +} + +fn create_branch_no_fields( + state: &Ident, + state_generics: impl ToTokens, + kind: &Ident, + handler: &Path, +) -> TokenStream { + quote! { + .branch( + dptree::filter(|state: #state #state_generics| async move { + match state { #state::#kind => true, _ => false } + }).endpoint(#handler) + ) + } +} + +fn create_branch_one_field( + state: &Ident, + state_generics: impl ToTokens, + kind: &Ident, + handler: &Path, +) -> TokenStream { + quote! { + .branch( + dptree::filter_map(|state: #state #state_generics| async move { + match state { #state::#kind(arg) => Some(arg), _ => None } + }).endpoint(#handler) + ) + } +} + +fn parse_out_type( + span: Span, + attrs: &[syn::Attribute], +) -> Result { + let mut out = None; + for x in attrs { + if x.path.is_ident("out") { + out = Some(x.parse_args::()?); + } + } + if let Some(out) = out { + return Ok(out); + } + Err(syn::Error::new( + span, + "There are must be 2 attributes: `out` and `store`", + )) +} + +pub struct HandlerAttr { + func: Path, +} + +impl Parse for HandlerAttr { + fn parse(input: ParseStream) -> Result { + Ok(Self { func: input.parse::()? }) + } +} diff --git a/src/lib.rs b/src/lib.rs index 5744c436..02576212 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ mod attr; mod command; mod command_enum; +mod dialogue_state; mod fields_parse; mod rename_rules; @@ -24,6 +25,18 @@ use syn::{ use std::fmt::Write; +#[proc_macro_derive(DialogueState, attributes(handler, out, store))] +pub fn derive_dialogue_state(item: TokenStream) -> TokenStream { + let input = parse_macro_input!(item as ItemEnum); + match dialogue_state::expand(input) { + Ok(s) => { + let s = s.into(); + s + } + Err(e) => e.into_compile_error().into(), + } +} + /// The docs is below. /// /// The only accepted form at the current moment is `#[teloxide(subtransition)]`