diff --git a/crates/teloxide/src/update_listeners/polling.rs b/crates/teloxide/src/update_listeners/polling.rs index b713db37..62d63076 100644 --- a/crates/teloxide/src/update_listeners/polling.rs +++ b/crates/teloxide/src/update_listeners/polling.rs @@ -1,6 +1,7 @@ use std::{ convert::TryInto, future::Future, + mem, pin::Pin, task::{ self, @@ -97,8 +98,16 @@ where pub fn build(self) -> Polling { let Self { bot, timeout, limit, allowed_updates, drop_pending_updates } = self; let (token, flag) = mk_stop_token(); - let polling = - Polling { bot, timeout, limit, allowed_updates, drop_pending_updates, flag, token }; + let polling = Polling { + bot, + timeout, + limit, + allowed_updates, + drop_pending_updates, + flag: Some(flag), + token, + stop_token_cloned: false, + }; assert_update_listener(polling) } @@ -240,17 +249,21 @@ pub struct Polling { limit: Option, allowed_updates: Option>, drop_pending_updates: bool, - flag: StopFlag, + flag: Option, token: StopToken, + stop_token_cloned: bool, } impl Polling where - R: Requester + Send + 'static, - ::GetUpdates: Send, + R: Requester, { /// Returns a builder for polling update listener. - pub fn builder(bot: R) -> PollingBuilder { + pub fn builder(bot: R) -> PollingBuilder + where + R: Send + 'static, + ::GetUpdates: Send, + { PollingBuilder { bot, timeout: None, @@ -259,6 +272,19 @@ where drop_pending_updates: false, } } + + /// Returns true if re-initialization happened *and* + /// the previous token was cloned. + fn reinit_stop_flag_if_needed(&mut self) -> bool { + if self.flag.is_some() { + return false; + } + + let (token, flag) = mk_stop_token(); + self.token = token; + self.flag = Some(flag); + mem::replace(&mut self.stop_token_cloned, false) + } } #[pin_project::pin_project] @@ -287,12 +313,18 @@ pub struct PollingStream<'a, B: Requester> { /// In-flight `get_updates()` call. #[pin] in_flight: Option<::Send>, + + /// The flag that notifies polling to stop polling. + #[pin] + flag: StopFlag, } impl UpdateListener for Polling { type Err = B::Err; fn stop_token(&mut self) -> StopToken { + self.reinit_stop_flag_if_needed(); + self.stop_token_cloned = true; self.token.clone() } @@ -311,6 +343,21 @@ impl<'a, B: Requester + Send + 'a> AsUpdateStream<'a> for Polling { let timeout = self.timeout.map(|t| t.as_secs().try_into().expect("timeout is too big")); let allowed_updates = self.allowed_updates.clone(); let drop_pending_updates = self.drop_pending_updates; + + let token_used_and_updated = self.reinit_stop_flag_if_needed(); + + // FIXME: document that `as_stream` is a destructive operation, actually, + // and you need to call `stop_token` *again* after it + if token_used_and_updated { + panic!( + "detected calling `as_stream` a second time after calling `stop_token`. \ + `as_stream` updates the stop token, thus you need to call it again after calling \ + `as_stream`" + ) + } + + // Unwrap: just called reinit + let flag = self.flag.take().unwrap(); PollingStream { polling: self, drop_pending_updates, @@ -321,6 +368,7 @@ impl<'a, B: Requester + Send + 'a> AsUpdateStream<'a> for Polling { stopping: false, buffer: Vec::new().into_iter(), in_flight: None, + flag, } } } @@ -343,7 +391,10 @@ impl Stream for PollingStream<'_, B> { // Check if we should stop and if so — drop in flight request, // we don't care about updates that happened *after* we started stopping - if !*this.stopping && this.polling.flag.is_stopped() { + // + // N.B.: it's important to use `poll` and not `is_stopped` here, + // so that *this stream* is polled when the flag is set to stop + if !*this.stopping && matches!(this.flag.poll(cx), Poll::Ready(())) { *this.stopping = true; log::trace!("dropping in-flight request");