diff --git a/src/dispatching/update_listeners/polling.rs b/src/dispatching/update_listeners/polling.rs index e229be98..50f570ee 100644 --- a/src/dispatching/update_listeners/polling.rs +++ b/src/dispatching/update_listeners/polling.rs @@ -1,16 +1,22 @@ -use std::{convert::TryInto, time::Duration}; - -use futures::{ - future::{ready, Either}, - stream::{self, Stream, StreamExt}, +use std::{ + convert::TryInto, + future::Future, + pin::Pin, + task::{ + self, + Poll::{self, Ready}, + }, + time::Duration, + vec, }; +use futures::{ready, stream::Stream}; + use crate::{ dispatching::{ stop_token::{AsyncStopFlag, AsyncStopToken}, - update_listeners::{stateful_listener::StatefulListener, UpdateListener}, + update_listeners::{AsUpdateStream, UpdateListener}, }, - payloads::{GetUpdates, GetUpdatesSetters as _}, requests::{HasPayload, Request, Requester}, types::{AllowedUpdate, Update}, }; @@ -197,93 +203,8 @@ where R: Requester + Send + 'static, ::GetUpdates: Send, { - struct State { - bot: B, - timeout: Option, - limit: Option, - allowed_updates: Option>, - offset: i32, - flag: AsyncStopFlag, - token: AsyncStopToken, - force_stop: bool, - } - - fn stream(st: &mut State) -> impl Stream> + Send + '_ - where - B: Requester + Send, - ::GetUpdates: Send, - { - stream::unfold(st, move |state| async move { - let State { timeout, limit, allowed_updates, bot, offset, flag, force_stop, .. } = - &mut *state; - - if *force_stop { - return None; - } - - if flag.is_stopped() { - let mut req = bot.get_updates().offset(*offset).timeout(0).limit(1); - req.payload_mut().allowed_updates = allowed_updates.take(); - - return match req.send().await { - Ok(_) => None, - Err(err) => { - // Prevents infinite retries, see https://github.com/teloxide/teloxide/issues/496 - *force_stop = true; - - Some((Either::Left(stream::once(ready(Err(err)))), state)) - } - }; - } - - let mut req = bot.get_updates(); - *req.payload_mut() = GetUpdates { - offset: Some(*offset), - timeout: *timeout, - limit: *limit, - allowed_updates: allowed_updates.take(), - }; - - match req.send().await { - Ok(updates) => { - // Set offset to the last update's id + 1 - if let Some(upd) = updates.last() { - *offset = upd.id + 1; - } - - let updates = updates.into_iter().map(Ok); - Some((Either::Right(stream::iter(updates)), state)) - } - Err(err) => Some((Either::Left(stream::once(ready(Err(err)))), state)), - } - }) - .flatten() - } - let (token, flag) = AsyncStopToken::new_pair(); - - let state = State { - bot, - timeout: timeout.map(|t| t.as_secs().try_into().expect("timeout is too big")), - limit, - allowed_updates, - offset: 0, - flag, - token, - force_stop: false, - }; - - let stop_token = |st: &mut State<_>| st.token.clone(); - - let hint_allowed_updates = - Some(|state: &mut State<_>, allowed: &mut dyn Iterator| { - // TODO: we should probably warn if there already were different allowed updates - // before - state.allowed_updates = Some(allowed.collect()); - }); - let timeout_hint = Some(move |_: &State<_>| timeout); - - StatefulListener::new_with_hints(state, stream, stop_token, hint_allowed_updates, timeout_hint) + Polling { bot, timeout, limit, allowed_updates, flag, token } } async fn delete_webhook_if_setup(requester: &R) @@ -307,6 +228,143 @@ where } } +struct Polling { + bot: B, + timeout: Option, + limit: Option, + allowed_updates: Option>, + flag: AsyncStopFlag, + token: AsyncStopToken, +} + +#[pin_project::pin_project] +struct PollingStream<'a, B: Requester> { + /// Parent structure + polling: &'a mut Polling, + + /// Timeout parameter for normal `get_updates()` calls. + timeout: Option, + /// Allowed updates parameter for the first `get_updates()` call. + allowed_updates: Option>, + /// Offset parameter for normal `get_updates()` calls. + offset: i32, + + /// If this is set, return `None` from `poll_next` immediately. + force_stop: bool, + /// If true we've sent last `get_updates()` call for graceful shutdown. + stopping: bool, + + /// Buffer of updates to be yielded. + buffer: vec::IntoIter, + + /// In-flight `get_updates()` call. + #[pin] + in_flight: Option<::Send>, +} + +impl UpdateListener for Polling { + type StopToken = AsyncStopToken; + + fn stop_token(&mut self) -> Self::StopToken { + self.token.clone() + } + + fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator) { + // TODO: we should probably warn if there already were different allowed updates + // before + self.allowed_updates = Some(hint.collect()); + } + + fn timeout_hint(&self) -> Option { + self.timeout + } +} + +impl<'a, B: Requester + Send + 'a> AsUpdateStream<'a, B::Err> for Polling { + type Stream = PollingStream<'a, B>; + + fn as_stream(&'a mut self) -> Self::Stream { + let timeout = self.timeout.map(|t| t.as_secs().try_into().expect("timeout is too big")); + let allowed_updates = self.allowed_updates.clone(); + PollingStream { + polling: self, + timeout, + allowed_updates, + offset: 0, + force_stop: false, + stopping: false, + buffer: Vec::new().into_iter(), + in_flight: None, + } + } +} + +impl Stream for PollingStream<'_, B> { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + let mut this = self.as_mut().project(); + + if *this.force_stop { + return Ready(None); + } + + // Poll in-flight future until completion + if let Some(in_flight) = this.in_flight.as_mut().as_pin_mut() { + let res = ready!(in_flight.poll(cx)); + this.in_flight.set(None); + + match res { + Ok(_) if *this.stopping => return Ready(None), + Err(err) if *this.stopping => { + // Prevents infinite retries, see https://github.com/teloxide/teloxide/issues/496 + *this.force_stop = true; + + return Ready(Some(Err(err))); + } + Ok(updates) => { + if let Some(upd) = updates.last() { + *this.offset = upd.id + 1; + } + + *this.buffer = updates.into_iter(); + } + Err(err) => return Ready(Some(Err(err))), + } + } + + // If there are any buffered updates, return one + if let Some(upd) = this.buffer.next() { + return Ready(Some(Ok(upd))); + } + + // When stopping we set `timeout = 0` and `limit = 1` so that `get_updates()` + // set last seen update (offset) and return immediately + let (timeout, limit) = if this.polling.flag.is_stopped() { + *this.stopping = true; + (Some(0), Some(1)) + } else { + (*this.timeout, this.polling.limit) + }; + + let req = this + .polling + .bot + .get_updates() + .with_payload_mut(|pay| { + pay.offset = Some(*this.offset); + pay.timeout = timeout; + pay.limit = limit; + pay.allowed_updates = this.allowed_updates.take(); + }) + .send(); + this.in_flight.set(Some(req)); + + // Recurse to poll `self.in_flight` + self.poll_next(cx) + } +} + #[test] fn polling_is_send() { use crate::dispatching::update_listeners::AsUpdateStream;