Merge pull request #505 from teloxide/refactor_shutdown_token

Refactor `ShutdownToken`
This commit is contained in:
Hirrolot 2022-02-05 13:33:07 +06:00 committed by GitHub
commit 25f863402d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 204 additions and 214 deletions

View file

@ -1,11 +1,4 @@
use std::{
fmt::{self, Debug},
sync::{
atomic::{AtomicU8, Ordering},
Arc,
},
time::Duration,
};
use std::{fmt::Debug, sync::Arc};
use crate::{
dispatching::{
@ -14,9 +7,10 @@ use crate::{
DispatcherHandler, UpdateWithCx,
},
error_handlers::{ErrorHandler, LoggingErrorHandler},
utils::shutdown_token::shutdown_check_timeout_for,
};
use futures::{stream::FuturesUnordered, Future, StreamExt};
use futures::{stream::FuturesUnordered, StreamExt};
use teloxide_core::{
requests::Requester,
types::{
@ -25,11 +19,9 @@ use teloxide_core::{
UpdateKind,
},
};
use tokio::{
sync::{mpsc, Notify},
task::JoinHandle,
time::timeout,
};
use tokio::{sync::mpsc, task::JoinHandle, time::timeout};
use crate::utils::shutdown_token::ShutdownToken;
type Tx<Upd, R> = Option<mpsc::UnboundedSender<UpdateWithCx<Upd, R>>>;
@ -58,8 +50,7 @@ pub struct Dispatcher<R> {
running_handlers: FuturesUnordered<JoinHandle<()>>,
state: Arc<DispatcherState>,
shutdown_notify_back: Arc<Notify>,
state: ShutdownToken,
}
impl<R> Dispatcher<R>
@ -86,8 +77,7 @@ where
chat_members_queue: None,
chat_join_requests_queue: None,
running_handlers: FuturesUnordered::new(),
state: <_>::default(),
shutdown_notify_back: <_>::default(),
state: ShutdownToken::new(),
}
}
@ -113,20 +103,18 @@ where
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "ctrlc_handler")))]
#[must_use]
pub fn setup_ctrlc_handler(self) -> Self {
let state = Arc::clone(&self.state);
let token = self.state.clone();
tokio::spawn(async move {
loop {
tokio::signal::ctrl_c().await.expect("Failed to listen for ^C");
match shutdown_inner(&state) {
Ok(()) => log::info!("^C received, trying to shutdown the dispatcher..."),
Err(Ok(AlreadyShuttingDown)) => {
log::info!(
"^C received, the dispatcher is already shutting down, ignoring the \
signal"
)
match token.shutdown() {
Ok(f) => {
log::info!("^C received, trying to shutdown the dispatcher...");
f.await;
log::info!("dispatcher is shutdown...");
}
Err(Err(IdleShutdownError)) => {
Err(_) => {
log::info!("^C received, the dispatcher isn't running, ignoring the signal")
}
}
@ -297,19 +285,12 @@ where
ListenerE: Debug,
R: Requester + Clone,
{
use ShutdownState::*;
self.hint_allowed_updates(&mut update_listener);
let shutdown_check_timeout = shutdown_check_timeout_for(&update_listener);
let mut stop_token = Some(update_listener.stop_token());
if let Err(actual) = self.state.compare_exchange(Idle, Running) {
unreachable!(
"Dispatching is already running: expected `{:?}` state, found `{:?}`",
Idle, actual
);
}
self.state.start_dispatching();
{
let stream = update_listener.as_stream();
@ -325,7 +306,7 @@ where
}
}
if let ShuttingDown = self.state.load() {
if self.state.is_shutting_down() {
if let Some(token) = stop_token.take() {
log::debug!("Start shutting down dispatching...");
token.stop();
@ -335,27 +316,13 @@ where
}
self.wait_for_handlers().await;
if let ShuttingDown = self.state.load() {
// Stopped because of a `shutdown` call.
// Notify `shutdown`s that we finished
self.shutdown_notify_back.notify_waiters();
log::info!("Dispatching has been shut down.");
} else {
log::info!("Dispatching has been stopped (listener returned `None`).");
}
self.state.store(Idle);
self.state.done();
}
/// Returns a shutdown token, which can later be used to shutdown
/// dispatching.
pub fn shutdown_token(&self) -> ShutdownToken {
ShutdownToken {
dispatcher_state: Arc::clone(&self.state),
shutdown_notify_back: Arc::clone(&self.shutdown_notify_back),
}
self.state.clone()
}
async fn process_update<ListenerE, Eh>(
@ -547,123 +514,6 @@ where
}
}
/// This error is returned from [`ShutdownToken::shutdown`] when trying to
/// shutdown an idle [`Dispatcher`].
#[derive(Debug)]
pub struct IdleShutdownError;
impl fmt::Display for IdleShutdownError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Dispatcher was idle and as such couldn't be shut down")
}
}
impl std::error::Error for IdleShutdownError {}
/// A token which used to shutdown [`Dispatcher`].
#[derive(Clone)]
pub struct ShutdownToken {
pub(crate) dispatcher_state: Arc<DispatcherState>,
pub(crate) shutdown_notify_back: Arc<Notify>,
}
impl ShutdownToken {
/// Tries to shutdown dispatching.
///
/// Returns an error if the dispatcher is idle at the moment.
///
/// If you don't need to wait for shutdown, the returned future can be
/// ignored.
pub fn shutdown(&self) -> Result<impl Future<Output = ()> + '_, IdleShutdownError> {
match shutdown_inner(&self.dispatcher_state) {
Ok(()) | Err(Ok(AlreadyShuttingDown)) => Ok(async move {
log::info!("Trying to shutdown the dispatcher...");
self.shutdown_notify_back.notified().await
}),
Err(Err(err)) => Err(err),
}
}
}
pub(crate) struct DispatcherState {
inner: AtomicU8,
}
impl DispatcherState {
pub(crate) fn load(&self) -> ShutdownState {
ShutdownState::from_u8(self.inner.load(Ordering::SeqCst))
}
pub(crate) fn store(&self, new: ShutdownState) {
self.inner.store(new as _, Ordering::SeqCst)
}
pub(crate) fn compare_exchange(
&self,
current: ShutdownState,
new: ShutdownState,
) -> Result<ShutdownState, ShutdownState> {
self.inner
.compare_exchange(current as _, new as _, Ordering::SeqCst, Ordering::SeqCst)
.map(ShutdownState::from_u8)
.map_err(ShutdownState::from_u8)
}
}
impl Default for DispatcherState {
fn default() -> Self {
Self { inner: AtomicU8::new(ShutdownState::Idle as _) }
}
}
#[repr(u8)]
#[derive(Debug)]
pub(crate) enum ShutdownState {
Running,
ShuttingDown,
Idle,
}
impl ShutdownState {
fn from_u8(n: u8) -> Self {
const RUNNING: u8 = ShutdownState::Running as u8;
const SHUTTING_DOWN: u8 = ShutdownState::ShuttingDown as u8;
const IDLE: u8 = ShutdownState::Idle as u8;
match n {
RUNNING => ShutdownState::Running,
SHUTTING_DOWN => ShutdownState::ShuttingDown,
IDLE => ShutdownState::Idle,
_ => unreachable!(),
}
}
}
pub(crate) fn shutdown_check_timeout_for<E>(update_listener: &impl UpdateListener<E>) -> Duration {
const MIN_SHUTDOWN_CHECK_TIMEOUT: Duration = Duration::from_secs(1);
const DZERO: Duration = Duration::ZERO;
let shutdown_check_timeout = update_listener.timeout_hint().unwrap_or(DZERO);
shutdown_check_timeout.saturating_add(MIN_SHUTDOWN_CHECK_TIMEOUT)
}
pub(crate) struct AlreadyShuttingDown;
pub(crate) fn shutdown_inner(
state: &DispatcherState,
) -> Result<(), Result<AlreadyShuttingDown, IdleShutdownError>> {
use ShutdownState::*;
let res = state.compare_exchange(Running, ShuttingDown);
match res {
Ok(_) => Ok(()),
Err(ShuttingDown) => Err(Ok(AlreadyShuttingDown)),
Err(Idle) => Err(Err(IdleShutdownError)),
Err(Running) => unreachable!(),
}
}
fn send<'a, R, Upd>(requester: &'a R, tx: &'a Tx<R, Upd>, update: Upd, variant: &'static str)
where
Upd: Debug,

View file

@ -59,17 +59,13 @@ mod dispatcher_handler;
mod dispatcher_handler_rx_ext;
mod update_with_cx;
pub use dispatcher::{Dispatcher, IdleShutdownError, ShutdownToken};
pub use crate::utils::shutdown_token::{IdleShutdownError, ShutdownToken};
pub use dispatcher::Dispatcher;
pub use dispatcher_handler::DispatcherHandler;
pub use dispatcher_handler_rx_ext::DispatcherHandlerRxExt;
use tokio::sync::mpsc::UnboundedReceiver;
pub use update_with_cx::{UpdateWithCx, UpdateWithCxRequesterType};
#[cfg(feature = "dispatching2")]
pub(crate) use dispatcher::{
shutdown_check_timeout_for, shutdown_inner, DispatcherState, ShutdownState,
};
/// A type of a stream, consumed by [`Dispatcher`]'s handlers.
///
/// [`Dispatcher`]: crate::dispatching::Dispatcher

View file

@ -1,18 +1,18 @@
use crate::{
adaptors::CacheMe,
dispatching::{
shutdown_check_timeout_for, shutdown_inner, stop_token::StopToken, update_listeners,
update_listeners::UpdateListener, DispatcherState, ShutdownToken,
stop_token::StopToken, update_listeners, update_listeners::UpdateListener, ShutdownToken,
},
error_handlers::{ErrorHandler, LoggingErrorHandler},
requests::Requester,
types::{AllowedUpdate, Update},
utils::shutdown_token::shutdown_check_timeout_for,
};
use dptree::di::{DependencyMap, DependencySupplier};
use futures::{future::BoxFuture, StreamExt};
use std::{collections::HashSet, fmt::Debug, ops::ControlFlow, sync::Arc};
use teloxide_core::requests::{Request, RequesterExt};
use tokio::{sync::Notify, time::timeout};
use tokio::time::timeout;
use std::future::Future;
@ -77,8 +77,7 @@ where
default_handler: self.default_handler,
error_handler: self.error_handler,
allowed_updates: Default::default(),
state: Arc::new(Default::default()),
shutdown_notify_back: Arc::new(Default::default()),
state: ShutdownToken::new(),
}
}
}
@ -95,8 +94,7 @@ pub struct Dispatcher<R, Err> {
// TODO: respect allowed_udpates
allowed_updates: HashSet<AllowedUpdate>,
state: Arc<DispatcherState>,
shutdown_notify_back: Arc<Notify>,
state: ShutdownToken,
}
// TODO: it is allowed to return message as response on telegram request in
@ -174,19 +172,12 @@ where
Eh: ErrorHandler<ListenerE> + 'a,
ListenerE: Debug,
{
use crate::dispatching::ShutdownState::*;
update_listener.hint_allowed_updates(&mut self.allowed_updates.clone().into_iter());
let shutdown_check_timeout = shutdown_check_timeout_for(&update_listener);
let mut stop_token = Some(update_listener.stop_token());
if let Err(actual) = self.state.compare_exchange(Idle, Running) {
unreachable!(
"Dispatching is already running: expected `{:?}` state, found `{:?}`",
Idle, actual
);
}
self.state.start_dispatching();
{
let stream = update_listener.as_stream();
@ -202,7 +193,7 @@ where
}
}
if let ShuttingDown = self.state.load() {
if self.state.is_shutting_down() {
if let Some(token) = stop_token.take() {
log::debug!("Start shutting down dispatching...");
token.stop();
@ -212,17 +203,9 @@ where
}
}
if let ShuttingDown = self.state.load() {
// Stopped because of a `shutdown` call.
// TODO: wait for executing handlers?
// Notify `shutdown`s that we finished
self.shutdown_notify_back.notify_waiters();
log::info!("Dispatching has been shut down.");
} else {
log::info!("Dispatching has been stopped (listener returned `None`).");
}
self.state.store(Idle);
self.state.done();
}
async fn process_update<LErr, LErrHandler>(
@ -262,20 +245,18 @@ where
#[cfg(feature = "ctrlc_handler")]
#[cfg_attr(docsrs, doc(cfg(feature = "ctrlc_handler")))]
pub fn setup_ctrlc_handler(&mut self) -> &mut Self {
let state = Arc::clone(&self.state);
let token = self.state.clone();
tokio::spawn(async move {
loop {
tokio::signal::ctrl_c().await.expect("Failed to listen for ^C");
match shutdown_inner(&state) {
Ok(()) => log::info!("^C received, trying to shutdown the dispatcher..."),
Err(Ok(_)) => {
log::info!(
"^C received, the dispatcher is already shutting down, ignoring the \
signal"
)
match token.shutdown() {
Ok(f) => {
log::info!("^C received, trying to shutdown the dispatcher...");
f.await;
log::info!("dispatcher is shutdown...");
}
Err(Err(_)) => {
Err(_) => {
log::info!("^C received, the dispatcher isn't running, ignoring the signal")
}
}
@ -288,9 +269,6 @@ where
/// Returns a shutdown token, which can later be used to shutdown
/// dispatching.
pub fn shutdown_token(&self) -> ShutdownToken {
ShutdownToken {
dispatcher_state: Arc::clone(&self.state),
shutdown_notify_back: Arc::clone(&self.shutdown_notify_back),
}
self.state.clone()
}
}

View file

@ -3,6 +3,7 @@
pub mod command;
pub mod html;
pub mod markdown;
pub(crate) mod shutdown_token;
mod up_state;
pub use teloxide_core::net::client_from_env;

165
src/utils/shutdown_token.rs Normal file
View file

@ -0,0 +1,165 @@
use std::{
fmt,
future::Future,
sync::{
atomic::{AtomicU8, Ordering},
Arc,
},
time::Duration,
};
use tokio::sync::Notify;
use crate::dispatching::update_listeners::UpdateListener;
/// A token which used to shutdown [`Dispatcher`].
#[derive(Clone)]
pub struct ShutdownToken {
dispatcher_state: Arc<DispatcherState>,
shutdown_notify_back: Arc<Notify>,
}
/// This error is returned from [`ShutdownToken::shutdown`] when trying to
/// shutdown an idle [`Dispatcher`].
#[derive(Debug)]
pub struct IdleShutdownError;
impl ShutdownToken {
/// Tries to shutdown dispatching.
///
/// Returns an error if the dispatcher is idle at the moment.
///
/// If you don't need to wait for shutdown, the returned future can be
/// ignored.
pub fn shutdown(&self) -> Result<impl Future<Output = ()> + '_, IdleShutdownError> {
match shutdown_inner(&self.dispatcher_state) {
Ok(()) | Err(Ok(AlreadyShuttingDown)) => Ok(async move {
log::info!("Trying to shutdown the dispatcher...");
self.shutdown_notify_back.notified().await
}),
Err(Err(err)) => Err(err),
}
}
pub(crate) fn new() -> Self {
Self {
dispatcher_state: Arc::new(DispatcherState {
inner: AtomicU8::new(ShutdownState::Idle as _),
}),
shutdown_notify_back: <_>::default(),
}
}
pub(crate) fn start_dispatching(&self) {
if let Err(actual) =
self.dispatcher_state.compare_exchange(ShutdownState::Idle, ShutdownState::Running)
{
panic!(
"Dispatching is already running: expected `{:?}` state, found `{:?}`",
ShutdownState::Idle,
actual
);
}
}
pub(crate) fn is_shutting_down(&self) -> bool {
matches!(self.dispatcher_state.load(), ShutdownState::ShuttingDown)
}
pub(crate) fn done(&self) {
if self.is_shutting_down() {
// Stopped because of a `shutdown` call.
// Notify `shutdown`s that we finished
self.shutdown_notify_back.notify_waiters();
log::info!("Dispatching has been shut down.");
} else {
log::info!("Dispatching has been stopped (listener returned `None`).");
}
self.dispatcher_state.store(ShutdownState::Idle);
}
}
impl fmt::Display for IdleShutdownError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Dispatcher was idle and as such couldn't be shut down")
}
}
impl std::error::Error for IdleShutdownError {}
pub(crate) fn shutdown_check_timeout_for<E>(update_listener: &impl UpdateListener<E>) -> Duration {
const MIN_SHUTDOWN_CHECK_TIMEOUT: Duration = Duration::from_secs(1);
const DZERO: Duration = Duration::ZERO;
let shutdown_check_timeout = update_listener.timeout_hint().unwrap_or(DZERO);
shutdown_check_timeout.saturating_add(MIN_SHUTDOWN_CHECK_TIMEOUT)
}
struct DispatcherState {
inner: AtomicU8,
}
impl DispatcherState {
// Ordering::Relaxed: only one atomic variable, nothing to synchronize.
fn load(&self) -> ShutdownState {
ShutdownState::from_u8(self.inner.load(Ordering::Relaxed))
}
fn store(&self, new: ShutdownState) {
self.inner.store(new as _, Ordering::Relaxed)
}
fn compare_exchange(
&self,
current: ShutdownState,
new: ShutdownState,
) -> Result<ShutdownState, ShutdownState> {
self.inner
.compare_exchange(current as _, new as _, Ordering::Relaxed, Ordering::Relaxed)
.map(ShutdownState::from_u8)
.map_err(ShutdownState::from_u8)
}
}
#[repr(u8)]
#[derive(Debug)]
enum ShutdownState {
Running,
ShuttingDown,
Idle,
}
impl ShutdownState {
fn from_u8(n: u8) -> Self {
const RUNNING: u8 = ShutdownState::Running as u8;
const SHUTTING_DOWN: u8 = ShutdownState::ShuttingDown as u8;
const IDLE: u8 = ShutdownState::Idle as u8;
match n {
RUNNING => ShutdownState::Running,
SHUTTING_DOWN => ShutdownState::ShuttingDown,
IDLE => ShutdownState::Idle,
_ => unreachable!(),
}
}
}
struct AlreadyShuttingDown;
fn shutdown_inner(
state: &DispatcherState,
) -> Result<(), Result<AlreadyShuttingDown, IdleShutdownError>> {
use ShutdownState::*;
let res = state.compare_exchange(Running, ShuttingDown);
match res {
Ok(_) => Ok(()),
Err(ShuttingDown) => Err(Ok(AlreadyShuttingDown)),
Err(Idle) => Err(Err(IdleShutdownError)),
Err(Running) => unreachable!(),
}
}