Auto-magically detect how much workers need to be kept alive

This commit is contained in:
Maybe Waffle 2022-06-26 22:53:41 +04:00
parent 9cb7ca9bd3
commit a820dedd50

View file

@ -11,20 +11,20 @@ use crate::{
use dptree::di::{DependencyMap, DependencySupplier};
use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
use std::{
collections::HashMap,
fmt::Debug,
hash::Hash,
ops::{ControlFlow, Deref},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use tokio::time::timeout;
use tokio_stream::wrappers::ReceiverStream;
use std::future::Future;
use std::{
collections::HashMap,
fmt::Debug,
future::Future,
hash::Hash,
ops::{ControlFlow, Deref},
sync::{
atomic::{AtomicBool, AtomicU32, Ordering},
Arc,
},
};
/// The builder for [`Dispatcher`].
pub struct DispatcherBuilder<R, Err, Key> {
@ -35,7 +35,6 @@ pub struct DispatcherBuilder<R, Err, Key> {
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
distribution_f: fn(&Update) -> Option<Key>,
worker_queue_size: usize,
gc_worker_count_trigger: usize,
}
impl<R, Err, Key> DispatcherBuilder<R, Err, Key>
@ -87,17 +86,6 @@ where
Self { worker_queue_size: size, ..self }
}
/// Maximum number of inactive workers.
///
/// When number of workers exceeds this limit dispatcher will try to remove
/// inactive workers.
///
/// By default it's 32.
#[must_use]
pub fn gc_worker_count_trigger(self, count: usize) -> Self {
Self { gc_worker_count_trigger: count, ..self }
}
/// Specifies the distribution function that decides how updates are grouped
/// before execution.
pub fn distribution_function<K>(
@ -115,7 +103,6 @@ where
error_handler,
distribution_f: _,
worker_queue_size,
gc_worker_count_trigger: worker_count_gc,
} = self;
DispatcherBuilder {
@ -126,7 +113,6 @@ where
error_handler,
distribution_f: f,
worker_queue_size,
gc_worker_count_trigger: worker_count_gc,
}
}
@ -141,7 +127,6 @@ where
error_handler,
distribution_f,
worker_queue_size,
gc_worker_count_trigger,
} = self;
Dispatcher {
@ -155,7 +140,8 @@ where
worker_queue_size,
workers: HashMap::new(),
default_worker: None,
gc_worker_count_trigger,
current_number_of_active_workers: Default::default(),
max_number_of_active_workers: Default::default(),
}
}
}
@ -177,7 +163,8 @@ pub struct Dispatcher<R, Err, Key> {
distribution_f: fn(&Update) -> Option<Key>,
worker_queue_size: usize,
gc_worker_count_trigger: usize,
current_number_of_active_workers: Arc<AtomicU32>,
max_number_of_active_workers: Arc<AtomicU32>,
// Tokio TX channel parts associated with chat IDs that consume updates sequentially.
workers: HashMap<Key, Worker>,
// The default TX part that consume updates concurrently.
@ -215,7 +202,6 @@ where
Err: Debug,
{
const DEFAULT_WORKER_QUEUE_SIZE: usize = 64;
const DEFAULT_GC_WORKER_COUNT_TRIGGER: usize = 32;
DispatcherBuilder {
bot,
@ -228,7 +214,6 @@ where
error_handler: LoggingErrorHandler::new(),
worker_queue_size: DEFAULT_WORKER_QUEUE_SIZE,
distribution_f: default_distribution_function,
gc_worker_count_trigger: DEFAULT_GC_WORKER_COUNT_TRIGGER,
}
}
}
@ -367,6 +352,8 @@ where
handler,
default_handler,
error_handler,
Arc::clone(&self.current_number_of_active_workers),
Arc::clone(&self.max_number_of_active_workers),
self.worker_queue_size,
)
}),
@ -393,7 +380,10 @@ where
}
async fn gc_workers_if_needed(&mut self) {
if self.workers.len() <= self.gc_worker_count_trigger {
let workers = self.workers.len();
let max = self.max_number_of_active_workers.load(Ordering::Relaxed) as usize;
if workers <= max {
return;
}
@ -469,6 +459,8 @@ fn spawn_worker<Err>(
handler: Arc<UpdateHandler<Err>>,
default_handler: DefaultHandler,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
current_number_of_active_workers: Arc<AtomicU32>,
max_number_of_active_workers: Arc<AtomicU32>,
queue_size: usize,
) -> Worker
where
@ -483,6 +475,10 @@ where
let handle = tokio::spawn(async move {
while let Some(update) = rx.recv().await {
is_waiting_local.store(false, Ordering::Relaxed);
{
let current = current_number_of_active_workers.fetch_add(1, Ordering::Relaxed) + 1;
max_number_of_active_workers.fetch_max(current, Ordering::Relaxed);
}
let deps = Arc::clone(&deps);
let handler = Arc::clone(&handler);
@ -491,6 +487,7 @@ where
handle_update(update, deps, handler, default_handler, error_handler).await;
current_number_of_active_workers.fetch_sub(1, Ordering::Relaxed);
is_waiting_local.store(true, Ordering::Relaxed);
}
});