Merge pull request #585 from teloxide/concurrent-dispatching

Implement concurrent update dispatching
This commit is contained in:
Hirrolot 2022-04-13 17:38:50 +06:00 committed by GitHub
commit ba5dc486ce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 133 additions and 22 deletions

View file

@ -27,6 +27,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BotCommands::descriptions` now returns `CommandDescriptions` instead of `String` [**BC**]. - `BotCommands::descriptions` now returns `CommandDescriptions` instead of `String` [**BC**].
- Mark `Dialogue::new` as `#[must_use]`. - Mark `Dialogue::new` as `#[must_use]`.
### Fixed
- Concurrent update handling in the new dispatcher ([issue 536](https://github.com/teloxide/teloxide/issues/536)).
### Deprecated ### Deprecated
- `HandlerFactory` and `HandlerExt::dispatch_by` in favour of `teloxide::handler!`. - `HandlerFactory` and `HandlerExt::dispatch_by` in favour of `teloxide::handler!`.

View file

@ -9,10 +9,16 @@ use crate::{
}; };
use dptree::di::{DependencyMap, DependencySupplier}; use dptree::di::{DependencyMap, DependencySupplier};
use futures::{future::BoxFuture, StreamExt}; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
use std::{collections::HashSet, fmt::Debug, ops::ControlFlow, sync::Arc}; use std::{
collections::{HashMap, HashSet},
fmt::Debug,
ops::{ControlFlow, Deref},
sync::Arc,
};
use teloxide_core::{requests::Request, types::UpdateKind}; use teloxide_core::{requests::Request, types::UpdateKind};
use tokio::time::timeout; use tokio::time::timeout;
use tokio_stream::wrappers::ReceiverStream;
use std::future::Future; use std::future::Future;
@ -20,7 +26,7 @@ use std::future::Future;
pub struct DispatcherBuilder<R, Err> { pub struct DispatcherBuilder<R, Err> {
bot: R, bot: R,
dependencies: DependencyMap, dependencies: DependencyMap,
handler: UpdateHandler<Err>, handler: Arc<UpdateHandler<Err>>,
default_handler: DefaultHandler, default_handler: DefaultHandler,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>, error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
} }
@ -42,7 +48,7 @@ where
let handler = Arc::new(handler); let handler = Arc::new(handler);
Self { Self {
default_handler: Box::new(move |upd| { default_handler: Arc::new(move |upd| {
let handler = Arc::clone(&handler); let handler = Arc::clone(&handler);
Box::pin(handler(upd)) Box::pin(handler(upd))
}), }),
@ -77,17 +83,29 @@ where
error_handler: self.error_handler, error_handler: self.error_handler,
allowed_updates: Default::default(), allowed_updates: Default::default(),
state: ShutdownToken::new(), state: ShutdownToken::new(),
workers: HashMap::new(),
default_worker: None,
} }
} }
} }
/// The base for update dispatching. /// The base for update dispatching.
///
/// Updates from different chats are handles concurrently, whereas updates from
/// the same chats are handled sequentially. If the dispatcher is unable to
/// determine a chat ID of an incoming update, it will be handled concurrently.
pub struct Dispatcher<R, Err> { pub struct Dispatcher<R, Err> {
bot: R, bot: R,
dependencies: DependencyMap, dependencies: DependencyMap,
handler: UpdateHandler<Err>, handler: Arc<UpdateHandler<Err>>,
default_handler: DefaultHandler, default_handler: DefaultHandler,
// Tokio TX channel parts associated with chat IDs that consume updates sequentially.
workers: HashMap<i64, Worker>,
// The default TX part that consume updates concurrently.
default_worker: Option<Worker>,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>, error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
// TODO: respect allowed_udpates // TODO: respect allowed_udpates
allowed_updates: HashSet<AllowedUpdate>, allowed_updates: HashSet<AllowedUpdate>,
@ -95,13 +113,18 @@ pub struct Dispatcher<R, Err> {
state: ShutdownToken, state: ShutdownToken,
} }
struct Worker {
tx: tokio::sync::mpsc::Sender<Update>,
handle: tokio::task::JoinHandle<()>,
}
// TODO: it is allowed to return message as response on telegram request in // TODO: it is allowed to return message as response on telegram request in
// webhooks, so we can allow this too. See more there: https://core.telegram.org/bots/api#making-requests-when-getting-updates // webhooks, so we can allow this too. See more there: https://core.telegram.org/bots/api#making-requests-when-getting-updates
/// A handler that processes updates from Telegram. /// A handler that processes updates from Telegram.
pub type UpdateHandler<Err> = dptree::Handler<'static, DependencyMap, Result<(), Err>>; pub type UpdateHandler<Err> = dptree::Handler<'static, DependencyMap, Result<(), Err>>;
type DefaultHandler = Box<dyn Fn(Arc<Update>) -> BoxFuture<'static, ()> + Send + Sync>; type DefaultHandler = Arc<dyn Fn(Arc<Update>) -> BoxFuture<'static, ()> + Send + Sync>;
impl<R, Err> Dispatcher<R, Err> impl<R, Err> Dispatcher<R, Err>
where where
@ -117,8 +140,8 @@ where
DispatcherBuilder { DispatcherBuilder {
bot, bot,
dependencies: DependencyMap::new(), dependencies: DependencyMap::new(),
handler, handler: Arc::new(handler),
default_handler: Box::new(|upd| { default_handler: Arc::new(|upd| {
log::warn!("Unhandled update: {:?}", upd); log::warn!("Unhandled update: {:?}", upd);
Box::pin(async {}) Box::pin(async {})
}), }),
@ -205,13 +228,21 @@ where
} }
} }
// TODO: wait for executing handlers? self.workers
.drain()
.map(|(_chat_id, worker)| worker.handle)
.chain(self.default_worker.take().map(|worker| worker.handle))
.collect::<FuturesUnordered<_>>()
.for_each(|res| async {
res.expect("Failed to wait for a worker.");
})
.await;
self.state.done(); self.state.done();
} }
async fn process_update<LErr, LErrHandler>( async fn process_update<LErr, LErrHandler>(
&self, &mut self,
update: Result<Update, LErr>, update: Result<Update, LErr>,
err_handler: &Arc<LErrHandler>, err_handler: &Arc<LErrHandler>,
) where ) where
@ -229,19 +260,21 @@ where
return; return;
} }
let mut deps = self.dependencies.clone(); let deps = self.dependencies.clone();
deps.insert(upd); let handler = Arc::clone(&self.handler);
let default_handler = Arc::clone(&self.default_handler);
let error_handler = Arc::clone(&self.error_handler);
match self.handler.dispatch(deps).await { let worker = match upd.chat() {
ControlFlow::Break(Ok(())) => {} Some(chat) => self.workers.entry(chat.id).or_insert_with(|| {
ControlFlow::Break(Err(err)) => { spawn_worker(deps, handler, default_handler, error_handler)
self.error_handler.clone().handle_error(err).await }),
} None => self.default_worker.get_or_insert_with(|| {
ControlFlow::Continue(deps) => { spawn_default_worker(deps, handler, default_handler, error_handler)
let upd = deps.get(); }),
(self.default_handler)(upd).await; };
}
} worker.tx.send(upd).await.expect("TX is dead");
} }
Err(err) => err_handler.clone().handle_error(err).await, Err(err) => err_handler.clone().handle_error(err).await,
} }
@ -280,6 +313,80 @@ where
} }
} }
const WORKER_QUEUE_SIZE: usize = 64;
fn spawn_worker<Err>(
deps: DependencyMap,
handler: Arc<UpdateHandler<Err>>,
default_handler: DefaultHandler,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
) -> Worker
where
Err: Send + Sync + 'static,
{
let (tx, rx) = tokio::sync::mpsc::channel(WORKER_QUEUE_SIZE);
let deps = Arc::new(deps);
let handle = tokio::spawn(ReceiverStream::new(rx).for_each(move |update| {
let deps = Arc::clone(&deps);
let handler = Arc::clone(&handler);
let default_handler = Arc::clone(&default_handler);
let error_handler = Arc::clone(&error_handler);
handle_update(update, deps, handler, default_handler, error_handler)
}));
Worker { tx, handle }
}
fn spawn_default_worker<Err>(
deps: DependencyMap,
handler: Arc<UpdateHandler<Err>>,
default_handler: DefaultHandler,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
) -> Worker
where
Err: Send + Sync + 'static,
{
let (tx, rx) = tokio::sync::mpsc::channel(WORKER_QUEUE_SIZE);
let deps = Arc::new(deps);
let handle = tokio::spawn(ReceiverStream::new(rx).for_each_concurrent(None, move |update| {
let deps = Arc::clone(&deps);
let handler = Arc::clone(&handler);
let default_handler = Arc::clone(&default_handler);
let error_handler = Arc::clone(&error_handler);
handle_update(update, deps, handler, default_handler, error_handler)
}));
Worker { tx, handle }
}
async fn handle_update<Err>(
update: Update,
deps: Arc<DependencyMap>,
handler: Arc<UpdateHandler<Err>>,
default_handler: DefaultHandler,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
) where
Err: Send + Sync + 'static,
{
let mut deps = deps.deref().clone();
deps.insert(update);
match handler.dispatch(deps).await {
ControlFlow::Break(Ok(())) => {}
ControlFlow::Break(Err(err)) => error_handler.clone().handle_error(err).await,
ControlFlow::Continue(deps) => {
let update = deps.get();
(default_handler)(update).await;
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::convert::Infallible; use std::convert::Infallible;