[WIP] implement allowed updates in dispatcher

This commit is contained in:
Maybe Waffle 2022-04-13 15:44:18 +04:00
parent 03521bfd3d
commit 4db52436f3
11 changed files with 179 additions and 97 deletions

View file

@ -62,7 +62,8 @@ teloxide-macros = { git = "https://github.com/teloxide/teloxide-macros.git", rev
serde_json = "1.0" serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
dptree = { version = "0.1.0" } #dptree = { version = "0.1.0" }
dptree = { git = "https://github.com/WaffleLapkin/dptree", rev = "192f3fe" }
tokio = { version = "1.8", features = ["fs"] } tokio = { version = "1.8", features = ["fs"] }
tokio-util = "0.6" tokio-util = "0.6"

View file

@ -0,0 +1,58 @@
use std::collections::HashSet;
use dptree::{MaybeSpecial, UpdateSet};
use teloxide_core::types::AllowedUpdate;
pub struct AllowedUpdates {
inner: MaybeSpecial<HashSet<AllowedUpdate>>,
}
impl AllowedUpdates {
pub(crate) fn of(allowed: AllowedUpdate) -> Self {
let mut set = HashSet::with_capacity(1);
set.insert(allowed);
Self { inner: MaybeSpecial::Known(set) }
}
pub(crate) fn get_param(&self) -> Vec<AllowedUpdate> {
use AllowedUpdate::*;
match &self.inner {
MaybeSpecial::Known(set) => set.iter().cloned().collect(),
MaybeSpecial::Invisible => panic!("No updates were allowed"),
MaybeSpecial::Unknown => vec![
Message,
EditedMessage,
ChannelPost,
EditedChannelPost,
InlineQuery,
ChosenInlineResult,
CallbackQuery,
ShippingQuery,
PreCheckoutQuery,
Poll,
PollAnswer,
MyChatMember,
ChatMember,
],
}
}
}
impl UpdateSet for AllowedUpdates {
fn unknown() -> Self {
Self { inner: UpdateSet::unknown() }
}
fn invisible() -> Self {
Self { inner: UpdateSet::invisible() }
}
fn union(&self, other: &Self) -> Self {
Self { inner: self.inner.union(&other.inner) }
}
fn intersection(&self, other: &Self) -> Self {
Self { inner: self.inner.intersection(&other.inner) }
}
}

View file

@ -244,6 +244,8 @@ macro_rules! handler {
mod tests { mod tests {
use std::ops::ControlFlow; use std::ops::ControlFlow;
use crate::dispatching::UpdateHandler;
#[derive(Debug, Copy, Clone, Eq, PartialEq)] #[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum State { enum State {
A, A,
@ -257,7 +259,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn handler_empty_variant() { async fn handler_empty_variant() {
let input = State::A; let input = State::A;
let h = handler![State::A].endpoint(|| async move { 123 }); let h: dptree::Handler<_, _> = handler![State::A].endpoint(|| async move { 123 });
assert_eq!(h.dispatch(dptree::deps![input]).await, ControlFlow::Break(123)); assert_eq!(h.dispatch(dptree::deps![input]).await, ControlFlow::Break(123));
assert!(matches!(h.dispatch(dptree::deps![State::Other]).await, ControlFlow::Continue(_))); assert!(matches!(h.dispatch(dptree::deps![State::Other]).await, ControlFlow::Continue(_)));
@ -266,7 +268,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn handler_single_fn_variant() { async fn handler_single_fn_variant() {
let input = State::B(42); let input = State::B(42);
let h = handler![State::B(x)].endpoint(|x: i32| async move { let h: dptree::Handler<_, _> = handler![State::B(x)].endpoint(|x: i32| async move {
assert_eq!(x, 42); assert_eq!(x, 42);
123 123
}); });
@ -278,7 +280,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn handler_single_fn_variant_trailing_comma() { async fn handler_single_fn_variant_trailing_comma() {
let input = State::B(42); let input = State::B(42);
let h = handler![State::B(x,)].endpoint(|(x,): (i32,)| async move { let h: dptree::Handler<_, _> = handler![State::B(x,)].endpoint(|(x,): (i32,)| async move {
assert_eq!(x, 42); assert_eq!(x, 42);
123 123
}); });
@ -290,11 +292,12 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn handler_fn_variant() { async fn handler_fn_variant() {
let input = State::C(42, "abc"); let input = State::C(42, "abc");
let h = handler![State::C(x, y)].endpoint(|(x, str): (i32, &'static str)| async move { let h: dptree::Handler<_, _> =
assert_eq!(x, 42); handler![State::C(x, y)].endpoint(|(x, str): (i32, &'static str)| async move {
assert_eq!(str, "abc"); assert_eq!(x, 42);
123 assert_eq!(str, "abc");
}); 123
});
assert_eq!(h.dispatch(dptree::deps![input]).await, ControlFlow::Break(123)); assert_eq!(h.dispatch(dptree::deps![input]).await, ControlFlow::Break(123));
assert!(matches!(h.dispatch(dptree::deps![State::Other]).await, ControlFlow::Continue(_))); assert!(matches!(h.dispatch(dptree::deps![State::Other]).await, ControlFlow::Continue(_)));
@ -303,7 +306,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn handler_single_struct_variant() { async fn handler_single_struct_variant() {
let input = State::D { foo: 42 }; let input = State::D { foo: 42 };
let h = handler![State::D { foo }].endpoint(|x: i32| async move { let h: dptree::Handler<_, _> = handler![State::D { foo }].endpoint(|x: i32| async move {
assert_eq!(x, 42); assert_eq!(x, 42);
123 123
}); });
@ -316,7 +319,7 @@ mod tests {
async fn handler_single_struct_variant_trailing_comma() { async fn handler_single_struct_variant_trailing_comma() {
let input = State::D { foo: 42 }; let input = State::D { foo: 42 };
#[rustfmt::skip] // rustfmt removes the trailing comma from `State::D { foo, }`, but it plays a vital role in this test. #[rustfmt::skip] // rustfmt removes the trailing comma from `State::D { foo, }`, but it plays a vital role in this test.
let h = handler![State::D { foo, }].endpoint(|(x,): (i32,)| async move { let h: dptree::Handler<_, _> = handler![State::D { foo, }].endpoint(|(x,): (i32,)| async move {
assert_eq!(x, 42); assert_eq!(x, 42);
123 123
}); });
@ -328,7 +331,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn handler_struct_variant() { async fn handler_struct_variant() {
let input = State::E { foo: 42, bar: "abc" }; let input = State::E { foo: 42, bar: "abc" };
let h = let h: dptree::Handler<_, _> =
handler![State::E { foo, bar }].endpoint(|(x, str): (i32, &'static str)| async move { handler![State::E { foo, bar }].endpoint(|(x, str): (i32, &'static str)| async move {
assert_eq!(x, 42); assert_eq!(x, 42);
assert_eq!(str, "abc"); assert_eq!(str, "abc");

View file

@ -1,18 +1,18 @@
use crate::{ use crate::{
dispatching::{ dispatching::{
distribution::default_distribution_function, stop_token::StopToken, update_listeners, distribution::default_distribution_function, stop_token::StopToken, update_listeners,
update_listeners::UpdateListener, DefaultKey, ShutdownToken, update_listeners::UpdateListener, AllowedUpdates, DefaultKey, ShutdownToken,
}, },
error_handlers::{ErrorHandler, LoggingErrorHandler}, error_handlers::{ErrorHandler, LoggingErrorHandler},
requests::{Request, Requester}, requests::{Request, Requester},
types::{AllowedUpdate, Update, UpdateKind}, types::{Update, UpdateKind},
utils::shutdown_token::shutdown_check_timeout_for, utils::shutdown_token::shutdown_check_timeout_for,
}; };
use dptree::di::{DependencyMap, DependencySupplier}; use dptree::di::{DependencyMap, DependencySupplier};
use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
use std::{ use std::{
collections::{HashMap, HashSet}, collections::HashMap,
fmt::Debug, fmt::Debug,
hash::Hash, hash::Hash,
ops::{ControlFlow, Deref}, ops::{ControlFlow, Deref},
@ -132,7 +132,6 @@ where
handler, handler,
default_handler, default_handler,
error_handler, error_handler,
allowed_updates: Default::default(),
state: ShutdownToken::new(), state: ShutdownToken::new(),
distribution_f, distribution_f,
worker_queue_size, worker_queue_size,
@ -165,8 +164,6 @@ pub struct Dispatcher<R, Err, Key> {
default_worker: Option<Worker>, default_worker: Option<Worker>,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>, error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
// TODO: respect allowed_udpates
allowed_updates: HashSet<AllowedUpdate>,
state: ShutdownToken, state: ShutdownToken,
} }
@ -180,7 +177,8 @@ struct Worker {
// webhooks, so we can allow this too. See more there: https://core.telegram.org/bots/api#making-requests-when-getting-updates // webhooks, so we can allow this too. See more there: https://core.telegram.org/bots/api#making-requests-when-getting-updates
/// A handler that processes updates from Telegram. /// A handler that processes updates from Telegram.
pub type UpdateHandler<Err> = dptree::Handler<'static, DependencyMap, Result<(), Err>>; pub type UpdateHandler<Err> =
dptree::Handler<'static, DependencyMap, Result<(), Err>, AllowedUpdates>;
type DefaultHandler = Arc<dyn Fn(Arc<Update>) -> BoxFuture<'static, ()> + Send + Sync>; type DefaultHandler = Arc<dyn Fn(Arc<Update>) -> BoxFuture<'static, ()> + Send + Sync>;
@ -267,7 +265,9 @@ where
self.dependencies.insert(me); self.dependencies.insert(me);
self.dependencies.insert(self.bot.clone()); self.dependencies.insert(self.bot.clone());
update_listener.hint_allowed_updates(&mut self.allowed_updates.clone().into_iter()); update_listener.hint_allowed_updates(
&mut self.handler.required_update_kinds_set().get_param().into_iter(),
);
let shutdown_check_timeout = shutdown_check_timeout_for(&update_listener); let shutdown_check_timeout = shutdown_check_timeout_for(&update_listener);
let mut stop_token = Some(update_listener.stop_token()); let mut stop_token = Some(update_listener.stop_token());

View file

@ -1,10 +1,14 @@
#![allow(clippy::redundant_closure_call)] #![allow(clippy::redundant_closure_call)]
use dptree::{di::DependencyMap, Handler}; use dptree::{di::DependencyMap, Handler};
use teloxide_core::types::{Message, Update, UpdateKind};
use crate::{
dispatching::AllowedUpdates,
types::{AllowedUpdate, Message, Update, UpdateKind},
};
macro_rules! define_ext { macro_rules! define_ext {
($ext_name:ident, $for_ty:ty => $( ($func:ident, $proj_fn:expr, $fn_doc:expr) ,)*) => { ($ext_name:ident, $for_ty:ty => $( ($func:ident, $proj_fn:expr, $fn_doc:expr $(, $Allowed:ident)? ) ,)*) => {
#[doc = concat!("Filter methods for [`", stringify!($for_ty), "`].")] #[doc = concat!("Filter methods for [`", stringify!($for_ty), "`].")]
pub trait $ext_name<Out>: private::Sealed { pub trait $ext_name<Out>: private::Sealed {
$( define_ext!(@sig $func, $fn_doc); )* $( define_ext!(@sig $func, $fn_doc); )*
@ -14,17 +18,25 @@ macro_rules! define_ext {
where where
Out: Send + Sync + 'static, Out: Send + Sync + 'static,
{ {
$( define_ext!(@impl $for_ty, $func, $proj_fn); )* $( define_ext!(@impl $for_ty, $func, $proj_fn $(, $Allowed )? ); )*
} }
}; };
(@sig $func:ident, $fn_doc:expr) => { (@sig $func:ident, $fn_doc:expr) => {
#[doc = $fn_doc] #[doc = $fn_doc]
fn $func() -> Handler<'static, DependencyMap, Out>; fn $func() -> Handler<'static, DependencyMap, Out, AllowedUpdates>;
};
(@impl $for_ty:ty, $func:ident, $proj_fn:expr, $Allowed:ident) => {
fn $func() -> Handler<'static, DependencyMap, Out, AllowedUpdates> {
dptree::filter_map_with_requirements(AllowedUpdates::of(AllowedUpdate::$Allowed), move |input: $for_ty| {
$proj_fn(input)
})
}
}; };
(@impl $for_ty:ty, $func:ident, $proj_fn:expr) => { (@impl $for_ty:ty, $func:ident, $proj_fn:expr) => {
fn $func() -> Handler<'static, DependencyMap, Out> { fn $func() -> Handler<'static, DependencyMap, Out, AllowedUpdates> {
dptree::filter_map(move |input: $for_ty| { dptree::filter_map(move |input: $for_ty| {
$proj_fn(input) $proj_fn(input)
}) })
@ -75,7 +87,7 @@ define_message_ext! {
} }
macro_rules! define_update_ext { macro_rules! define_update_ext {
($( ($func:ident, $kind:path) ,)*) => { ($( ($func:ident, $kind:path, $Allowed:ident) ,)*) => {
define_ext! { define_ext! {
UpdateFilterExt, crate::types::Update => UpdateFilterExt, crate::types::Update =>
$(( $((
@ -84,7 +96,8 @@ macro_rules! define_update_ext {
$kind(x) => Some(x), $kind(x) => Some(x),
_ => None, _ => None,
}, },
concat!("Filters out [`crate::types::", stringify!($kind), "`] objects.") concat!("Filters out [`crate::types::", stringify!($kind), "`] objects."),
$Allowed
),)* ),)*
} }
} }
@ -92,17 +105,17 @@ macro_rules! define_update_ext {
// May be expanded in the future. // May be expanded in the future.
define_update_ext! { define_update_ext! {
(filter_message, UpdateKind::Message), (filter_message, UpdateKind::Message, Message),
(filter_edited_message, UpdateKind::EditedMessage), (filter_edited_message, UpdateKind::EditedMessage, EditedMessage),
(filter_channel_post, UpdateKind::ChannelPost), (filter_channel_post, UpdateKind::ChannelPost, ChannelPost),
(filter_edited_channel_post, UpdateKind::EditedChannelPost), (filter_edited_channel_post, UpdateKind::EditedChannelPost, EditedChannelPost),
(filter_inline_query, UpdateKind::InlineQuery), (filter_inline_query, UpdateKind::InlineQuery, InlineQuery),
(filter_chosen_inline_result, UpdateKind::ChosenInlineResult), (filter_chosen_inline_result, UpdateKind::ChosenInlineResult, ChosenInlineResult),
(filter_callback_query, UpdateKind::CallbackQuery), (filter_callback_query, UpdateKind::CallbackQuery, CallbackQuery),
(filter_shipping_query, UpdateKind::ShippingQuery), (filter_shipping_query, UpdateKind::ShippingQuery, ShippingQuery),
(filter_pre_checkout_query, UpdateKind::PreCheckoutQuery), (filter_pre_checkout_query, UpdateKind::PreCheckoutQuery, PreCheckoutQuery),
(filter_poll, UpdateKind::Poll), (filter_poll, UpdateKind::Poll, Poll),
(filter_poll_answer, UpdateKind::PollAnswer), (filter_poll_answer, UpdateKind::PollAnswer, PollAnswer),
(filter_my_chat_member, UpdateKind::MyChatMember), (filter_my_chat_member, UpdateKind::MyChatMember, MyChatMember),
(filter_chat_member, UpdateKind::ChatMember), (filter_chat_member, UpdateKind::ChatMember, ChatMember),
} }

View file

@ -1,7 +1,10 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{ use crate::{
dispatching::dialogue::{Dialogue, GetChatId, Storage}, dispatching::{
dialogue::{Dialogue, GetChatId, Storage},
AllowedUpdates,
},
types::{Me, Message}, types::{Me, Message},
utils::command::BotCommands, utils::command::BotCommands,
}; };
@ -58,7 +61,7 @@ pub trait HandlerExt<Output> {
F: HandlerFactory<Out = Output>; F: HandlerFactory<Out = Output>;
} }
impl<Output> HandlerExt<Output> for Handler<'static, DependencyMap, Output> impl<Output> HandlerExt<Output> for Handler<'static, DependencyMap, Output, AllowedUpdates>
where where
Output: Send + Sync + 'static, Output: Send + Sync + 'static,
{ {

View file

@ -1,9 +1,11 @@
use dptree::{di::DependencyMap, Handler}; use dptree::{di::DependencyMap, Handler};
use crate::dispatching::AllowedUpdates;
/// Something that can construct a handler. /// Something that can construct a handler.
#[deprecated(note = "Use the teloxide::handler! API")] #[deprecated(note = "Use the teloxide::handler! API")]
pub trait HandlerFactory { pub trait HandlerFactory {
type Out; type Out;
fn handler() -> Handler<'static, DependencyMap, Self::Out>; fn handler() -> Handler<'static, DependencyMap, Self::Out, AllowedUpdates>;
} }

View file

@ -98,6 +98,7 @@
#[cfg(all(feature = "ctrlc_handler"))] #[cfg(all(feature = "ctrlc_handler"))]
pub mod repls; pub mod repls;
mod allowed_updates;
pub mod dialogue; pub mod dialogue;
mod dispatcher; mod dispatcher;
mod distribution; mod distribution;
@ -108,6 +109,7 @@ pub mod stop_token;
pub mod update_listeners; pub mod update_listeners;
pub use crate::utils::shutdown_token::{IdleShutdownError, ShutdownToken}; pub use crate::utils::shutdown_token::{IdleShutdownError, ShutdownToken};
pub use allowed_updates::AllowedUpdates;
pub use dispatcher::{Dispatcher, DispatcherBuilder, UpdateHandler}; pub use dispatcher::{Dispatcher, DispatcherBuilder, UpdateHandler};
pub use distribution::DefaultKey; pub use distribution::DefaultKey;
pub use filter_ext::{MessageFilterExt, UpdateFilterExt}; pub use filter_ext::{MessageFilterExt, UpdateFilterExt};

View file

@ -83,7 +83,7 @@ pub async fn commands_repl_with_listener<'a, R, Cmd, H, L, ListenerE, E, Args>(
Dispatcher::builder( Dispatcher::builder(
bot, bot,
Update::filter_message().filter_command::<Cmd>().branch(dptree::endpoint(handler)), Update::filter_message().filter_command::<Cmd>().chain(dptree::endpoint(handler)),
) )
.default_handler(ignore_update) .default_handler(ignore_update)
.build() .build()

View file

@ -59,7 +59,7 @@ where
// messages. See <https://github.com/teloxide/teloxide/issues/557>. // messages. See <https://github.com/teloxide/teloxide/issues/557>.
let ignore_update = |_upd| Box::pin(async {}); let ignore_update = |_upd| Box::pin(async {});
Dispatcher::builder(bot, Update::filter_message().branch(dptree::endpoint(handler))) Dispatcher::builder(bot, Update::filter_message().chain(dptree::endpoint(handler)))
.default_handler(ignore_update) .default_handler(ignore_update)
.build() .build()
.setup_ctrlc_handler() .setup_ctrlc_handler()

View file

@ -1,62 +1,62 @@
#[cfg(feature = "macros")] // #[cfg(feature = "macros")]
use teloxide::macros::DialogueState; // use teloxide::macros::DialogueState;
// We put tests here because macro expand in unit tests in the crate was a // // We put tests here because macro expand in unit tests in the crate was a
// failure // // failure
#[test] // #[test]
#[cfg(feature = "macros")] // #[cfg(feature = "macros")]
fn compile_test() { // fn compile_test() {
#[allow(dead_code)] // #[allow(dead_code)]
#[derive(DialogueState, Clone)] // #[derive(DialogueState, Clone)]
#[handler_out(Result<(), teloxide::RequestError>)] // #[handler_out(Result<(), teloxide::RequestError>)]
enum State { // enum State {
#[handler(handle_start)] // #[handler(handle_start)]
Start, // Start,
#[handler(handle_have_data)] // #[handler(handle_have_data)]
HaveData(String), // HaveData(String),
} // }
impl Default for State { // impl Default for State {
fn default() -> Self { // fn default() -> Self {
Self::Start // Self::Start
} // }
} // }
async fn handle_start() -> Result<(), teloxide::RequestError> { // async fn handle_start() -> Result<(), teloxide::RequestError> {
Ok(()) // Ok(())
} // }
async fn handle_have_data() -> Result<(), teloxide::RequestError> { // async fn handle_have_data() -> Result<(), teloxide::RequestError> {
Ok(()) // Ok(())
} // }
} // }
#[test] // #[test]
#[cfg(feature = "macros")] // #[cfg(feature = "macros")]
fn compile_test_generics() { // fn compile_test_generics() {
#[allow(dead_code)] // #[allow(dead_code)]
#[derive(DialogueState, Clone)] // #[derive(DialogueState, Clone)]
#[handler_out(Result<(), teloxide::RequestError>)] // #[handler_out(Result<(), teloxide::RequestError>)]
enum State<X: Clone + Send + Sync + 'static> { // enum State<X: Clone + Send + Sync + 'static> {
#[handler(handle_start)] // #[handler(handle_start)]
Start, // Start,
#[handler(handle_have_data)] // #[handler(handle_have_data)]
HaveData(X), // HaveData(X),
} // }
impl<X: Clone + Send + Sync + 'static> Default for State<X> { // impl<X: Clone + Send + Sync + 'static> Default for State<X> {
fn default() -> Self { // fn default() -> Self {
Self::Start // Self::Start
} // }
} // }
async fn handle_start() -> Result<(), teloxide::RequestError> { // async fn handle_start() -> Result<(), teloxide::RequestError> {
Ok(()) // Ok(())
} // }
async fn handle_have_data() -> Result<(), teloxide::RequestError> { // async fn handle_have_data() -> Result<(), teloxide::RequestError> {
Ok(()) // Ok(())
} // }
} // }