diff --git a/CHANGELOG.md b/CHANGELOG.md index 7becfdda..34f2e5d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `BotCommands::descriptions` now returns `CommandDescriptions` instead of `String` [**BC**]. - Mark `Dialogue::new` as `#[must_use]`. +### Fixed + + - Concurrent update handling in the new dispatcher ([issue 536](https://github.com/teloxide/teloxide/issues/536)). + ### Deprecated - `HandlerFactory` and `HandlerExt::dispatch_by` in favour of `teloxide::handler!`. diff --git a/src/dispatching/dispatcher.rs b/src/dispatching/dispatcher.rs index 5c93b918..fd194430 100644 --- a/src/dispatching/dispatcher.rs +++ b/src/dispatching/dispatcher.rs @@ -9,10 +9,16 @@ use crate::{ }; use dptree::di::{DependencyMap, DependencySupplier}; -use futures::{future::BoxFuture, StreamExt}; -use std::{collections::HashSet, fmt::Debug, ops::ControlFlow, sync::Arc}; +use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; +use std::{ + collections::{HashMap, HashSet}, + fmt::Debug, + ops::{ControlFlow, Deref}, + sync::Arc, +}; use teloxide_core::{requests::Request, types::UpdateKind}; use tokio::time::timeout; +use tokio_stream::wrappers::ReceiverStream; use std::future::Future; @@ -20,7 +26,7 @@ use std::future::Future; pub struct DispatcherBuilder { bot: R, dependencies: DependencyMap, - handler: UpdateHandler, + handler: Arc>, default_handler: DefaultHandler, error_handler: Arc + Send + Sync>, } @@ -42,7 +48,7 @@ where let handler = Arc::new(handler); Self { - default_handler: Box::new(move |upd| { + default_handler: Arc::new(move |upd| { let handler = Arc::clone(&handler); Box::pin(handler(upd)) }), @@ -77,17 +83,29 @@ where error_handler: self.error_handler, allowed_updates: Default::default(), state: ShutdownToken::new(), + workers: HashMap::new(), + default_worker: None, } } } /// The base for update dispatching. +/// +/// Updates from different chats are handles concurrently, whereas updates from +/// the same chats are handled sequentially. If the dispatcher is unable to +/// determine a chat ID of an incoming update, it will be handled concurrently. pub struct Dispatcher { bot: R, dependencies: DependencyMap, - handler: UpdateHandler, + handler: Arc>, default_handler: DefaultHandler, + + // Tokio TX channel parts associated with chat IDs that consume updates sequentially. + workers: HashMap, + // The default TX part that consume updates concurrently. + default_worker: Option, + error_handler: Arc + Send + Sync>, // TODO: respect allowed_udpates allowed_updates: HashSet, @@ -95,13 +113,18 @@ pub struct Dispatcher { state: ShutdownToken, } +struct Worker { + tx: tokio::sync::mpsc::Sender, + handle: tokio::task::JoinHandle<()>, +} + // TODO: it is allowed to return message as response on telegram request in // webhooks, so we can allow this too. See more there: https://core.telegram.org/bots/api#making-requests-when-getting-updates /// A handler that processes updates from Telegram. pub type UpdateHandler = dptree::Handler<'static, DependencyMap, Result<(), Err>>; -type DefaultHandler = Box) -> BoxFuture<'static, ()> + Send + Sync>; +type DefaultHandler = Arc) -> BoxFuture<'static, ()> + Send + Sync>; impl Dispatcher where @@ -117,8 +140,8 @@ where DispatcherBuilder { bot, dependencies: DependencyMap::new(), - handler, - default_handler: Box::new(|upd| { + handler: Arc::new(handler), + default_handler: Arc::new(|upd| { log::warn!("Unhandled update: {:?}", upd); Box::pin(async {}) }), @@ -205,13 +228,21 @@ where } } - // TODO: wait for executing handlers? + self.workers + .drain() + .map(|(_chat_id, worker)| worker.handle) + .chain(self.default_worker.take().map(|worker| worker.handle)) + .collect::>() + .for_each(|res| async { + res.expect("Failed to wait for a worker."); + }) + .await; self.state.done(); } async fn process_update( - &self, + &mut self, update: Result, err_handler: &Arc, ) where @@ -229,19 +260,21 @@ where return; } - let mut deps = self.dependencies.clone(); - deps.insert(upd); + let deps = self.dependencies.clone(); + let handler = Arc::clone(&self.handler); + let default_handler = Arc::clone(&self.default_handler); + let error_handler = Arc::clone(&self.error_handler); - match self.handler.dispatch(deps).await { - ControlFlow::Break(Ok(())) => {} - ControlFlow::Break(Err(err)) => { - self.error_handler.clone().handle_error(err).await - } - ControlFlow::Continue(deps) => { - let upd = deps.get(); - (self.default_handler)(upd).await; - } - } + let worker = match upd.chat() { + Some(chat) => self.workers.entry(chat.id).or_insert_with(|| { + spawn_worker(deps, handler, default_handler, error_handler) + }), + None => self.default_worker.get_or_insert_with(|| { + spawn_default_worker(deps, handler, default_handler, error_handler) + }), + }; + + worker.tx.send(upd).await.expect("TX is dead"); } Err(err) => err_handler.clone().handle_error(err).await, } @@ -280,6 +313,80 @@ where } } +const WORKER_QUEUE_SIZE: usize = 64; + +fn spawn_worker( + deps: DependencyMap, + handler: Arc>, + default_handler: DefaultHandler, + error_handler: Arc + Send + Sync>, +) -> Worker +where + Err: Send + Sync + 'static, +{ + let (tx, rx) = tokio::sync::mpsc::channel(WORKER_QUEUE_SIZE); + + let deps = Arc::new(deps); + + let handle = tokio::spawn(ReceiverStream::new(rx).for_each(move |update| { + let deps = Arc::clone(&deps); + let handler = Arc::clone(&handler); + let default_handler = Arc::clone(&default_handler); + let error_handler = Arc::clone(&error_handler); + + handle_update(update, deps, handler, default_handler, error_handler) + })); + + Worker { tx, handle } +} + +fn spawn_default_worker( + deps: DependencyMap, + handler: Arc>, + default_handler: DefaultHandler, + error_handler: Arc + Send + Sync>, +) -> Worker +where + Err: Send + Sync + 'static, +{ + let (tx, rx) = tokio::sync::mpsc::channel(WORKER_QUEUE_SIZE); + + let deps = Arc::new(deps); + + let handle = tokio::spawn(ReceiverStream::new(rx).for_each_concurrent(None, move |update| { + let deps = Arc::clone(&deps); + let handler = Arc::clone(&handler); + let default_handler = Arc::clone(&default_handler); + let error_handler = Arc::clone(&error_handler); + + handle_update(update, deps, handler, default_handler, error_handler) + })); + + Worker { tx, handle } +} + +async fn handle_update( + update: Update, + deps: Arc, + handler: Arc>, + default_handler: DefaultHandler, + error_handler: Arc + Send + Sync>, +) where + Err: Send + Sync + 'static, +{ + let mut deps = deps.deref().clone(); + deps.insert(update); + + match handler.dispatch(deps).await { + ControlFlow::Break(Ok(())) => {} + ControlFlow::Break(Err(err)) => error_handler.clone().handle_error(err).await, + ControlFlow::Continue(deps) => { + let update = deps.get(); + (default_handler)(update).await; + } + } +} + #[cfg(test)] mod tests { use std::convert::Infallible;