Respect an auxiliary parameter

This commit is contained in:
Temirkhan Myrzamadi 2020-07-26 22:53:37 +06:00
parent 2e0c6a57d2
commit 117690514d

View file

@ -18,7 +18,8 @@ use crate::{
use proc_macro::TokenStream; use proc_macro::TokenStream;
use quote::{quote, ToTokens}; use quote::{quote, ToTokens};
use syn::{ use syn::{
parse_macro_input, DeriveInput, Fields, FnArg, ItemEnum, ItemFn, ReturnType, parse_macro_input, DeriveInput, Fields, FnArg, ItemEnum, ItemFn,
ReturnType, Type,
}; };
use std::fmt::Write; use std::fmt::Write;
@ -31,10 +32,10 @@ pub fn teloxide(attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemFn); let input = parse_macro_input!(item as ItemFn);
let params = input.sig.inputs.iter().collect::<Vec<&FnArg>>(); let params = input.sig.inputs.iter().collect::<Vec<&FnArg>>();
if params.len() != 2 { if params.len() != 2 && params.len() != 3 {
panic!( panic!(
"An transition function must accept two parameters: a \ "An transition function must accept two/three parameters: \
state type and TransitionIn" a state type, TransitionIn, and an optional data."
); );
} }
@ -52,17 +53,34 @@ pub fn teloxide(attr: TokenStream, item: TokenStream) -> TokenStream {
type>" type>"
), ),
}; };
let aux_param_type = match params.get(2) {
Some(data_param_type) => match *data_param_type {
FnArg::Typed(typed) => typed.ty.clone(),
_ => unreachable!(),
},
None => {
let unit_type = proc_macro::TokenStream::from(quote! {()});
Box::new(parse_macro_input!(unit_type as Type))
}
};
let call_fn = match params.get(2) {
Some(_) => {
quote! { #fn_name(self, cx, aux) }
}
None => quote! { #fn_name(self, cx) },
};
let item = proc_macro2::TokenStream::from(item_cloned); let item = proc_macro2::TokenStream::from(item_cloned);
let impl_transition = quote! { let impl_transition = quote! {
impl teloxide::dispatching::dialogue::SubTransition< impl teloxide::dispatching::dialogue::SubTransition for #state_type {
<#fn_return_type as teloxide::dispatching::dialogue::SubTransitionOutputType>::Output> type Aux = #aux_param_type;
for #state_type { type Dialogue = <#fn_return_type as teloxide::dispatching::dialogue::SubTransitionOutputType>::Output;
fn react(self, cx: teloxide::dispatching::dialogue::TransitionIn)
fn react(self, cx: teloxide::dispatching::dialogue::TransitionIn, aux: #aux_param_type)
-> futures::future::BoxFuture<'static, #fn_return_type> { -> futures::future::BoxFuture<'static, #fn_return_type> {
#item #item
futures::future::FutureExt::boxed(#call_fn)
futures::future::FutureExt::boxed(#fn_name(self, cx))
} }
} }
}; };
@ -82,11 +100,28 @@ pub fn derive_transition(item: TokenStream) -> TokenStream {
write!( write!(
dispatch_fn, dispatch_fn,
"impl teloxide::dispatching::dialogue::Transition for {} {{ fn \ "impl teloxide::dispatching::dialogue::Transition<<{0} as \
react(self, cx: teloxide::dispatching::dialogue::TransitionIn) -> \ teloxide::dispatching::dialogue::SubTransition>::Aux> for {1} {{ fn \
react(self, cx: teloxide::dispatching::dialogue::TransitionIn, aux: \
<{0} as teloxide::dispatching::dialogue::SubTransition>::Aux) -> \
futures::future::BoxFuture<'static, \ futures::future::BoxFuture<'static, \
teloxide::dispatching::dialogue::TransitionOut<Self>> {{ \ teloxide::dispatching::dialogue::TransitionOut<Self>> {{ \
futures::future::FutureExt::boxed(async {{ match self {{", futures::future::FutureExt::boxed(async {{ match self {{",
// .unwrap() because empty enumerations are not yet allowed in stable
// Rust.
match &input.variants.iter().next().unwrap().fields {
Fields::Unnamed(fields) => {
fields
.unnamed
.iter()
.next()
.unwrap()
.ty
.to_token_stream()
.to_string()
}
_ => panic!("Only one unnamed field per variant is allowed"),
},
input.ident input.ident
) )
.unwrap(); .unwrap();
@ -95,8 +130,8 @@ pub fn derive_transition(item: TokenStream) -> TokenStream {
write!( write!(
dispatch_fn, dispatch_fn,
"{}::{}(state) => \ "{}::{}(state) => \
teloxide::dispatching::dialogue::SubTransition::react(state, \ teloxide::dispatching::dialogue::SubTransition::react(state, cx, \
cx).await,", aux).await,",
input.ident, variant.ident input.ident, variant.ident
) )
.unwrap(); .unwrap();