mirror of
https://github.com/teloxide/teloxide.git
synced 2025-01-18 15:20:15 +01:00
Merge pull request #585 from teloxide/concurrent-dispatching
Implement concurrent update dispatching
This commit is contained in:
commit
ba5dc486ce
2 changed files with 133 additions and 22 deletions
|
@ -27,6 +27,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
- `BotCommands::descriptions` now returns `CommandDescriptions` instead of `String` [**BC**].
|
- `BotCommands::descriptions` now returns `CommandDescriptions` instead of `String` [**BC**].
|
||||||
- Mark `Dialogue::new` as `#[must_use]`.
|
- Mark `Dialogue::new` as `#[must_use]`.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Concurrent update handling in the new dispatcher ([issue 536](https://github.com/teloxide/teloxide/issues/536)).
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
- `HandlerFactory` and `HandlerExt::dispatch_by` in favour of `teloxide::handler!`.
|
- `HandlerFactory` and `HandlerExt::dispatch_by` in favour of `teloxide::handler!`.
|
||||||
|
|
|
@ -9,10 +9,16 @@ use crate::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use dptree::di::{DependencyMap, DependencySupplier};
|
use dptree::di::{DependencyMap, DependencySupplier};
|
||||||
use futures::{future::BoxFuture, StreamExt};
|
use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
|
||||||
use std::{collections::HashSet, fmt::Debug, ops::ControlFlow, sync::Arc};
|
use std::{
|
||||||
|
collections::{HashMap, HashSet},
|
||||||
|
fmt::Debug,
|
||||||
|
ops::{ControlFlow, Deref},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
use teloxide_core::{requests::Request, types::UpdateKind};
|
use teloxide_core::{requests::Request, types::UpdateKind};
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
|
|
||||||
|
@ -20,7 +26,7 @@ use std::future::Future;
|
||||||
pub struct DispatcherBuilder<R, Err> {
|
pub struct DispatcherBuilder<R, Err> {
|
||||||
bot: R,
|
bot: R,
|
||||||
dependencies: DependencyMap,
|
dependencies: DependencyMap,
|
||||||
handler: UpdateHandler<Err>,
|
handler: Arc<UpdateHandler<Err>>,
|
||||||
default_handler: DefaultHandler,
|
default_handler: DefaultHandler,
|
||||||
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
|
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
|
||||||
}
|
}
|
||||||
|
@ -42,7 +48,7 @@ where
|
||||||
let handler = Arc::new(handler);
|
let handler = Arc::new(handler);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
default_handler: Box::new(move |upd| {
|
default_handler: Arc::new(move |upd| {
|
||||||
let handler = Arc::clone(&handler);
|
let handler = Arc::clone(&handler);
|
||||||
Box::pin(handler(upd))
|
Box::pin(handler(upd))
|
||||||
}),
|
}),
|
||||||
|
@ -77,17 +83,29 @@ where
|
||||||
error_handler: self.error_handler,
|
error_handler: self.error_handler,
|
||||||
allowed_updates: Default::default(),
|
allowed_updates: Default::default(),
|
||||||
state: ShutdownToken::new(),
|
state: ShutdownToken::new(),
|
||||||
|
workers: HashMap::new(),
|
||||||
|
default_worker: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The base for update dispatching.
|
/// The base for update dispatching.
|
||||||
|
///
|
||||||
|
/// Updates from different chats are handles concurrently, whereas updates from
|
||||||
|
/// the same chats are handled sequentially. If the dispatcher is unable to
|
||||||
|
/// determine a chat ID of an incoming update, it will be handled concurrently.
|
||||||
pub struct Dispatcher<R, Err> {
|
pub struct Dispatcher<R, Err> {
|
||||||
bot: R,
|
bot: R,
|
||||||
dependencies: DependencyMap,
|
dependencies: DependencyMap,
|
||||||
|
|
||||||
handler: UpdateHandler<Err>,
|
handler: Arc<UpdateHandler<Err>>,
|
||||||
default_handler: DefaultHandler,
|
default_handler: DefaultHandler,
|
||||||
|
|
||||||
|
// Tokio TX channel parts associated with chat IDs that consume updates sequentially.
|
||||||
|
workers: HashMap<i64, Worker>,
|
||||||
|
// The default TX part that consume updates concurrently.
|
||||||
|
default_worker: Option<Worker>,
|
||||||
|
|
||||||
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
|
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
|
||||||
// TODO: respect allowed_udpates
|
// TODO: respect allowed_udpates
|
||||||
allowed_updates: HashSet<AllowedUpdate>,
|
allowed_updates: HashSet<AllowedUpdate>,
|
||||||
|
@ -95,13 +113,18 @@ pub struct Dispatcher<R, Err> {
|
||||||
state: ShutdownToken,
|
state: ShutdownToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Worker {
|
||||||
|
tx: tokio::sync::mpsc::Sender<Update>,
|
||||||
|
handle: tokio::task::JoinHandle<()>,
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: it is allowed to return message as response on telegram request in
|
// TODO: it is allowed to return message as response on telegram request in
|
||||||
// webhooks, so we can allow this too. See more there: https://core.telegram.org/bots/api#making-requests-when-getting-updates
|
// webhooks, so we can allow this too. See more there: https://core.telegram.org/bots/api#making-requests-when-getting-updates
|
||||||
|
|
||||||
/// A handler that processes updates from Telegram.
|
/// A handler that processes updates from Telegram.
|
||||||
pub type UpdateHandler<Err> = dptree::Handler<'static, DependencyMap, Result<(), Err>>;
|
pub type UpdateHandler<Err> = dptree::Handler<'static, DependencyMap, Result<(), Err>>;
|
||||||
|
|
||||||
type DefaultHandler = Box<dyn Fn(Arc<Update>) -> BoxFuture<'static, ()> + Send + Sync>;
|
type DefaultHandler = Arc<dyn Fn(Arc<Update>) -> BoxFuture<'static, ()> + Send + Sync>;
|
||||||
|
|
||||||
impl<R, Err> Dispatcher<R, Err>
|
impl<R, Err> Dispatcher<R, Err>
|
||||||
where
|
where
|
||||||
|
@ -117,8 +140,8 @@ where
|
||||||
DispatcherBuilder {
|
DispatcherBuilder {
|
||||||
bot,
|
bot,
|
||||||
dependencies: DependencyMap::new(),
|
dependencies: DependencyMap::new(),
|
||||||
handler,
|
handler: Arc::new(handler),
|
||||||
default_handler: Box::new(|upd| {
|
default_handler: Arc::new(|upd| {
|
||||||
log::warn!("Unhandled update: {:?}", upd);
|
log::warn!("Unhandled update: {:?}", upd);
|
||||||
Box::pin(async {})
|
Box::pin(async {})
|
||||||
}),
|
}),
|
||||||
|
@ -205,13 +228,21 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: wait for executing handlers?
|
self.workers
|
||||||
|
.drain()
|
||||||
|
.map(|(_chat_id, worker)| worker.handle)
|
||||||
|
.chain(self.default_worker.take().map(|worker| worker.handle))
|
||||||
|
.collect::<FuturesUnordered<_>>()
|
||||||
|
.for_each(|res| async {
|
||||||
|
res.expect("Failed to wait for a worker.");
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
self.state.done();
|
self.state.done();
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn process_update<LErr, LErrHandler>(
|
async fn process_update<LErr, LErrHandler>(
|
||||||
&self,
|
&mut self,
|
||||||
update: Result<Update, LErr>,
|
update: Result<Update, LErr>,
|
||||||
err_handler: &Arc<LErrHandler>,
|
err_handler: &Arc<LErrHandler>,
|
||||||
) where
|
) where
|
||||||
|
@ -229,19 +260,21 @@ where
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut deps = self.dependencies.clone();
|
let deps = self.dependencies.clone();
|
||||||
deps.insert(upd);
|
let handler = Arc::clone(&self.handler);
|
||||||
|
let default_handler = Arc::clone(&self.default_handler);
|
||||||
|
let error_handler = Arc::clone(&self.error_handler);
|
||||||
|
|
||||||
match self.handler.dispatch(deps).await {
|
let worker = match upd.chat() {
|
||||||
ControlFlow::Break(Ok(())) => {}
|
Some(chat) => self.workers.entry(chat.id).or_insert_with(|| {
|
||||||
ControlFlow::Break(Err(err)) => {
|
spawn_worker(deps, handler, default_handler, error_handler)
|
||||||
self.error_handler.clone().handle_error(err).await
|
}),
|
||||||
}
|
None => self.default_worker.get_or_insert_with(|| {
|
||||||
ControlFlow::Continue(deps) => {
|
spawn_default_worker(deps, handler, default_handler, error_handler)
|
||||||
let upd = deps.get();
|
}),
|
||||||
(self.default_handler)(upd).await;
|
};
|
||||||
}
|
|
||||||
}
|
worker.tx.send(upd).await.expect("TX is dead");
|
||||||
}
|
}
|
||||||
Err(err) => err_handler.clone().handle_error(err).await,
|
Err(err) => err_handler.clone().handle_error(err).await,
|
||||||
}
|
}
|
||||||
|
@ -280,6 +313,80 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const WORKER_QUEUE_SIZE: usize = 64;
|
||||||
|
|
||||||
|
fn spawn_worker<Err>(
|
||||||
|
deps: DependencyMap,
|
||||||
|
handler: Arc<UpdateHandler<Err>>,
|
||||||
|
default_handler: DefaultHandler,
|
||||||
|
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
|
||||||
|
) -> Worker
|
||||||
|
where
|
||||||
|
Err: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::channel(WORKER_QUEUE_SIZE);
|
||||||
|
|
||||||
|
let deps = Arc::new(deps);
|
||||||
|
|
||||||
|
let handle = tokio::spawn(ReceiverStream::new(rx).for_each(move |update| {
|
||||||
|
let deps = Arc::clone(&deps);
|
||||||
|
let handler = Arc::clone(&handler);
|
||||||
|
let default_handler = Arc::clone(&default_handler);
|
||||||
|
let error_handler = Arc::clone(&error_handler);
|
||||||
|
|
||||||
|
handle_update(update, deps, handler, default_handler, error_handler)
|
||||||
|
}));
|
||||||
|
|
||||||
|
Worker { tx, handle }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn_default_worker<Err>(
|
||||||
|
deps: DependencyMap,
|
||||||
|
handler: Arc<UpdateHandler<Err>>,
|
||||||
|
default_handler: DefaultHandler,
|
||||||
|
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
|
||||||
|
) -> Worker
|
||||||
|
where
|
||||||
|
Err: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::channel(WORKER_QUEUE_SIZE);
|
||||||
|
|
||||||
|
let deps = Arc::new(deps);
|
||||||
|
|
||||||
|
let handle = tokio::spawn(ReceiverStream::new(rx).for_each_concurrent(None, move |update| {
|
||||||
|
let deps = Arc::clone(&deps);
|
||||||
|
let handler = Arc::clone(&handler);
|
||||||
|
let default_handler = Arc::clone(&default_handler);
|
||||||
|
let error_handler = Arc::clone(&error_handler);
|
||||||
|
|
||||||
|
handle_update(update, deps, handler, default_handler, error_handler)
|
||||||
|
}));
|
||||||
|
|
||||||
|
Worker { tx, handle }
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_update<Err>(
|
||||||
|
update: Update,
|
||||||
|
deps: Arc<DependencyMap>,
|
||||||
|
handler: Arc<UpdateHandler<Err>>,
|
||||||
|
default_handler: DefaultHandler,
|
||||||
|
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
|
||||||
|
) where
|
||||||
|
Err: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let mut deps = deps.deref().clone();
|
||||||
|
deps.insert(update);
|
||||||
|
|
||||||
|
match handler.dispatch(deps).await {
|
||||||
|
ControlFlow::Break(Ok(())) => {}
|
||||||
|
ControlFlow::Break(Err(err)) => error_handler.clone().handle_error(err).await,
|
||||||
|
ControlFlow::Continue(deps) => {
|
||||||
|
let update = deps.get();
|
||||||
|
(default_handler)(update).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
|
|
Loading…
Reference in a new issue