Implement polling stream by hand

IMO it's actually clearer & nicer than the old impl. +The types are now
nameable.


Former-commit-id: 82fc756aab
This commit is contained in:
Maybe Waffle 2022-06-27 02:03:34 +04:00
parent a839b47106
commit 2ceccdf442

View file

@ -1,16 +1,22 @@
use std::{convert::TryInto, time::Duration};
use futures::{
future::{ready, Either},
stream::{self, Stream, StreamExt},
use std::{
convert::TryInto,
future::Future,
pin::Pin,
task::{
self,
Poll::{self, Ready},
},
time::Duration,
vec,
};
use futures::{ready, stream::Stream};
use crate::{
dispatching::{
stop_token::{AsyncStopFlag, AsyncStopToken},
update_listeners::{stateful_listener::StatefulListener, UpdateListener},
update_listeners::{AsUpdateStream, UpdateListener},
},
payloads::{GetUpdates, GetUpdatesSetters as _},
requests::{HasPayload, Request, Requester},
types::{AllowedUpdate, Update},
};
@ -197,93 +203,8 @@ where
R: Requester + Send + 'static,
<R as Requester>::GetUpdates: Send,
{
struct State<B: Requester> {
bot: B,
timeout: Option<u32>,
limit: Option<u8>,
allowed_updates: Option<Vec<AllowedUpdate>>,
offset: i32,
flag: AsyncStopFlag,
token: AsyncStopToken,
force_stop: bool,
}
fn stream<B>(st: &mut State<B>) -> impl Stream<Item = Result<Update, B::Err>> + Send + '_
where
B: Requester + Send,
<B as Requester>::GetUpdates: Send,
{
stream::unfold(st, move |state| async move {
let State { timeout, limit, allowed_updates, bot, offset, flag, force_stop, .. } =
&mut *state;
if *force_stop {
return None;
}
if flag.is_stopped() {
let mut req = bot.get_updates().offset(*offset).timeout(0).limit(1);
req.payload_mut().allowed_updates = allowed_updates.take();
return match req.send().await {
Ok(_) => None,
Err(err) => {
// Prevents infinite retries, see https://github.com/teloxide/teloxide/issues/496
*force_stop = true;
Some((Either::Left(stream::once(ready(Err(err)))), state))
}
};
}
let mut req = bot.get_updates();
*req.payload_mut() = GetUpdates {
offset: Some(*offset),
timeout: *timeout,
limit: *limit,
allowed_updates: allowed_updates.take(),
};
match req.send().await {
Ok(updates) => {
// Set offset to the last update's id + 1
if let Some(upd) = updates.last() {
*offset = upd.id + 1;
}
let updates = updates.into_iter().map(Ok);
Some((Either::Right(stream::iter(updates)), state))
}
Err(err) => Some((Either::Left(stream::once(ready(Err(err)))), state)),
}
})
.flatten()
}
let (token, flag) = AsyncStopToken::new_pair();
let state = State {
bot,
timeout: timeout.map(|t| t.as_secs().try_into().expect("timeout is too big")),
limit,
allowed_updates,
offset: 0,
flag,
token,
force_stop: false,
};
let stop_token = |st: &mut State<_>| st.token.clone();
let hint_allowed_updates =
Some(|state: &mut State<_>, allowed: &mut dyn Iterator<Item = AllowedUpdate>| {
// TODO: we should probably warn if there already were different allowed updates
// before
state.allowed_updates = Some(allowed.collect());
});
let timeout_hint = Some(move |_: &State<_>| timeout);
StatefulListener::new_with_hints(state, stream, stop_token, hint_allowed_updates, timeout_hint)
Polling { bot, timeout, limit, allowed_updates, flag, token }
}
async fn delete_webhook_if_setup<R>(requester: &R)
@ -307,6 +228,143 @@ where
}
}
struct Polling<B: Requester> {
bot: B,
timeout: Option<Duration>,
limit: Option<u8>,
allowed_updates: Option<Vec<AllowedUpdate>>,
flag: AsyncStopFlag,
token: AsyncStopToken,
}
#[pin_project::pin_project]
struct PollingStream<'a, B: Requester> {
/// Parent structure
polling: &'a mut Polling<B>,
/// Timeout parameter for normal `get_updates()` calls.
timeout: Option<u32>,
/// Allowed updates parameter for the first `get_updates()` call.
allowed_updates: Option<Vec<AllowedUpdate>>,
/// Offset parameter for normal `get_updates()` calls.
offset: i32,
/// If this is set, return `None` from `poll_next` immediately.
force_stop: bool,
/// If true we've sent last `get_updates()` call for graceful shutdown.
stopping: bool,
/// Buffer of updates to be yielded.
buffer: vec::IntoIter<Update>,
/// In-flight `get_updates()` call.
#[pin]
in_flight: Option<<B::GetUpdates as Request>::Send>,
}
impl<B: Requester + Send + 'static> UpdateListener<B::Err> for Polling<B> {
type StopToken = AsyncStopToken;
fn stop_token(&mut self) -> Self::StopToken {
self.token.clone()
}
fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator<Item = AllowedUpdate>) {
// TODO: we should probably warn if there already were different allowed updates
// before
self.allowed_updates = Some(hint.collect());
}
fn timeout_hint(&self) -> Option<Duration> {
self.timeout
}
}
impl<'a, B: Requester + Send + 'a> AsUpdateStream<'a, B::Err> for Polling<B> {
type Stream = PollingStream<'a, B>;
fn as_stream(&'a mut self) -> Self::Stream {
let timeout = self.timeout.map(|t| t.as_secs().try_into().expect("timeout is too big"));
let allowed_updates = self.allowed_updates.clone();
PollingStream {
polling: self,
timeout,
allowed_updates,
offset: 0,
force_stop: false,
stopping: false,
buffer: Vec::new().into_iter(),
in_flight: None,
}
}
}
impl<B: Requester> Stream for PollingStream<'_, B> {
type Item = Result<Update, B::Err>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.as_mut().project();
if *this.force_stop {
return Ready(None);
}
// Poll in-flight future until completion
if let Some(in_flight) = this.in_flight.as_mut().as_pin_mut() {
let res = ready!(in_flight.poll(cx));
this.in_flight.set(None);
match res {
Ok(_) if *this.stopping => return Ready(None),
Err(err) if *this.stopping => {
// Prevents infinite retries, see https://github.com/teloxide/teloxide/issues/496
*this.force_stop = true;
return Ready(Some(Err(err)));
}
Ok(updates) => {
if let Some(upd) = updates.last() {
*this.offset = upd.id + 1;
}
*this.buffer = updates.into_iter();
}
Err(err) => return Ready(Some(Err(err))),
}
}
// If there are any buffered updates, return one
if let Some(upd) = this.buffer.next() {
return Ready(Some(Ok(upd)));
}
// When stopping we set `timeout = 0` and `limit = 1` so that `get_updates()`
// set last seen update (offset) and return immediately
let (timeout, limit) = if this.polling.flag.is_stopped() {
*this.stopping = true;
(Some(0), Some(1))
} else {
(*this.timeout, this.polling.limit)
};
let req = this
.polling
.bot
.get_updates()
.with_payload_mut(|pay| {
pay.offset = Some(*this.offset);
pay.timeout = timeout;
pay.limit = limit;
pay.allowed_updates = this.allowed_updates.take();
})
.send();
this.in_flight.set(Some(req));
// Recurse to poll `self.in_flight`
self.poll_next(cx)
}
}
#[test]
fn polling_is_send() {
use crate::dispatching::update_listeners::AsUpdateStream;