mirror of
https://github.com/teloxide/teloxide.git
synced 2025-01-08 19:33:53 +01:00
Merge pull request #938 from teloxide/⚠️polling-in-flight⚠️
Improve graceful shutdown
This commit is contained in:
commit
d21ca11a54
6 changed files with 115 additions and 84 deletions
|
@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
- Add `MessageToCopyNotFound` error to `teloxide::errors::ApiError` ([PR 917](https://github.com/teloxide/teloxide/pull/917))
|
- Add `MessageToCopyNotFound` error to `teloxide::errors::ApiError` ([PR 917](https://github.com/teloxide/teloxide/pull/917))
|
||||||
### Fixed
|
### Fixed
|
||||||
- Use `UserId` instead of `i64` for `user_id` in `html::user_mention` and `markdown::user_mention` ([PR 896](https://github.com/teloxide/teloxide/pull/896))
|
- Use `UserId` instead of `i64` for `user_id` in `html::user_mention` and `markdown::user_mention` ([PR 896](https://github.com/teloxide/teloxide/pull/896))
|
||||||
|
- Greatly improved the speed of graceful shutdown (`^C`) ([PR 938](https://github.com/teloxide/teloxide/pull/938))
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
|
||||||
|
- `UpdateListener::timeout_hint` and related APIs ([PR 938](https://github.com/teloxide/teloxide/pull/938))
|
||||||
|
|
||||||
## 0.12.2 - 2023-02-15
|
## 0.12.2 - 2023-02-15
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,10 @@ use crate::{
|
||||||
requests::{Request, Requester},
|
requests::{Request, Requester},
|
||||||
types::{Update, UpdateKind},
|
types::{Update, UpdateKind},
|
||||||
update_listeners::{self, UpdateListener},
|
update_listeners::{self, UpdateListener},
|
||||||
utils::shutdown_token::shutdown_check_timeout_for,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use dptree::di::{DependencyMap, DependencySupplier};
|
use dptree::di::{DependencyMap, DependencySupplier};
|
||||||
use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
|
use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
|
||||||
use tokio::time::timeout;
|
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
|
@ -312,7 +310,6 @@ where
|
||||||
log::debug!("hinting allowed updates: {:?}", allowed_updates);
|
log::debug!("hinting allowed updates: {:?}", allowed_updates);
|
||||||
update_listener.hint_allowed_updates(&mut allowed_updates.into_iter());
|
update_listener.hint_allowed_updates(&mut allowed_updates.into_iter());
|
||||||
|
|
||||||
let shutdown_check_timeout = shutdown_check_timeout_for(&update_listener);
|
|
||||||
let mut stop_token = Some(update_listener.stop_token());
|
let mut stop_token = Some(update_listener.stop_token());
|
||||||
|
|
||||||
self.state.start_dispatching();
|
self.state.start_dispatching();
|
||||||
|
@ -324,19 +321,16 @@ where
|
||||||
loop {
|
loop {
|
||||||
self.remove_inactive_workers_if_needed().await;
|
self.remove_inactive_workers_if_needed().await;
|
||||||
|
|
||||||
// False positive
|
tokio::select! {
|
||||||
#[allow(clippy::collapsible_match)]
|
upd = stream.next() => match upd {
|
||||||
if let Ok(upd) = timeout(shutdown_check_timeout, stream.next()).await {
|
|
||||||
match upd {
|
|
||||||
None => break,
|
None => break,
|
||||||
Some(upd) => self.process_update(upd, &update_listener_error_handler).await,
|
Some(upd) => self.process_update(upd, &update_listener_error_handler).await,
|
||||||
}
|
},
|
||||||
}
|
() = self.state.wait_for_changes() => if self.state.is_shutting_down() {
|
||||||
|
if let Some(token) = stop_token.take() {
|
||||||
if self.state.is_shutting_down() {
|
log::debug!("Start shutting down dispatching...");
|
||||||
if let Some(token) = stop_token.take() {
|
token.stop();
|
||||||
log::debug!("Start shutting down dispatching...");
|
}
|
||||||
token.stop();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,8 +32,6 @@ pub mod webhooks;
|
||||||
|
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
stop::StopToken,
|
stop::StopToken,
|
||||||
types::{AllowedUpdate, Update},
|
types::{AllowedUpdate, Update},
|
||||||
|
@ -94,19 +92,6 @@ pub trait UpdateListener:
|
||||||
fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator<Item = AllowedUpdate>) {
|
fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator<Item = AllowedUpdate>) {
|
||||||
let _ = hint;
|
let _ = hint;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The timeout duration hint.
|
|
||||||
///
|
|
||||||
/// This hints how often dispatcher should check for a shutdown. E.g., for
|
|
||||||
/// [`polling()`] this returns the [`timeout`].
|
|
||||||
///
|
|
||||||
/// [`timeout`]: crate::payloads::GetUpdates::timeout
|
|
||||||
///
|
|
||||||
/// If you are implementing this trait and not sure what to return from this
|
|
||||||
/// function, just leave it with the default implementation.
|
|
||||||
fn timeout_hint(&self) -> Option<Duration> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// [`UpdateListener`]'s supertrait/extension.
|
/// [`UpdateListener`]'s supertrait/extension.
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use std::{
|
use std::{
|
||||||
convert::TryInto,
|
convert::TryInto,
|
||||||
future::Future,
|
future::Future,
|
||||||
|
mem,
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
task::{
|
task::{
|
||||||
self,
|
self,
|
||||||
|
@ -97,8 +98,16 @@ where
|
||||||
pub fn build(self) -> Polling<R> {
|
pub fn build(self) -> Polling<R> {
|
||||||
let Self { bot, timeout, limit, allowed_updates, drop_pending_updates } = self;
|
let Self { bot, timeout, limit, allowed_updates, drop_pending_updates } = self;
|
||||||
let (token, flag) = mk_stop_token();
|
let (token, flag) = mk_stop_token();
|
||||||
let polling =
|
let polling = Polling {
|
||||||
Polling { bot, timeout, limit, allowed_updates, drop_pending_updates, flag, token };
|
bot,
|
||||||
|
timeout,
|
||||||
|
limit,
|
||||||
|
allowed_updates,
|
||||||
|
drop_pending_updates,
|
||||||
|
flag: Some(flag),
|
||||||
|
token,
|
||||||
|
stop_token_cloned: false,
|
||||||
|
};
|
||||||
|
|
||||||
assert_update_listener(polling)
|
assert_update_listener(polling)
|
||||||
}
|
}
|
||||||
|
@ -240,17 +249,21 @@ pub struct Polling<B: Requester> {
|
||||||
limit: Option<u8>,
|
limit: Option<u8>,
|
||||||
allowed_updates: Option<Vec<AllowedUpdate>>,
|
allowed_updates: Option<Vec<AllowedUpdate>>,
|
||||||
drop_pending_updates: bool,
|
drop_pending_updates: bool,
|
||||||
flag: StopFlag,
|
flag: Option<StopFlag>,
|
||||||
token: StopToken,
|
token: StopToken,
|
||||||
|
stop_token_cloned: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> Polling<R>
|
impl<R> Polling<R>
|
||||||
where
|
where
|
||||||
R: Requester + Send + 'static,
|
R: Requester,
|
||||||
<R as Requester>::GetUpdates: Send,
|
|
||||||
{
|
{
|
||||||
/// Returns a builder for polling update listener.
|
/// Returns a builder for polling update listener.
|
||||||
pub fn builder(bot: R) -> PollingBuilder<R> {
|
pub fn builder(bot: R) -> PollingBuilder<R>
|
||||||
|
where
|
||||||
|
R: Send + 'static,
|
||||||
|
<R as Requester>::GetUpdates: Send,
|
||||||
|
{
|
||||||
PollingBuilder {
|
PollingBuilder {
|
||||||
bot,
|
bot,
|
||||||
timeout: None,
|
timeout: None,
|
||||||
|
@ -259,6 +272,19 @@ where
|
||||||
drop_pending_updates: false,
|
drop_pending_updates: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns true if re-initialization happened *and*
|
||||||
|
/// the previous token was cloned.
|
||||||
|
fn reinit_stop_flag_if_needed(&mut self) -> bool {
|
||||||
|
if self.flag.is_some() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let (token, flag) = mk_stop_token();
|
||||||
|
self.token = token;
|
||||||
|
self.flag = Some(flag);
|
||||||
|
mem::replace(&mut self.stop_token_cloned, false)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pin_project::pin_project]
|
#[pin_project::pin_project]
|
||||||
|
@ -287,12 +313,18 @@ pub struct PollingStream<'a, B: Requester> {
|
||||||
/// In-flight `get_updates()` call.
|
/// In-flight `get_updates()` call.
|
||||||
#[pin]
|
#[pin]
|
||||||
in_flight: Option<<B::GetUpdates as Request>::Send>,
|
in_flight: Option<<B::GetUpdates as Request>::Send>,
|
||||||
|
|
||||||
|
/// The flag that notifies polling to stop polling.
|
||||||
|
#[pin]
|
||||||
|
flag: StopFlag,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Requester + Send + 'static> UpdateListener for Polling<B> {
|
impl<B: Requester + Send + 'static> UpdateListener for Polling<B> {
|
||||||
type Err = B::Err;
|
type Err = B::Err;
|
||||||
|
|
||||||
fn stop_token(&mut self) -> StopToken {
|
fn stop_token(&mut self) -> StopToken {
|
||||||
|
self.reinit_stop_flag_if_needed();
|
||||||
|
self.stop_token_cloned = true;
|
||||||
self.token.clone()
|
self.token.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -301,10 +333,6 @@ impl<B: Requester + Send + 'static> UpdateListener for Polling<B> {
|
||||||
// before
|
// before
|
||||||
self.allowed_updates = Some(hint.collect());
|
self.allowed_updates = Some(hint.collect());
|
||||||
}
|
}
|
||||||
|
|
||||||
fn timeout_hint(&self) -> Option<Duration> {
|
|
||||||
self.timeout
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, B: Requester + Send + 'a> AsUpdateStream<'a> for Polling<B> {
|
impl<'a, B: Requester + Send + 'a> AsUpdateStream<'a> for Polling<B> {
|
||||||
|
@ -315,6 +343,21 @@ impl<'a, B: Requester + Send + 'a> AsUpdateStream<'a> for Polling<B> {
|
||||||
let timeout = self.timeout.map(|t| t.as_secs().try_into().expect("timeout is too big"));
|
let timeout = self.timeout.map(|t| t.as_secs().try_into().expect("timeout is too big"));
|
||||||
let allowed_updates = self.allowed_updates.clone();
|
let allowed_updates = self.allowed_updates.clone();
|
||||||
let drop_pending_updates = self.drop_pending_updates;
|
let drop_pending_updates = self.drop_pending_updates;
|
||||||
|
|
||||||
|
let token_used_and_updated = self.reinit_stop_flag_if_needed();
|
||||||
|
|
||||||
|
// FIXME: document that `as_stream` is a destructive operation, actually,
|
||||||
|
// and you need to call `stop_token` *again* after it
|
||||||
|
if token_used_and_updated {
|
||||||
|
panic!(
|
||||||
|
"detected calling `as_stream` a second time after calling `stop_token`. \
|
||||||
|
`as_stream` updates the stop token, thus you need to call it again after calling \
|
||||||
|
`as_stream`"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap: just called reinit
|
||||||
|
let flag = self.flag.take().unwrap();
|
||||||
PollingStream {
|
PollingStream {
|
||||||
polling: self,
|
polling: self,
|
||||||
drop_pending_updates,
|
drop_pending_updates,
|
||||||
|
@ -325,6 +368,7 @@ impl<'a, B: Requester + Send + 'a> AsUpdateStream<'a> for Polling<B> {
|
||||||
stopping: false,
|
stopping: false,
|
||||||
buffer: Vec::new().into_iter(),
|
buffer: Vec::new().into_iter(),
|
||||||
in_flight: None,
|
in_flight: None,
|
||||||
|
flag,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -333,15 +377,33 @@ impl<B: Requester> Stream for PollingStream<'_, B> {
|
||||||
type Item = Result<Update, B::Err>;
|
type Item = Result<Update, B::Err>;
|
||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
log::trace!("polling polling stream");
|
||||||
let mut this = self.as_mut().project();
|
let mut this = self.as_mut().project();
|
||||||
|
|
||||||
if *this.force_stop {
|
if *this.force_stop {
|
||||||
return Ready(None);
|
return Ready(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If there are any buffered updates, return one
|
||||||
|
if let Some(upd) = this.buffer.next() {
|
||||||
|
return Ready(Some(Ok(upd)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we should stop and if so — drop in flight request,
|
||||||
|
// we don't care about updates that happened *after* we started stopping
|
||||||
|
//
|
||||||
|
// N.B.: it's important to use `poll` and not `is_stopped` here,
|
||||||
|
// so that *this stream* is polled when the flag is set to stop
|
||||||
|
if !*this.stopping && matches!(this.flag.poll(cx), Poll::Ready(())) {
|
||||||
|
*this.stopping = true;
|
||||||
|
|
||||||
|
log::trace!("dropping in-flight request");
|
||||||
|
this.in_flight.set(None);
|
||||||
|
}
|
||||||
// Poll in-flight future until completion
|
// Poll in-flight future until completion
|
||||||
if let Some(in_flight) = this.in_flight.as_mut().as_pin_mut() {
|
else if let Some(in_flight) = this.in_flight.as_mut().as_pin_mut() {
|
||||||
let res = ready!(in_flight.poll(cx));
|
let res = ready!(in_flight.poll(cx));
|
||||||
|
log::trace!("in-flight request completed");
|
||||||
this.in_flight.set(None);
|
this.in_flight.set(None);
|
||||||
|
|
||||||
match res {
|
match res {
|
||||||
|
@ -366,12 +428,6 @@ impl<B: Requester> Stream for PollingStream<'_, B> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are any buffered updates, return one
|
|
||||||
if let Some(upd) = this.buffer.next() {
|
|
||||||
return Ready(Some(Ok(upd)));
|
|
||||||
}
|
|
||||||
|
|
||||||
*this.stopping = this.polling.flag.is_stopped();
|
|
||||||
let (offset, limit, timeout) = match (this.stopping, this.drop_pending_updates) {
|
let (offset, limit, timeout) = match (this.stopping, this.drop_pending_updates) {
|
||||||
// Normal `get_updates()` call
|
// Normal `get_updates()` call
|
||||||
(false, false) => (*this.offset, this.polling.limit, *this.timeout),
|
(false, false) => (*this.offset, this.polling.limit, *this.timeout),
|
||||||
|
@ -380,7 +436,10 @@ impl<B: Requester> Stream for PollingStream<'_, B> {
|
||||||
//
|
//
|
||||||
// When stopping we set `timeout = 0` and `limit = 1` so that `get_updates()`
|
// When stopping we set `timeout = 0` and `limit = 1` so that `get_updates()`
|
||||||
// set last seen update (offset) and return immediately
|
// set last seen update (offset) and return immediately
|
||||||
(true, _) => (*this.offset, Some(1), Some(0)),
|
(true, _) => {
|
||||||
|
log::trace!("graceful shutdown `get_updates` call");
|
||||||
|
(*this.offset, Some(1), Some(0))
|
||||||
|
}
|
||||||
// Drop pending updates
|
// Drop pending updates
|
||||||
(_, true) => (-1, Some(1), Some(0)),
|
(_, true) => (-1, Some(1), Some(0)),
|
||||||
};
|
};
|
||||||
|
@ -398,8 +457,10 @@ impl<B: Requester> Stream for PollingStream<'_, B> {
|
||||||
.send();
|
.send();
|
||||||
this.in_flight.set(Some(req));
|
this.in_flight.set(Some(req));
|
||||||
|
|
||||||
// Recurse to poll `self.in_flight`
|
// Immediately wake up to poll `self.in_flight`
|
||||||
self.poll_next(cx)
|
// (without this this stream becomes a zombie)
|
||||||
|
cx.waker().wake_by_ref();
|
||||||
|
Poll::Pending
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -17,7 +15,7 @@ use crate::{
|
||||||
///
|
///
|
||||||
/// [`polling`]: crate::update_listeners::polling()
|
/// [`polling`]: crate::update_listeners::polling()
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
pub struct StatefulListener<St, Assf, Sf, Hauf, Thf> {
|
pub struct StatefulListener<St, Assf, Sf, Hauf> {
|
||||||
/// The state of the listener.
|
/// The state of the listener.
|
||||||
pub state: St,
|
pub state: St,
|
||||||
|
|
||||||
|
@ -36,38 +34,30 @@ pub struct StatefulListener<St, Assf, Sf, Hauf, Thf> {
|
||||||
/// Must implement `FnMut(&mut St, &mut dyn Iterator<Item =
|
/// Must implement `FnMut(&mut St, &mut dyn Iterator<Item =
|
||||||
/// AllowedUpdate>)`.
|
/// AllowedUpdate>)`.
|
||||||
pub hint_allowed_updates: Option<Hauf>,
|
pub hint_allowed_updates: Option<Hauf>,
|
||||||
|
|
||||||
/// The function used as [`UpdateListener::timeout_hint`].
|
|
||||||
///
|
|
||||||
/// Must implement `Fn(&St) -> Option<Duration>`.
|
|
||||||
pub timeout_hint: Option<Thf>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Haufn<State> = for<'a, 'b> fn(&'a mut State, &'b mut dyn Iterator<Item = AllowedUpdate>);
|
type Haufn<State> = for<'a, 'b> fn(&'a mut State, &'b mut dyn Iterator<Item = AllowedUpdate>);
|
||||||
type Thfn<State> = for<'a> fn(&'a State) -> Option<Duration>;
|
|
||||||
|
|
||||||
impl<St, Assf, Sf> StatefulListener<St, Assf, Sf, Haufn<St>, Thfn<St>> {
|
impl<St, Assf, Sf> StatefulListener<St, Assf, Sf, Haufn<St>> {
|
||||||
/// Creates a new stateful listener from its components.
|
/// Creates a new stateful listener from its components.
|
||||||
pub fn new(state: St, stream: Assf, stop_token: Sf) -> Self {
|
pub fn new(state: St, stream: Assf, stop_token: Sf) -> Self {
|
||||||
Self::new_with_hints(state, stream, stop_token, None, None)
|
Self::new_with_hints(state, stream, stop_token, None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<St, Assf, Sf, Hauf, Thf> StatefulListener<St, Assf, Sf, Hauf, Thf> {
|
impl<St, Assf, Sf, Hauf> StatefulListener<St, Assf, Sf, Hauf> {
|
||||||
/// Creates a new stateful listener from its components.
|
/// Creates a new stateful listener from its components.
|
||||||
pub fn new_with_hints(
|
pub fn new_with_hints(
|
||||||
state: St,
|
state: St,
|
||||||
stream: Assf,
|
stream: Assf,
|
||||||
stop_token: Sf,
|
stop_token: Sf,
|
||||||
hint_allowed_updates: Option<Hauf>,
|
hint_allowed_updates: Option<Hauf>,
|
||||||
timeout_hint: Option<Thf>,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self { state, stream, stop_token, hint_allowed_updates, timeout_hint }
|
Self { state, stream, stop_token, hint_allowed_updates }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, St, Assf, Sf, Hauf, Thf, Strm, E> AsUpdateStream<'a>
|
impl<'a, St, Assf, Sf, Hauf, Strm, E> AsUpdateStream<'a> for StatefulListener<St, Assf, Hauf, Sf>
|
||||||
for StatefulListener<St, Assf, Hauf, Sf, Thf>
|
|
||||||
where
|
where
|
||||||
(St, Strm): 'a,
|
(St, Strm): 'a,
|
||||||
Strm: Send,
|
Strm: Send,
|
||||||
|
@ -82,12 +72,11 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<St, Assf, Sf, Hauf, Thf, E> UpdateListener for StatefulListener<St, Assf, Sf, Hauf, Thf>
|
impl<St, Assf, Sf, Hauf, E> UpdateListener for StatefulListener<St, Assf, Sf, Hauf>
|
||||||
where
|
where
|
||||||
Self: for<'a> AsUpdateStream<'a, StreamErr = E>,
|
Self: for<'a> AsUpdateStream<'a, StreamErr = E>,
|
||||||
Sf: FnMut(&mut St) -> StopToken,
|
Sf: FnMut(&mut St) -> StopToken,
|
||||||
Hauf: FnMut(&mut St, &mut dyn Iterator<Item = AllowedUpdate>),
|
Hauf: FnMut(&mut St, &mut dyn Iterator<Item = AllowedUpdate>),
|
||||||
Thf: Fn(&St) -> Option<Duration>,
|
|
||||||
{
|
{
|
||||||
type Err = E;
|
type Err = E;
|
||||||
|
|
||||||
|
@ -100,8 +89,4 @@ where
|
||||||
f(&mut self.state, hint);
|
f(&mut self.state, hint);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn timeout_hint(&self) -> Option<Duration> {
|
|
||||||
self.timeout_hint.as_ref().and_then(|f| f(&self.state))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,18 +5,16 @@ use std::{
|
||||||
atomic::{AtomicU8, Ordering},
|
atomic::{AtomicU8, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
},
|
},
|
||||||
time::Duration,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use tokio::sync::Notify;
|
use tokio::sync::Notify;
|
||||||
|
|
||||||
use crate::update_listeners::UpdateListener;
|
|
||||||
|
|
||||||
/// A token which used to shutdown [`Dispatcher`].
|
/// A token which used to shutdown [`Dispatcher`].
|
||||||
///
|
///
|
||||||
/// [`Dispatcher`]: crate::dispatching::Dispatcher
|
/// [`Dispatcher`]: crate::dispatching::Dispatcher
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ShutdownToken {
|
pub struct ShutdownToken {
|
||||||
|
// FIXME: use a single arc
|
||||||
dispatcher_state: Arc<DispatcherState>,
|
dispatcher_state: Arc<DispatcherState>,
|
||||||
shutdown_notify_back: Arc<Notify>,
|
shutdown_notify_back: Arc<Notify>,
|
||||||
}
|
}
|
||||||
|
@ -49,11 +47,16 @@ impl ShutdownToken {
|
||||||
Self {
|
Self {
|
||||||
dispatcher_state: Arc::new(DispatcherState {
|
dispatcher_state: Arc::new(DispatcherState {
|
||||||
inner: AtomicU8::new(ShutdownState::Idle as _),
|
inner: AtomicU8::new(ShutdownState::Idle as _),
|
||||||
|
notify: <_>::default(),
|
||||||
}),
|
}),
|
||||||
shutdown_notify_back: <_>::default(),
|
shutdown_notify_back: <_>::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn wait_for_changes(&self) {
|
||||||
|
self.dispatcher_state.notify.notified().await;
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn start_dispatching(&self) {
|
pub(crate) fn start_dispatching(&self) {
|
||||||
if let Err(actual) =
|
if let Err(actual) =
|
||||||
self.dispatcher_state.compare_exchange(ShutdownState::Idle, ShutdownState::Running)
|
self.dispatcher_state.compare_exchange(ShutdownState::Idle, ShutdownState::Running)
|
||||||
|
@ -93,27 +96,20 @@ impl fmt::Display for IdleShutdownError {
|
||||||
|
|
||||||
impl std::error::Error for IdleShutdownError {}
|
impl std::error::Error for IdleShutdownError {}
|
||||||
|
|
||||||
pub(crate) fn shutdown_check_timeout_for(update_listener: &impl UpdateListener) -> Duration {
|
|
||||||
const MIN_SHUTDOWN_CHECK_TIMEOUT: Duration = Duration::from_secs(1);
|
|
||||||
const DZERO: Duration = Duration::ZERO;
|
|
||||||
|
|
||||||
let shutdown_check_timeout = update_listener.timeout_hint().unwrap_or(DZERO);
|
|
||||||
shutdown_check_timeout.saturating_add(MIN_SHUTDOWN_CHECK_TIMEOUT)
|
|
||||||
}
|
|
||||||
|
|
||||||
struct DispatcherState {
|
struct DispatcherState {
|
||||||
inner: AtomicU8,
|
inner: AtomicU8,
|
||||||
|
notify: Notify,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DispatcherState {
|
impl DispatcherState {
|
||||||
// Ordering::Relaxed: only one atomic variable, nothing to synchronize.
|
// Ordering::Relaxed: only one atomic variable, nothing to synchronize.
|
||||||
|
|
||||||
fn load(&self) -> ShutdownState {
|
fn load(&self) -> ShutdownState {
|
||||||
ShutdownState::from_u8(self.inner.load(Ordering::Relaxed))
|
ShutdownState::from_u8(self.inner.load(Ordering::Relaxed))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn store(&self, new: ShutdownState) {
|
fn store(&self, new: ShutdownState) {
|
||||||
self.inner.store(new as _, Ordering::Relaxed)
|
self.inner.store(new as _, Ordering::Relaxed);
|
||||||
|
self.notify.notify_waiters();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compare_exchange(
|
fn compare_exchange(
|
||||||
|
@ -125,6 +121,11 @@ impl DispatcherState {
|
||||||
.compare_exchange(current as _, new as _, Ordering::Relaxed, Ordering::Relaxed)
|
.compare_exchange(current as _, new as _, Ordering::Relaxed, Ordering::Relaxed)
|
||||||
.map(ShutdownState::from_u8)
|
.map(ShutdownState::from_u8)
|
||||||
.map_err(ShutdownState::from_u8)
|
.map_err(ShutdownState::from_u8)
|
||||||
|
// FIXME: `Result::inspect` when :(
|
||||||
|
.map(|st| {
|
||||||
|
self.notify.notify_waiters();
|
||||||
|
st
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue