From 07e08bef6ccf2293f1950042f5be8620fd8d2cd8 Mon Sep 17 00:00:00 2001 From: Maybe Waffle Date: Mon, 25 Sep 2023 20:02:57 +0400 Subject: [PATCH] Don't use timeout to check `ShutdownToken` --- crates/teloxide/src/dispatching/dispatcher.rs | 22 ++++++--------- crates/teloxide/src/utils/shutdown_token.rs | 27 ++++++++++--------- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/crates/teloxide/src/dispatching/dispatcher.rs b/crates/teloxide/src/dispatching/dispatcher.rs index d7c14387..07e80cc9 100644 --- a/crates/teloxide/src/dispatching/dispatcher.rs +++ b/crates/teloxide/src/dispatching/dispatcher.rs @@ -7,12 +7,10 @@ use crate::{ requests::{Request, Requester}, types::{Update, UpdateKind}, update_listeners::{self, UpdateListener}, - utils::shutdown_token::shutdown_check_timeout_for, }; use dptree::di::{DependencyMap, DependencySupplier}; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; -use tokio::time::timeout; use tokio_stream::wrappers::ReceiverStream; use std::{ @@ -312,7 +310,6 @@ where log::debug!("hinting allowed updates: {:?}", allowed_updates); update_listener.hint_allowed_updates(&mut allowed_updates.into_iter()); - let shutdown_check_timeout = shutdown_check_timeout_for(&update_listener); let mut stop_token = Some(update_listener.stop_token()); self.state.start_dispatching(); @@ -324,19 +321,16 @@ where loop { self.remove_inactive_workers_if_needed().await; - // False positive - #[allow(clippy::collapsible_match)] - if let Ok(upd) = timeout(shutdown_check_timeout, stream.next()).await { - match upd { + tokio::select! { + upd = stream.next() => match upd { None => break, Some(upd) => self.process_update(upd, &update_listener_error_handler).await, - } - } - - if self.state.is_shutting_down() { - if let Some(token) = stop_token.take() { - log::debug!("Start shutting down dispatching..."); - token.stop(); + }, + () = self.state.wait_for_changes() => if self.state.is_shutting_down() { + if let Some(token) = stop_token.take() { + log::debug!("Start shutting down dispatching..."); + token.stop(); + } } } } diff --git a/crates/teloxide/src/utils/shutdown_token.rs b/crates/teloxide/src/utils/shutdown_token.rs index cb4da020..14ca622f 100644 --- a/crates/teloxide/src/utils/shutdown_token.rs +++ b/crates/teloxide/src/utils/shutdown_token.rs @@ -5,18 +5,16 @@ use std::{ atomic::{AtomicU8, Ordering}, Arc, }, - time::Duration, }; use tokio::sync::Notify; -use crate::update_listeners::UpdateListener; - /// A token which used to shutdown [`Dispatcher`]. /// /// [`Dispatcher`]: crate::dispatching::Dispatcher #[derive(Clone)] pub struct ShutdownToken { + // FIXME: use a single arc dispatcher_state: Arc, shutdown_notify_back: Arc, } @@ -49,11 +47,16 @@ impl ShutdownToken { Self { dispatcher_state: Arc::new(DispatcherState { inner: AtomicU8::new(ShutdownState::Idle as _), + notify: <_>::default(), }), shutdown_notify_back: <_>::default(), } } + pub(crate) async fn wait_for_changes(&self) { + self.dispatcher_state.notify.notified().await; + } + pub(crate) fn start_dispatching(&self) { if let Err(actual) = self.dispatcher_state.compare_exchange(ShutdownState::Idle, ShutdownState::Running) @@ -93,27 +96,20 @@ impl fmt::Display for IdleShutdownError { impl std::error::Error for IdleShutdownError {} -pub(crate) fn shutdown_check_timeout_for(update_listener: &impl UpdateListener) -> Duration { - const MIN_SHUTDOWN_CHECK_TIMEOUT: Duration = Duration::from_secs(1); - const DZERO: Duration = Duration::ZERO; - - let shutdown_check_timeout = update_listener.timeout_hint().unwrap_or(DZERO); - shutdown_check_timeout.saturating_add(MIN_SHUTDOWN_CHECK_TIMEOUT) -} - struct DispatcherState { inner: AtomicU8, + notify: Notify, } impl DispatcherState { // Ordering::Relaxed: only one atomic variable, nothing to synchronize. - fn load(&self) -> ShutdownState { ShutdownState::from_u8(self.inner.load(Ordering::Relaxed)) } fn store(&self, new: ShutdownState) { - self.inner.store(new as _, Ordering::Relaxed) + self.inner.store(new as _, Ordering::Relaxed); + self.notify.notify_waiters(); } fn compare_exchange( @@ -125,6 +121,11 @@ impl DispatcherState { .compare_exchange(current as _, new as _, Ordering::Relaxed, Ordering::Relaxed) .map(ShutdownState::from_u8) .map_err(ShutdownState::from_u8) + // FIXME: `Result::inspect` when :( + .map(|st| { + self.notify.notify_waiters(); + st + }) } }