diff --git a/Cargo.toml b/Cargo.toml index 013f97bd..3ff6a2d1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "teloxide" version = "0.7.2" -edition = "2018" +edition = "2021" description = "An elegant Telegram bots framework for Rust" repository = "https://github.com/teloxide/teloxide" documentation = "https://docs.rs/teloxide/" diff --git a/src/dispatching/dispatcher.rs b/src/dispatching/dispatcher.rs index c6f14616..31494664 100644 --- a/src/dispatching/dispatcher.rs +++ b/src/dispatching/dispatcher.rs @@ -1,10 +1,11 @@ use crate::{ dispatching::{ - stop_token::StopToken, update_listeners, update_listeners::UpdateListener, ShutdownToken, + distribution::default_distribution_function, stop_token::StopToken, update_listeners, + update_listeners::UpdateListener, DefaultKey, ShutdownToken, }, error_handlers::{ErrorHandler, LoggingErrorHandler}, - requests::Requester, - types::{AllowedUpdate, Update}, + requests::{Request, Requester}, + types::{AllowedUpdate, Update, UpdateKind}, utils::shutdown_token::shutdown_check_timeout_for, }; @@ -13,28 +14,27 @@ use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use std::{ collections::{HashMap, HashSet}, fmt::Debug, + hash::Hash, ops::{ControlFlow, Deref}, sync::Arc, }; -use teloxide_core::{ - requests::Request, - types::{ChatId, UpdateKind}, -}; use tokio::time::timeout; use tokio_stream::wrappers::ReceiverStream; use std::future::Future; /// The builder for [`Dispatcher`]. -pub struct DispatcherBuilder { +pub struct DispatcherBuilder { bot: R, dependencies: DependencyMap, handler: Arc>, default_handler: DefaultHandler, error_handler: Arc + Send + Sync>, + distribution_f: fn(&Update) -> Option, + worker_queue_size: usize, } -impl DispatcherBuilder +impl DispatcherBuilder where R: Clone + Requester + Clone + Send + Sync + 'static, Err: Debug + Send + Sync + 'static, @@ -75,17 +75,67 @@ where Self { dependencies, ..self } } + /// Specifies size of the queue for workers. + /// + /// By default it's 64. + #[must_use] + pub fn worker_queue_size(self, size: usize) -> Self { + Self { worker_queue_size: size, ..self } + } + + /// Specifies the distribution function that decides how updates are grouped + /// before execution. + pub fn distribution_function( + self, + f: fn(&Update) -> Option, + ) -> DispatcherBuilder + where + K: Hash + Eq, + { + let Self { + bot, + dependencies, + handler, + default_handler, + error_handler, + distribution_f: _, + worker_queue_size, + } = self; + + DispatcherBuilder { + bot, + dependencies, + handler, + default_handler, + error_handler, + distribution_f: f, + worker_queue_size, + } + } + /// Constructs [`Dispatcher`]. #[must_use] - pub fn build(self) -> Dispatcher { + pub fn build(self) -> Dispatcher { + let Self { + bot, + dependencies, + handler, + default_handler, + error_handler, + distribution_f, + worker_queue_size, + } = self; + Dispatcher { - bot: self.bot.clone(), - dependencies: self.dependencies, - handler: self.handler, - default_handler: self.default_handler, - error_handler: self.error_handler, + bot, + dependencies, + handler, + default_handler, + error_handler, allowed_updates: Default::default(), state: ShutdownToken::new(), + distribution_f, + worker_queue_size, workers: HashMap::new(), default_worker: None, } @@ -97,15 +147,20 @@ where /// Updates from different chats are handles concurrently, whereas updates from /// the same chats are handled sequentially. If the dispatcher is unable to /// determine a chat ID of an incoming update, it will be handled concurrently. -pub struct Dispatcher { +/// Note that this behaviour can be altered with [`distribution_function`]. +/// +/// [`distribution_function`]: DispatcherBuilder::distribution_function +pub struct Dispatcher { bot: R, dependencies: DependencyMap, handler: Arc>, default_handler: DefaultHandler, + distribution_f: fn(&Update) -> Option, + worker_queue_size: usize, // Tokio TX channel parts associated with chat IDs that consume updates sequentially. - workers: HashMap, + workers: HashMap, // The default TX part that consume updates concurrently. default_worker: Option, @@ -129,17 +184,19 @@ pub type UpdateHandler = dptree::Handler<'static, DependencyMap, Result<(), type DefaultHandler = Arc) -> BoxFuture<'static, ()> + Send + Sync>; -impl Dispatcher +impl Dispatcher where R: Requester + Clone + Send + Sync + 'static, Err: Send + Sync + 'static, { /// Constructs a new [`DispatcherBuilder`] with `bot` and `handler`. #[must_use] - pub fn builder(bot: R, handler: UpdateHandler) -> DispatcherBuilder + pub fn builder(bot: R, handler: UpdateHandler) -> DispatcherBuilder where Err: Debug, { + const DEFAULT_WORKER_QUEUE_SIZE: usize = 64; + DispatcherBuilder { bot, dependencies: DependencyMap::new(), @@ -149,9 +206,18 @@ where Box::pin(async {}) }), error_handler: LoggingErrorHandler::new(), + worker_queue_size: DEFAULT_WORKER_QUEUE_SIZE, + distribution_f: default_distribution_function, } } +} +impl Dispatcher +where + R: Requester + Clone + Send + Sync + 'static, + Err: Send + Sync + 'static, + Key: Hash + Eq, +{ /// Starts your bot with the default parameters. /// /// The default parameters are a long polling update listener and log all @@ -263,17 +329,34 @@ where return; } - let deps = self.dependencies.clone(); - let handler = Arc::clone(&self.handler); - let default_handler = Arc::clone(&self.default_handler); - let error_handler = Arc::clone(&self.error_handler); + let worker = match (self.distribution_f)(&upd) { + Some(key) => self.workers.entry(key).or_insert_with(|| { + let deps = self.dependencies.clone(); + let handler = Arc::clone(&self.handler); + let default_handler = Arc::clone(&self.default_handler); + let error_handler = Arc::clone(&self.error_handler); - let worker = match upd.chat() { - Some(chat) => self.workers.entry(chat.id).or_insert_with(|| { - spawn_worker(deps, handler, default_handler, error_handler) + spawn_worker( + deps, + handler, + default_handler, + error_handler, + self.worker_queue_size, + ) }), None => self.default_worker.get_or_insert_with(|| { - spawn_default_worker(deps, handler, default_handler, error_handler) + let deps = self.dependencies.clone(); + let handler = Arc::clone(&self.handler); + let default_handler = Arc::clone(&self.default_handler); + let error_handler = Arc::clone(&self.error_handler); + + spawn_default_worker( + deps, + handler, + default_handler, + error_handler, + self.worker_queue_size, + ) }), }; @@ -316,18 +399,17 @@ where } } -const WORKER_QUEUE_SIZE: usize = 64; - fn spawn_worker( deps: DependencyMap, handler: Arc>, default_handler: DefaultHandler, error_handler: Arc + Send + Sync>, + queue_size: usize, ) -> Worker where Err: Send + Sync + 'static, { - let (tx, rx) = tokio::sync::mpsc::channel(WORKER_QUEUE_SIZE); + let (tx, rx) = tokio::sync::mpsc::channel(queue_size); let deps = Arc::new(deps); @@ -348,11 +430,12 @@ fn spawn_default_worker( handler: Arc>, default_handler: DefaultHandler, error_handler: Arc + Send + Sync>, + queue_size: usize, ) -> Worker where Err: Send + Sync + 'static, { - let (tx, rx) = tokio::sync::mpsc::channel(WORKER_QUEUE_SIZE); + let (tx, rx) = tokio::sync::mpsc::channel(queue_size); let deps = Arc::new(deps); @@ -403,7 +486,7 @@ mod tests { tokio::spawn(async { // Just check that this code compiles. if false { - Dispatcher::<_, Infallible>::builder(Bot::new(""), dptree::entry()) + Dispatcher::<_, Infallible, _>::builder(Bot::new(""), dptree::entry()) .build() .dispatch() .await; diff --git a/src/dispatching/distribution.rs b/src/dispatching/distribution.rs new file mode 100644 index 00000000..208e0018 --- /dev/null +++ b/src/dispatching/distribution.rs @@ -0,0 +1,9 @@ +use teloxide_core::types::{ChatId, Update}; + +/// Default distribution key for dispatching. +#[derive(Debug, Hash, PartialEq, Eq)] +pub struct DefaultKey(ChatId); + +pub(crate) fn default_distribution_function(update: &Update) -> Option { + update.chat().map(|c| c.id).map(DefaultKey) +} diff --git a/src/dispatching/mod.rs b/src/dispatching/mod.rs index 33e02393..3599c436 100644 --- a/src/dispatching/mod.rs +++ b/src/dispatching/mod.rs @@ -100,6 +100,7 @@ pub mod repls; pub mod dialogue; mod dispatcher; +mod distribution; mod filter_ext; mod handler_ext; mod handler_factory; @@ -108,6 +109,7 @@ pub mod update_listeners; pub use crate::utils::shutdown_token::{IdleShutdownError, ShutdownToken}; pub use dispatcher::{Dispatcher, DispatcherBuilder, UpdateHandler}; +pub use distribution::DefaultKey; pub use filter_ext::{MessageFilterExt, UpdateFilterExt}; pub use handler_ext::HandlerExt; #[allow(deprecated)]