mirror of
https://github.com/teloxide/teloxide.git
synced 2024-12-22 14:35:36 +01:00
Implement polling stream by hand
IMO it's actually clearer & nicer than the old impl. +The types are now
nameable.
Former-commit-id: 82fc756aab
This commit is contained in:
parent
a839b47106
commit
2ceccdf442
1 changed files with 151 additions and 93 deletions
|
@ -1,16 +1,22 @@
|
|||
use std::{convert::TryInto, time::Duration};
|
||||
|
||||
use futures::{
|
||||
future::{ready, Either},
|
||||
stream::{self, Stream, StreamExt},
|
||||
use std::{
|
||||
convert::TryInto,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{
|
||||
self,
|
||||
Poll::{self, Ready},
|
||||
},
|
||||
time::Duration,
|
||||
vec,
|
||||
};
|
||||
|
||||
use futures::{ready, stream::Stream};
|
||||
|
||||
use crate::{
|
||||
dispatching::{
|
||||
stop_token::{AsyncStopFlag, AsyncStopToken},
|
||||
update_listeners::{stateful_listener::StatefulListener, UpdateListener},
|
||||
update_listeners::{AsUpdateStream, UpdateListener},
|
||||
},
|
||||
payloads::{GetUpdates, GetUpdatesSetters as _},
|
||||
requests::{HasPayload, Request, Requester},
|
||||
types::{AllowedUpdate, Update},
|
||||
};
|
||||
|
@ -197,93 +203,8 @@ where
|
|||
R: Requester + Send + 'static,
|
||||
<R as Requester>::GetUpdates: Send,
|
||||
{
|
||||
struct State<B: Requester> {
|
||||
bot: B,
|
||||
timeout: Option<u32>,
|
||||
limit: Option<u8>,
|
||||
allowed_updates: Option<Vec<AllowedUpdate>>,
|
||||
offset: i32,
|
||||
flag: AsyncStopFlag,
|
||||
token: AsyncStopToken,
|
||||
force_stop: bool,
|
||||
}
|
||||
|
||||
fn stream<B>(st: &mut State<B>) -> impl Stream<Item = Result<Update, B::Err>> + Send + '_
|
||||
where
|
||||
B: Requester + Send,
|
||||
<B as Requester>::GetUpdates: Send,
|
||||
{
|
||||
stream::unfold(st, move |state| async move {
|
||||
let State { timeout, limit, allowed_updates, bot, offset, flag, force_stop, .. } =
|
||||
&mut *state;
|
||||
|
||||
if *force_stop {
|
||||
return None;
|
||||
}
|
||||
|
||||
if flag.is_stopped() {
|
||||
let mut req = bot.get_updates().offset(*offset).timeout(0).limit(1);
|
||||
req.payload_mut().allowed_updates = allowed_updates.take();
|
||||
|
||||
return match req.send().await {
|
||||
Ok(_) => None,
|
||||
Err(err) => {
|
||||
// Prevents infinite retries, see https://github.com/teloxide/teloxide/issues/496
|
||||
*force_stop = true;
|
||||
|
||||
Some((Either::Left(stream::once(ready(Err(err)))), state))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let mut req = bot.get_updates();
|
||||
*req.payload_mut() = GetUpdates {
|
||||
offset: Some(*offset),
|
||||
timeout: *timeout,
|
||||
limit: *limit,
|
||||
allowed_updates: allowed_updates.take(),
|
||||
};
|
||||
|
||||
match req.send().await {
|
||||
Ok(updates) => {
|
||||
// Set offset to the last update's id + 1
|
||||
if let Some(upd) = updates.last() {
|
||||
*offset = upd.id + 1;
|
||||
}
|
||||
|
||||
let updates = updates.into_iter().map(Ok);
|
||||
Some((Either::Right(stream::iter(updates)), state))
|
||||
}
|
||||
Err(err) => Some((Either::Left(stream::once(ready(Err(err)))), state)),
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
}
|
||||
|
||||
let (token, flag) = AsyncStopToken::new_pair();
|
||||
|
||||
let state = State {
|
||||
bot,
|
||||
timeout: timeout.map(|t| t.as_secs().try_into().expect("timeout is too big")),
|
||||
limit,
|
||||
allowed_updates,
|
||||
offset: 0,
|
||||
flag,
|
||||
token,
|
||||
force_stop: false,
|
||||
};
|
||||
|
||||
let stop_token = |st: &mut State<_>| st.token.clone();
|
||||
|
||||
let hint_allowed_updates =
|
||||
Some(|state: &mut State<_>, allowed: &mut dyn Iterator<Item = AllowedUpdate>| {
|
||||
// TODO: we should probably warn if there already were different allowed updates
|
||||
// before
|
||||
state.allowed_updates = Some(allowed.collect());
|
||||
});
|
||||
let timeout_hint = Some(move |_: &State<_>| timeout);
|
||||
|
||||
StatefulListener::new_with_hints(state, stream, stop_token, hint_allowed_updates, timeout_hint)
|
||||
Polling { bot, timeout, limit, allowed_updates, flag, token }
|
||||
}
|
||||
|
||||
async fn delete_webhook_if_setup<R>(requester: &R)
|
||||
|
@ -307,6 +228,143 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
struct Polling<B: Requester> {
|
||||
bot: B,
|
||||
timeout: Option<Duration>,
|
||||
limit: Option<u8>,
|
||||
allowed_updates: Option<Vec<AllowedUpdate>>,
|
||||
flag: AsyncStopFlag,
|
||||
token: AsyncStopToken,
|
||||
}
|
||||
|
||||
#[pin_project::pin_project]
|
||||
struct PollingStream<'a, B: Requester> {
|
||||
/// Parent structure
|
||||
polling: &'a mut Polling<B>,
|
||||
|
||||
/// Timeout parameter for normal `get_updates()` calls.
|
||||
timeout: Option<u32>,
|
||||
/// Allowed updates parameter for the first `get_updates()` call.
|
||||
allowed_updates: Option<Vec<AllowedUpdate>>,
|
||||
/// Offset parameter for normal `get_updates()` calls.
|
||||
offset: i32,
|
||||
|
||||
/// If this is set, return `None` from `poll_next` immediately.
|
||||
force_stop: bool,
|
||||
/// If true we've sent last `get_updates()` call for graceful shutdown.
|
||||
stopping: bool,
|
||||
|
||||
/// Buffer of updates to be yielded.
|
||||
buffer: vec::IntoIter<Update>,
|
||||
|
||||
/// In-flight `get_updates()` call.
|
||||
#[pin]
|
||||
in_flight: Option<<B::GetUpdates as Request>::Send>,
|
||||
}
|
||||
|
||||
impl<B: Requester + Send + 'static> UpdateListener<B::Err> for Polling<B> {
|
||||
type StopToken = AsyncStopToken;
|
||||
|
||||
fn stop_token(&mut self) -> Self::StopToken {
|
||||
self.token.clone()
|
||||
}
|
||||
|
||||
fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator<Item = AllowedUpdate>) {
|
||||
// TODO: we should probably warn if there already were different allowed updates
|
||||
// before
|
||||
self.allowed_updates = Some(hint.collect());
|
||||
}
|
||||
|
||||
fn timeout_hint(&self) -> Option<Duration> {
|
||||
self.timeout
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Requester + Send + 'a> AsUpdateStream<'a, B::Err> for Polling<B> {
|
||||
type Stream = PollingStream<'a, B>;
|
||||
|
||||
fn as_stream(&'a mut self) -> Self::Stream {
|
||||
let timeout = self.timeout.map(|t| t.as_secs().try_into().expect("timeout is too big"));
|
||||
let allowed_updates = self.allowed_updates.clone();
|
||||
PollingStream {
|
||||
polling: self,
|
||||
timeout,
|
||||
allowed_updates,
|
||||
offset: 0,
|
||||
force_stop: false,
|
||||
stopping: false,
|
||||
buffer: Vec::new().into_iter(),
|
||||
in_flight: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Requester> Stream for PollingStream<'_, B> {
|
||||
type Item = Result<Update, B::Err>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let mut this = self.as_mut().project();
|
||||
|
||||
if *this.force_stop {
|
||||
return Ready(None);
|
||||
}
|
||||
|
||||
// Poll in-flight future until completion
|
||||
if let Some(in_flight) = this.in_flight.as_mut().as_pin_mut() {
|
||||
let res = ready!(in_flight.poll(cx));
|
||||
this.in_flight.set(None);
|
||||
|
||||
match res {
|
||||
Ok(_) if *this.stopping => return Ready(None),
|
||||
Err(err) if *this.stopping => {
|
||||
// Prevents infinite retries, see https://github.com/teloxide/teloxide/issues/496
|
||||
*this.force_stop = true;
|
||||
|
||||
return Ready(Some(Err(err)));
|
||||
}
|
||||
Ok(updates) => {
|
||||
if let Some(upd) = updates.last() {
|
||||
*this.offset = upd.id + 1;
|
||||
}
|
||||
|
||||
*this.buffer = updates.into_iter();
|
||||
}
|
||||
Err(err) => return Ready(Some(Err(err))),
|
||||
}
|
||||
}
|
||||
|
||||
// If there are any buffered updates, return one
|
||||
if let Some(upd) = this.buffer.next() {
|
||||
return Ready(Some(Ok(upd)));
|
||||
}
|
||||
|
||||
// When stopping we set `timeout = 0` and `limit = 1` so that `get_updates()`
|
||||
// set last seen update (offset) and return immediately
|
||||
let (timeout, limit) = if this.polling.flag.is_stopped() {
|
||||
*this.stopping = true;
|
||||
(Some(0), Some(1))
|
||||
} else {
|
||||
(*this.timeout, this.polling.limit)
|
||||
};
|
||||
|
||||
let req = this
|
||||
.polling
|
||||
.bot
|
||||
.get_updates()
|
||||
.with_payload_mut(|pay| {
|
||||
pay.offset = Some(*this.offset);
|
||||
pay.timeout = timeout;
|
||||
pay.limit = limit;
|
||||
pay.allowed_updates = this.allowed_updates.take();
|
||||
})
|
||||
.send();
|
||||
this.in_flight.set(Some(req));
|
||||
|
||||
// Recurse to poll `self.in_flight`
|
||||
self.poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn polling_is_send() {
|
||||
use crate::dispatching::update_listeners::AsUpdateStream;
|
||||
|
|
Loading…
Reference in a new issue