diff --git a/src/lib.rs b/src/lib.rs index d6ce22ec..30583ba8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,7 +18,8 @@ use crate::{ use proc_macro::TokenStream; use quote::{quote, ToTokens}; use syn::{ - parse_macro_input, DeriveInput, Fields, FnArg, ItemEnum, ItemFn, ReturnType, + parse_macro_input, DeriveInput, Fields, FnArg, ItemEnum, ItemFn, + ReturnType, Type, }; use std::fmt::Write; @@ -31,10 +32,10 @@ pub fn teloxide(attr: TokenStream, item: TokenStream) -> TokenStream { let input = parse_macro_input!(item as ItemFn); let params = input.sig.inputs.iter().collect::>(); - if params.len() != 2 { + if params.len() != 2 && params.len() != 3 { panic!( - "An transition function must accept two parameters: a \ - state type and TransitionIn" + "An transition function must accept two/three parameters: \ + a state type, TransitionIn, and an optional data." ); } @@ -52,17 +53,34 @@ pub fn teloxide(attr: TokenStream, item: TokenStream) -> TokenStream { type>" ), }; + let aux_param_type = match params.get(2) { + Some(data_param_type) => match *data_param_type { + FnArg::Typed(typed) => typed.ty.clone(), + _ => unreachable!(), + }, + None => { + let unit_type = proc_macro::TokenStream::from(quote! {()}); + Box::new(parse_macro_input!(unit_type as Type)) + } + }; + let call_fn = match params.get(2) { + Some(_) => { + quote! { #fn_name(self, cx, aux) } + } + None => quote! { #fn_name(self, cx) }, + }; + let item = proc_macro2::TokenStream::from(item_cloned); let impl_transition = quote! { - impl teloxide::dispatching::dialogue::SubTransition< - <#fn_return_type as teloxide::dispatching::dialogue::SubTransitionOutputType>::Output> - for #state_type { - fn react(self, cx: teloxide::dispatching::dialogue::TransitionIn) + impl teloxide::dispatching::dialogue::SubTransition for #state_type { + type Aux = #aux_param_type; + type Dialogue = <#fn_return_type as teloxide::dispatching::dialogue::SubTransitionOutputType>::Output; + + fn react(self, cx: teloxide::dispatching::dialogue::TransitionIn, aux: #aux_param_type) -> futures::future::BoxFuture<'static, #fn_return_type> { #item - - futures::future::FutureExt::boxed(#fn_name(self, cx)) + futures::future::FutureExt::boxed(#call_fn) } } }; @@ -82,11 +100,28 @@ pub fn derive_transition(item: TokenStream) -> TokenStream { write!( dispatch_fn, - "impl teloxide::dispatching::dialogue::Transition for {} {{ fn \ - react(self, cx: teloxide::dispatching::dialogue::TransitionIn) -> \ + "impl teloxide::dispatching::dialogue::Transition<<{0} as \ + teloxide::dispatching::dialogue::SubTransition>::Aux> for {1} {{ fn \ + react(self, cx: teloxide::dispatching::dialogue::TransitionIn, aux: \ + <{0} as teloxide::dispatching::dialogue::SubTransition>::Aux) -> \ futures::future::BoxFuture<'static, \ teloxide::dispatching::dialogue::TransitionOut> {{ \ futures::future::FutureExt::boxed(async {{ match self {{", + // .unwrap() because empty enumerations are not yet allowed in stable + // Rust. + match &input.variants.iter().next().unwrap().fields { + Fields::Unnamed(fields) => { + fields + .unnamed + .iter() + .next() + .unwrap() + .ty + .to_token_stream() + .to_string() + } + _ => panic!("Only one unnamed field per variant is allowed"), + }, input.ident ) .unwrap(); @@ -95,8 +130,8 @@ pub fn derive_transition(item: TokenStream) -> TokenStream { write!( dispatch_fn, "{}::{}(state) => \ - teloxide::dispatching::dialogue::SubTransition::react(state, \ - cx).await,", + teloxide::dispatching::dialogue::SubTransition::react(state, cx, \ + aux).await,", input.ident, variant.ident ) .unwrap();