mirror of
https://github.com/teloxide/teloxide.git
synced 2025-03-14 11:44:04 +01:00
[throttle] explicit queue-closed error handling & comments
This commit is contained in:
parent
23ef060d08
commit
247868a815
1 changed files with 60 additions and 25 deletions
|
@ -8,7 +8,7 @@ use std::{
|
|||
use futures::task::{Context, Poll};
|
||||
use tokio::{
|
||||
sync::{
|
||||
mpsc,
|
||||
mpsc::{self, error::TryRecvError},
|
||||
oneshot::{channel, Receiver, Sender},
|
||||
},
|
||||
time::delay_for,
|
||||
|
@ -48,6 +48,23 @@ impl Default for Limits {
|
|||
}
|
||||
}
|
||||
|
||||
/// <we need the doc here>
|
||||
///
|
||||
/// ## Note about send-by-@username
|
||||
///
|
||||
/// Telegram have limits on sending messages to _the same chat_. To check them
|
||||
/// we store `chat_id`s of several last requests. _However_ there is no good way
|
||||
/// to tell if given `ChatId::Id(x)` corresponds to the same chat as
|
||||
/// `ChatId::ChannelUsername(u)`.
|
||||
///
|
||||
/// Our current approach is to just give up and check `chat_id_a == chat_id_b`.
|
||||
/// This may give incorrect results.
|
||||
///
|
||||
/// Also, current algorithm requires to `clone` `chat_id` several times, which
|
||||
/// can be quire expensive for strings (though this may be fixed in the future)
|
||||
///
|
||||
/// As such, we encourage not to use `ChatId::ChannelUsername(u)` with this bot
|
||||
/// wrapper.
|
||||
pub struct Throttle<B> {
|
||||
bot: B,
|
||||
queue: mpsc::Sender<(ChatId, Sender<()>)>,
|
||||
|
@ -72,13 +89,22 @@ async fn worker(limits: Limits, mut queue_rx: mpsc::Receiver<(ChatId, Sender<()>
|
|||
loop {
|
||||
// If there are no pending requests we are just waiting
|
||||
if queue.is_empty() {
|
||||
queue.push(queue_rx.recv().await.unwrap());
|
||||
let req = queue_rx
|
||||
.recv()
|
||||
.await
|
||||
// FIXME(waffle): decide what should we do on channel close
|
||||
.expect("Queue channel was closed");
|
||||
queue.push(req);
|
||||
}
|
||||
|
||||
// update local queue with latest requests
|
||||
while let Ok(e) = queue_rx.try_recv() {
|
||||
// FIXME: properly check for errors (stop when the bot's sender is dropped?)
|
||||
queue.push(e);
|
||||
loop {
|
||||
match queue_rx.try_recv() {
|
||||
Ok(req) => queue.push(req),
|
||||
Err(TryRecvError::Empty) => break,
|
||||
// FIXME(waffle): decide what should we do on channel close
|
||||
Err(TryRecvError::Closed) => unimplemented!("Queue channel was closed"),
|
||||
}
|
||||
}
|
||||
|
||||
// _Maybe_ we need to use `spawn_blocking` here, because there is
|
||||
|
@ -126,9 +152,11 @@ async fn worker(limits: Limits, mut queue_rx: mpsc::Receiver<(ChatId, Sender<()>
|
|||
}
|
||||
|
||||
if let Some((chat, _)) = history.pop_front() {
|
||||
if let Entry::Occupied(entry) = hchats.entry(chat).and_modify(|count| {
|
||||
let ent = hchats.entry(chat).and_modify(|count| {
|
||||
*count -= 1;
|
||||
}) {
|
||||
});
|
||||
|
||||
if let Entry::Occupied(entry) = ent {
|
||||
if *entry.get() == 0 {
|
||||
entry.remove_entry();
|
||||
}
|
||||
|
@ -138,9 +166,8 @@ async fn worker(limits: Limits, mut queue_rx: mpsc::Receiver<(ChatId, Sender<()>
|
|||
|
||||
// as truncates which is ok since in case of truncation it would always be >=
|
||||
// limits.overall_s
|
||||
let mut allowed = limits
|
||||
.overall_s
|
||||
.saturating_sub(history.iter().take_while(|(_, time)| time > &sec_back).count() as u32);
|
||||
let used = history.iter().take_while(|(_, time)| time > &sec_back).count() as u32;
|
||||
let mut allowed = limits.overall_s.saturating_sub(used);
|
||||
|
||||
if allowed == 0 {
|
||||
hchats_s.clear();
|
||||
|
@ -166,8 +193,14 @@ async fn worker(limits: Limits, mut queue_rx: mpsc::Receiver<(ChatId, Sender<()>
|
|||
*hchats.entry(chat.clone()).or_insert(0) += 1;
|
||||
history.push_back((chat.clone(), Instant::now()));
|
||||
}
|
||||
entry.remove().1.send(());
|
||||
|
||||
// Explicitly ignore result.
|
||||
//
|
||||
// If request doesn't listen to unlock channel we don't want
|
||||
// anything to do with it.
|
||||
let _ = entry.remove().1.send(());
|
||||
|
||||
// We've "sent" 1 request, so now we can send 1 less
|
||||
allowed -= 1;
|
||||
if allowed == 0 {
|
||||
break;
|
||||
|
@ -178,6 +211,8 @@ async fn worker(limits: Limits, mut queue_rx: mpsc::Receiver<(ChatId, Sender<()>
|
|||
}
|
||||
drop(queue_rem);
|
||||
|
||||
// It's easier to just recompute last second stats, instead of keeping
|
||||
// track of it alongside with minute stats, so we just throw this away.
|
||||
hchats_s.clear();
|
||||
delay_for(DELAY).await;
|
||||
}
|
||||
|
@ -251,7 +286,6 @@ impl<B: Requester> Requester for Throttle<B> {
|
|||
}
|
||||
|
||||
pub trait GetChatId {
|
||||
// FIXME(waffle): add note about false negatives with ChatId::Username
|
||||
fn get_chat_id(&self) -> &ChatId;
|
||||
}
|
||||
|
||||
|
@ -280,13 +314,13 @@ where
|
|||
<R as HasPayload>::Payload: GetChatId,
|
||||
{
|
||||
type Err = R::Err;
|
||||
type Send = LimitedSend<R>;
|
||||
type SendRef = LimitedSend<R>;
|
||||
type Send = ThrottlingSend<R>;
|
||||
type SendRef = ThrottlingSend<R>;
|
||||
|
||||
fn send(self) -> Self::Send {
|
||||
let (tx, rx) = channel();
|
||||
let send = self.1.send_t((self.0.payload_ref().get_chat_id().clone(), tx));
|
||||
LimitedSend::Registering { request: self.0, send, wait: rx }
|
||||
ThrottlingSend::Registering { request: self.0, send, wait: rx }
|
||||
}
|
||||
|
||||
fn send_ref(&self) -> Self::SendRef {
|
||||
|
@ -295,7 +329,7 @@ where
|
|||
}
|
||||
|
||||
#[pin_project::pin_project(project = SendProj, project_replace = SendRepl)]
|
||||
pub enum LimitedSend<R: Request> {
|
||||
pub enum ThrottlingSend<R: Request> {
|
||||
Registering {
|
||||
request: R,
|
||||
#[pin]
|
||||
|
@ -314,7 +348,7 @@ pub enum LimitedSend<R: Request> {
|
|||
Done,
|
||||
}
|
||||
|
||||
impl<R: Request> Future for LimitedSend<R> {
|
||||
impl<R: Request> Future for ThrottlingSend<R> {
|
||||
type Output = Result<Output<R>, R::Err>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
|
@ -325,9 +359,9 @@ impl<R: Request> Future for LimitedSend<R> {
|
|||
// FIXME(waffle): remove unwrap
|
||||
r.unwrap();
|
||||
if let SendRepl::Registering { request, send: _, wait } =
|
||||
self.as_mut().project_replace(LimitedSend::Done)
|
||||
self.as_mut().project_replace(ThrottlingSend::Done)
|
||||
{
|
||||
self.as_mut().project_replace(LimitedSend::Pending { request, wait });
|
||||
self.as_mut().project_replace(ThrottlingSend::Pending { request, wait });
|
||||
}
|
||||
|
||||
self.poll(cx)
|
||||
|
@ -339,9 +373,9 @@ impl<R: Request> Future for LimitedSend<R> {
|
|||
// FIXME(waffle): remove unwrap
|
||||
r.unwrap();
|
||||
if let SendRepl::Pending { request, wait: _ } =
|
||||
self.as_mut().project_replace(LimitedSend::Done)
|
||||
self.as_mut().project_replace(ThrottlingSend::Done)
|
||||
{
|
||||
self.as_mut().project_replace(LimitedSend::Sent { fut: request.send() });
|
||||
self.as_mut().project_replace(ThrottlingSend::Sent { fut: request.send() });
|
||||
}
|
||||
|
||||
self.poll(cx)
|
||||
|
@ -349,7 +383,7 @@ impl<R: Request> Future for LimitedSend<R> {
|
|||
},
|
||||
SendProj::Sent { fut } => {
|
||||
let res = futures::ready!(fut.poll(cx));
|
||||
self.set(LimitedSend::Done);
|
||||
self.set(ThrottlingSend::Done);
|
||||
Poll::Ready(res)
|
||||
}
|
||||
SendProj::Done => Poll::Pending,
|
||||
|
@ -358,12 +392,13 @@ impl<R: Request> Future for LimitedSend<R> {
|
|||
}
|
||||
|
||||
mod chan_send {
|
||||
use crate::types::ChatId;
|
||||
use std::{future::Future, pin::Pin};
|
||||
|
||||
use futures::task::{Context, Poll};
|
||||
use pin_project::__private::Pin;
|
||||
use std::future::Future;
|
||||
use tokio::sync::{mpsc, mpsc::error::SendError, oneshot::Sender};
|
||||
|
||||
use crate::types::ChatId;
|
||||
|
||||
pub(crate) trait SendTy {
|
||||
fn send_t(self, val: (ChatId, Sender<()>)) -> ChanSend;
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue