Merge pull request #589 from teloxide/dp_impr

improve Dispatcher
This commit is contained in:
Waffle Maybe 2022-04-14 19:42:22 +04:00 committed by GitHub
commit 03521bfd3d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 127 additions and 33 deletions

View file

@ -1,7 +1,7 @@
[package]
name = "teloxide"
version = "0.7.2"
edition = "2018"
edition = "2021"
description = "An elegant Telegram bots framework for Rust"
repository = "https://github.com/teloxide/teloxide"
documentation = "https://docs.rs/teloxide/"

View file

@ -1,10 +1,11 @@
use crate::{
dispatching::{
stop_token::StopToken, update_listeners, update_listeners::UpdateListener, ShutdownToken,
distribution::default_distribution_function, stop_token::StopToken, update_listeners,
update_listeners::UpdateListener, DefaultKey, ShutdownToken,
},
error_handlers::{ErrorHandler, LoggingErrorHandler},
requests::Requester,
types::{AllowedUpdate, Update},
requests::{Request, Requester},
types::{AllowedUpdate, Update, UpdateKind},
utils::shutdown_token::shutdown_check_timeout_for,
};
@ -13,28 +14,27 @@ use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
use std::{
collections::{HashMap, HashSet},
fmt::Debug,
hash::Hash,
ops::{ControlFlow, Deref},
sync::Arc,
};
use teloxide_core::{
requests::Request,
types::{ChatId, UpdateKind},
};
use tokio::time::timeout;
use tokio_stream::wrappers::ReceiverStream;
use std::future::Future;
/// The builder for [`Dispatcher`].
pub struct DispatcherBuilder<R, Err> {
pub struct DispatcherBuilder<R, Err, Key> {
bot: R,
dependencies: DependencyMap,
handler: Arc<UpdateHandler<Err>>,
default_handler: DefaultHandler,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
distribution_f: fn(&Update) -> Option<Key>,
worker_queue_size: usize,
}
impl<R, Err> DispatcherBuilder<R, Err>
impl<R, Err, Key> DispatcherBuilder<R, Err, Key>
where
R: Clone + Requester + Clone + Send + Sync + 'static,
Err: Debug + Send + Sync + 'static,
@ -75,17 +75,67 @@ where
Self { dependencies, ..self }
}
/// Specifies size of the queue for workers.
///
/// By default it's 64.
#[must_use]
pub fn worker_queue_size(self, size: usize) -> Self {
Self { worker_queue_size: size, ..self }
}
/// Specifies the distribution function that decides how updates are grouped
/// before execution.
pub fn distribution_function<K>(
self,
f: fn(&Update) -> Option<K>,
) -> DispatcherBuilder<R, Err, K>
where
K: Hash + Eq,
{
let Self {
bot,
dependencies,
handler,
default_handler,
error_handler,
distribution_f: _,
worker_queue_size,
} = self;
DispatcherBuilder {
bot,
dependencies,
handler,
default_handler,
error_handler,
distribution_f: f,
worker_queue_size,
}
}
/// Constructs [`Dispatcher`].
#[must_use]
pub fn build(self) -> Dispatcher<R, Err> {
pub fn build(self) -> Dispatcher<R, Err, Key> {
let Self {
bot,
dependencies,
handler,
default_handler,
error_handler,
distribution_f,
worker_queue_size,
} = self;
Dispatcher {
bot: self.bot.clone(),
dependencies: self.dependencies,
handler: self.handler,
default_handler: self.default_handler,
error_handler: self.error_handler,
bot,
dependencies,
handler,
default_handler,
error_handler,
allowed_updates: Default::default(),
state: ShutdownToken::new(),
distribution_f,
worker_queue_size,
workers: HashMap::new(),
default_worker: None,
}
@ -97,15 +147,20 @@ where
/// Updates from different chats are handles concurrently, whereas updates from
/// the same chats are handled sequentially. If the dispatcher is unable to
/// determine a chat ID of an incoming update, it will be handled concurrently.
pub struct Dispatcher<R, Err> {
/// Note that this behaviour can be altered with [`distribution_function`].
///
/// [`distribution_function`]: DispatcherBuilder::distribution_function
pub struct Dispatcher<R, Err, Key> {
bot: R,
dependencies: DependencyMap,
handler: Arc<UpdateHandler<Err>>,
default_handler: DefaultHandler,
distribution_f: fn(&Update) -> Option<Key>,
worker_queue_size: usize,
// Tokio TX channel parts associated with chat IDs that consume updates sequentially.
workers: HashMap<ChatId, Worker>,
workers: HashMap<Key, Worker>,
// The default TX part that consume updates concurrently.
default_worker: Option<Worker>,
@ -129,17 +184,19 @@ pub type UpdateHandler<Err> = dptree::Handler<'static, DependencyMap, Result<(),
type DefaultHandler = Arc<dyn Fn(Arc<Update>) -> BoxFuture<'static, ()> + Send + Sync>;
impl<R, Err> Dispatcher<R, Err>
impl<R, Err> Dispatcher<R, Err, DefaultKey>
where
R: Requester + Clone + Send + Sync + 'static,
Err: Send + Sync + 'static,
{
/// Constructs a new [`DispatcherBuilder`] with `bot` and `handler`.
#[must_use]
pub fn builder(bot: R, handler: UpdateHandler<Err>) -> DispatcherBuilder<R, Err>
pub fn builder(bot: R, handler: UpdateHandler<Err>) -> DispatcherBuilder<R, Err, DefaultKey>
where
Err: Debug,
{
const DEFAULT_WORKER_QUEUE_SIZE: usize = 64;
DispatcherBuilder {
bot,
dependencies: DependencyMap::new(),
@ -149,9 +206,18 @@ where
Box::pin(async {})
}),
error_handler: LoggingErrorHandler::new(),
worker_queue_size: DEFAULT_WORKER_QUEUE_SIZE,
distribution_f: default_distribution_function,
}
}
}
impl<R, Err, Key> Dispatcher<R, Err, Key>
where
R: Requester + Clone + Send + Sync + 'static,
Err: Send + Sync + 'static,
Key: Hash + Eq,
{
/// Starts your bot with the default parameters.
///
/// The default parameters are a long polling update listener and log all
@ -263,17 +329,34 @@ where
return;
}
let deps = self.dependencies.clone();
let handler = Arc::clone(&self.handler);
let default_handler = Arc::clone(&self.default_handler);
let error_handler = Arc::clone(&self.error_handler);
let worker = match (self.distribution_f)(&upd) {
Some(key) => self.workers.entry(key).or_insert_with(|| {
let deps = self.dependencies.clone();
let handler = Arc::clone(&self.handler);
let default_handler = Arc::clone(&self.default_handler);
let error_handler = Arc::clone(&self.error_handler);
let worker = match upd.chat() {
Some(chat) => self.workers.entry(chat.id).or_insert_with(|| {
spawn_worker(deps, handler, default_handler, error_handler)
spawn_worker(
deps,
handler,
default_handler,
error_handler,
self.worker_queue_size,
)
}),
None => self.default_worker.get_or_insert_with(|| {
spawn_default_worker(deps, handler, default_handler, error_handler)
let deps = self.dependencies.clone();
let handler = Arc::clone(&self.handler);
let default_handler = Arc::clone(&self.default_handler);
let error_handler = Arc::clone(&self.error_handler);
spawn_default_worker(
deps,
handler,
default_handler,
error_handler,
self.worker_queue_size,
)
}),
};
@ -316,18 +399,17 @@ where
}
}
const WORKER_QUEUE_SIZE: usize = 64;
fn spawn_worker<Err>(
deps: DependencyMap,
handler: Arc<UpdateHandler<Err>>,
default_handler: DefaultHandler,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
queue_size: usize,
) -> Worker
where
Err: Send + Sync + 'static,
{
let (tx, rx) = tokio::sync::mpsc::channel(WORKER_QUEUE_SIZE);
let (tx, rx) = tokio::sync::mpsc::channel(queue_size);
let deps = Arc::new(deps);
@ -348,11 +430,12 @@ fn spawn_default_worker<Err>(
handler: Arc<UpdateHandler<Err>>,
default_handler: DefaultHandler,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
queue_size: usize,
) -> Worker
where
Err: Send + Sync + 'static,
{
let (tx, rx) = tokio::sync::mpsc::channel(WORKER_QUEUE_SIZE);
let (tx, rx) = tokio::sync::mpsc::channel(queue_size);
let deps = Arc::new(deps);
@ -403,7 +486,7 @@ mod tests {
tokio::spawn(async {
// Just check that this code compiles.
if false {
Dispatcher::<_, Infallible>::builder(Bot::new(""), dptree::entry())
Dispatcher::<_, Infallible, _>::builder(Bot::new(""), dptree::entry())
.build()
.dispatch()
.await;

View file

@ -0,0 +1,9 @@
use teloxide_core::types::{ChatId, Update};
/// Default distribution key for dispatching.
#[derive(Debug, Hash, PartialEq, Eq)]
pub struct DefaultKey(ChatId);
pub(crate) fn default_distribution_function(update: &Update) -> Option<DefaultKey> {
update.chat().map(|c| c.id).map(DefaultKey)
}

View file

@ -100,6 +100,7 @@ pub mod repls;
pub mod dialogue;
mod dispatcher;
mod distribution;
mod filter_ext;
mod handler_ext;
mod handler_factory;
@ -108,6 +109,7 @@ pub mod update_listeners;
pub use crate::utils::shutdown_token::{IdleShutdownError, ShutdownToken};
pub use dispatcher::{Dispatcher, DispatcherBuilder, UpdateHandler};
pub use distribution::DefaultKey;
pub use filter_ext::{MessageFilterExt, UpdateFilterExt};
pub use handler_ext::HandlerExt;
#[allow(deprecated)]