Respect an auxiliary parameter

This commit is contained in:
Temirkhan Myrzamadi 2020-07-26 22:53:37 +06:00
parent 2e0c6a57d2
commit 117690514d

View file

@ -18,7 +18,8 @@ use crate::{
use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::{
parse_macro_input, DeriveInput, Fields, FnArg, ItemEnum, ItemFn, ReturnType,
parse_macro_input, DeriveInput, Fields, FnArg, ItemEnum, ItemFn,
ReturnType, Type,
};
use std::fmt::Write;
@ -31,10 +32,10 @@ pub fn teloxide(attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemFn);
let params = input.sig.inputs.iter().collect::<Vec<&FnArg>>();
if params.len() != 2 {
if params.len() != 2 && params.len() != 3 {
panic!(
"An transition function must accept two parameters: a \
state type and TransitionIn"
"An transition function must accept two/three parameters: \
a state type, TransitionIn, and an optional data."
);
}
@ -52,17 +53,34 @@ pub fn teloxide(attr: TokenStream, item: TokenStream) -> TokenStream {
type>"
),
};
let aux_param_type = match params.get(2) {
Some(data_param_type) => match *data_param_type {
FnArg::Typed(typed) => typed.ty.clone(),
_ => unreachable!(),
},
None => {
let unit_type = proc_macro::TokenStream::from(quote! {()});
Box::new(parse_macro_input!(unit_type as Type))
}
};
let call_fn = match params.get(2) {
Some(_) => {
quote! { #fn_name(self, cx, aux) }
}
None => quote! { #fn_name(self, cx) },
};
let item = proc_macro2::TokenStream::from(item_cloned);
let impl_transition = quote! {
impl teloxide::dispatching::dialogue::SubTransition<
<#fn_return_type as teloxide::dispatching::dialogue::SubTransitionOutputType>::Output>
for #state_type {
fn react(self, cx: teloxide::dispatching::dialogue::TransitionIn)
impl teloxide::dispatching::dialogue::SubTransition for #state_type {
type Aux = #aux_param_type;
type Dialogue = <#fn_return_type as teloxide::dispatching::dialogue::SubTransitionOutputType>::Output;
fn react(self, cx: teloxide::dispatching::dialogue::TransitionIn, aux: #aux_param_type)
-> futures::future::BoxFuture<'static, #fn_return_type> {
#item
futures::future::FutureExt::boxed(#fn_name(self, cx))
futures::future::FutureExt::boxed(#call_fn)
}
}
};
@ -82,11 +100,28 @@ pub fn derive_transition(item: TokenStream) -> TokenStream {
write!(
dispatch_fn,
"impl teloxide::dispatching::dialogue::Transition for {} {{ fn \
react(self, cx: teloxide::dispatching::dialogue::TransitionIn) -> \
"impl teloxide::dispatching::dialogue::Transition<<{0} as \
teloxide::dispatching::dialogue::SubTransition>::Aux> for {1} {{ fn \
react(self, cx: teloxide::dispatching::dialogue::TransitionIn, aux: \
<{0} as teloxide::dispatching::dialogue::SubTransition>::Aux) -> \
futures::future::BoxFuture<'static, \
teloxide::dispatching::dialogue::TransitionOut<Self>> {{ \
futures::future::FutureExt::boxed(async {{ match self {{",
// .unwrap() because empty enumerations are not yet allowed in stable
// Rust.
match &input.variants.iter().next().unwrap().fields {
Fields::Unnamed(fields) => {
fields
.unnamed
.iter()
.next()
.unwrap()
.ty
.to_token_stream()
.to_string()
}
_ => panic!("Only one unnamed field per variant is allowed"),
},
input.ident
)
.unwrap();
@ -95,8 +130,8 @@ pub fn derive_transition(item: TokenStream) -> TokenStream {
write!(
dispatch_fn,
"{}::{}(state) => \
teloxide::dispatching::dialogue::SubTransition::react(state, \
cx).await,",
teloxide::dispatching::dialogue::SubTransition::react(state, cx, \
aux).await,",
input.ident, variant.ident
)
.unwrap();