Wait for workers by .awaiting join handles

This commit is contained in:
Hirrolot 2022-04-13 14:05:46 +06:00
parent 6a83fa5604
commit 579d5a7b7c

View file

@ -102,9 +102,9 @@ pub struct Dispatcher<R, Err> {
default_handler: DefaultHandler,
// Tokio TX channel parts associated with chat IDs that consume updates sequentially.
workers: HashMap<i64, WorkerTx>,
workers: HashMap<i64, Worker>,
// The default TX part that consume updates concurrently.
default_worker: Option<WorkerTx>,
default_worker: Option<Worker>,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
// TODO: respect allowed_udpates
@ -113,7 +113,10 @@ pub struct Dispatcher<R, Err> {
state: ShutdownToken,
}
type WorkerTx = tokio::sync::mpsc::Sender<Update>;
struct Worker {
tx: tokio::sync::mpsc::Sender<Update>,
handle: tokio::task::JoinHandle<()>,
}
// TODO: it is allowed to return message as response on telegram request in
// webhooks, so we can allow this too. See more there: https://core.telegram.org/bots/api#making-requests-when-getting-updates
@ -225,8 +228,15 @@ where
}
}
self.workers.drain();
self.default_worker.take();
for (_chat_id, worker) in self.workers.drain() {
drop(worker.tx);
worker.handle.await.expect("Unable to wait for a worker");
}
if let Some(worker) = self.default_worker.take() {
drop(worker.tx);
worker.handle.await.expect("Unable to wait for a default handler");
}
self.state.done();
}
@ -254,7 +264,7 @@ where
let default_handler = Arc::clone(&self.default_handler);
let error_handler = Arc::clone(&self.error_handler);
let tx = match upd.chat() {
let worker = match upd.chat() {
Some(chat) => self.workers.entry(chat.id).or_insert_with(|| {
spawn_worker(deps, handler, default_handler, error_handler)
}),
@ -263,7 +273,7 @@ where
}),
};
tx.send(upd).await.expect("TX is dead");
worker.tx.send(upd).await.expect("TX is dead");
}
Err(err) => err_handler.clone().handle_error(err).await,
}
@ -309,7 +319,7 @@ fn spawn_worker<Err>(
handler: Arc<UpdateHandler<Err>>,
default_handler: DefaultHandler,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
) -> WorkerTx
) -> Worker
where
Err: Send + Sync + 'static,
{
@ -317,7 +327,7 @@ where
let deps = Arc::new(deps);
tokio::spawn(ReceiverStream::new(rx).for_each(move |update| {
let handle = tokio::spawn(ReceiverStream::new(rx).for_each(move |update| {
let deps = Arc::clone(&deps);
let handler = Arc::clone(&handler);
let default_handler = Arc::clone(&default_handler);
@ -326,7 +336,7 @@ where
handle_update(update, deps, handler, default_handler, error_handler)
}));
tx
Worker { tx, handle }
}
fn spawn_default_worker<Err>(
@ -334,7 +344,7 @@ fn spawn_default_worker<Err>(
handler: Arc<UpdateHandler<Err>>,
default_handler: DefaultHandler,
error_handler: Arc<dyn ErrorHandler<Err> + Send + Sync>,
) -> WorkerTx
) -> Worker
where
Err: Send + Sync + 'static,
{
@ -342,7 +352,7 @@ where
let deps = Arc::new(deps);
tokio::spawn(ReceiverStream::new(rx).for_each_concurrent(None, move |update| {
let handle = tokio::spawn(ReceiverStream::new(rx).for_each_concurrent(None, move |update| {
let deps = Arc::clone(&deps);
let handler = Arc::clone(&handler);
let default_handler = Arc::clone(&default_handler);
@ -351,7 +361,7 @@ where
handle_update(update, deps, handler, default_handler, error_handler)
}));
tx
Worker { tx, handle }
}
async fn handle_update<Err>(