Polling: poll stop flag instead of checking, to make sure we wakeup immediately when stopping

This commit is contained in:
Maybe Waffle 2023-09-25 21:10:50 +04:00
parent 2b7eea2679
commit 6cc6c04192

View file

@ -1,6 +1,7 @@
use std::{ use std::{
convert::TryInto, convert::TryInto,
future::Future, future::Future,
mem,
pin::Pin, pin::Pin,
task::{ task::{
self, self,
@ -97,8 +98,16 @@ where
pub fn build(self) -> Polling<R> { pub fn build(self) -> Polling<R> {
let Self { bot, timeout, limit, allowed_updates, drop_pending_updates } = self; let Self { bot, timeout, limit, allowed_updates, drop_pending_updates } = self;
let (token, flag) = mk_stop_token(); let (token, flag) = mk_stop_token();
let polling = let polling = Polling {
Polling { bot, timeout, limit, allowed_updates, drop_pending_updates, flag, token }; bot,
timeout,
limit,
allowed_updates,
drop_pending_updates,
flag: Some(flag),
token,
stop_token_cloned: false,
};
assert_update_listener(polling) assert_update_listener(polling)
} }
@ -240,17 +249,21 @@ pub struct Polling<B: Requester> {
limit: Option<u8>, limit: Option<u8>,
allowed_updates: Option<Vec<AllowedUpdate>>, allowed_updates: Option<Vec<AllowedUpdate>>,
drop_pending_updates: bool, drop_pending_updates: bool,
flag: StopFlag, flag: Option<StopFlag>,
token: StopToken, token: StopToken,
stop_token_cloned: bool,
} }
impl<R> Polling<R> impl<R> Polling<R>
where where
R: Requester + Send + 'static, R: Requester,
<R as Requester>::GetUpdates: Send,
{ {
/// Returns a builder for polling update listener. /// Returns a builder for polling update listener.
pub fn builder(bot: R) -> PollingBuilder<R> { pub fn builder(bot: R) -> PollingBuilder<R>
where
R: Send + 'static,
<R as Requester>::GetUpdates: Send,
{
PollingBuilder { PollingBuilder {
bot, bot,
timeout: None, timeout: None,
@ -259,6 +272,19 @@ where
drop_pending_updates: false, drop_pending_updates: false,
} }
} }
/// Returns true if re-initialization happened *and*
/// the previous token was cloned.
fn reinit_stop_flag_if_needed(&mut self) -> bool {
if self.flag.is_some() {
return false;
}
let (token, flag) = mk_stop_token();
self.token = token;
self.flag = Some(flag);
mem::replace(&mut self.stop_token_cloned, false)
}
} }
#[pin_project::pin_project] #[pin_project::pin_project]
@ -287,12 +313,18 @@ pub struct PollingStream<'a, B: Requester> {
/// In-flight `get_updates()` call. /// In-flight `get_updates()` call.
#[pin] #[pin]
in_flight: Option<<B::GetUpdates as Request>::Send>, in_flight: Option<<B::GetUpdates as Request>::Send>,
/// The flag that notifies polling to stop polling.
#[pin]
flag: StopFlag,
} }
impl<B: Requester + Send + 'static> UpdateListener for Polling<B> { impl<B: Requester + Send + 'static> UpdateListener for Polling<B> {
type Err = B::Err; type Err = B::Err;
fn stop_token(&mut self) -> StopToken { fn stop_token(&mut self) -> StopToken {
self.reinit_stop_flag_if_needed();
self.stop_token_cloned = true;
self.token.clone() self.token.clone()
} }
@ -311,6 +343,21 @@ impl<'a, B: Requester + Send + 'a> AsUpdateStream<'a> for Polling<B> {
let timeout = self.timeout.map(|t| t.as_secs().try_into().expect("timeout is too big")); let timeout = self.timeout.map(|t| t.as_secs().try_into().expect("timeout is too big"));
let allowed_updates = self.allowed_updates.clone(); let allowed_updates = self.allowed_updates.clone();
let drop_pending_updates = self.drop_pending_updates; let drop_pending_updates = self.drop_pending_updates;
let token_used_and_updated = self.reinit_stop_flag_if_needed();
// FIXME: document that `as_stream` is a destructive operation, actually,
// and you need to call `stop_token` *again* after it
if token_used_and_updated {
panic!(
"detected calling `as_stream` a second time after calling `stop_token`. \
`as_stream` updates the stop token, thus you need to call it again after calling \
`as_stream`"
)
}
// Unwrap: just called reinit
let flag = self.flag.take().unwrap();
PollingStream { PollingStream {
polling: self, polling: self,
drop_pending_updates, drop_pending_updates,
@ -321,6 +368,7 @@ impl<'a, B: Requester + Send + 'a> AsUpdateStream<'a> for Polling<B> {
stopping: false, stopping: false,
buffer: Vec::new().into_iter(), buffer: Vec::new().into_iter(),
in_flight: None, in_flight: None,
flag,
} }
} }
} }
@ -343,7 +391,10 @@ impl<B: Requester> Stream for PollingStream<'_, B> {
// Check if we should stop and if so — drop in flight request, // Check if we should stop and if so — drop in flight request,
// we don't care about updates that happened *after* we started stopping // we don't care about updates that happened *after* we started stopping
if !*this.stopping && this.polling.flag.is_stopped() { //
// N.B.: it's important to use `poll` and not `is_stopped` here,
// so that *this stream* is polled when the flag is set to stop
if !*this.stopping && matches!(this.flag.poll(cx), Poll::Ready(())) {
*this.stopping = true; *this.stopping = true;
log::trace!("dropping in-flight request"); log::trace!("dropping in-flight request");