[throttle] explicit queue-closed error handling & comments

This commit is contained in:
Waffle 2020-10-01 16:51:59 +03:00
parent 23ef060d08
commit 247868a815

View file

@ -8,7 +8,7 @@ use std::{
use futures::task::{Context, Poll};
use tokio::{
sync::{
mpsc,
mpsc::{self, error::TryRecvError},
oneshot::{channel, Receiver, Sender},
},
time::delay_for,
@ -48,6 +48,23 @@ impl Default for Limits {
}
}
/// <we need the doc here>
///
/// ## Note about send-by-@username
///
/// Telegram have limits on sending messages to _the same chat_. To check them
/// we store `chat_id`s of several last requests. _However_ there is no good way
/// to tell if given `ChatId::Id(x)` corresponds to the same chat as
/// `ChatId::ChannelUsername(u)`.
///
/// Our current approach is to just give up and check `chat_id_a == chat_id_b`.
/// This may give incorrect results.
///
/// Also, current algorithm requires to `clone` `chat_id` several times, which
/// can be quire expensive for strings (though this may be fixed in the future)
///
/// As such, we encourage not to use `ChatId::ChannelUsername(u)` with this bot
/// wrapper.
pub struct Throttle<B> {
bot: B,
queue: mpsc::Sender<(ChatId, Sender<()>)>,
@ -72,13 +89,22 @@ async fn worker(limits: Limits, mut queue_rx: mpsc::Receiver<(ChatId, Sender<()>
loop {
// If there are no pending requests we are just waiting
if queue.is_empty() {
queue.push(queue_rx.recv().await.unwrap());
let req = queue_rx
.recv()
.await
// FIXME(waffle): decide what should we do on channel close
.expect("Queue channel was closed");
queue.push(req);
}
// update local queue with latest requests
while let Ok(e) = queue_rx.try_recv() {
// FIXME: properly check for errors (stop when the bot's sender is dropped?)
queue.push(e);
loop {
match queue_rx.try_recv() {
Ok(req) => queue.push(req),
Err(TryRecvError::Empty) => break,
// FIXME(waffle): decide what should we do on channel close
Err(TryRecvError::Closed) => unimplemented!("Queue channel was closed"),
}
}
// _Maybe_ we need to use `spawn_blocking` here, because there is
@ -126,9 +152,11 @@ async fn worker(limits: Limits, mut queue_rx: mpsc::Receiver<(ChatId, Sender<()>
}
if let Some((chat, _)) = history.pop_front() {
if let Entry::Occupied(entry) = hchats.entry(chat).and_modify(|count| {
let ent = hchats.entry(chat).and_modify(|count| {
*count -= 1;
}) {
});
if let Entry::Occupied(entry) = ent {
if *entry.get() == 0 {
entry.remove_entry();
}
@ -138,9 +166,8 @@ async fn worker(limits: Limits, mut queue_rx: mpsc::Receiver<(ChatId, Sender<()>
// as truncates which is ok since in case of truncation it would always be >=
// limits.overall_s
let mut allowed = limits
.overall_s
.saturating_sub(history.iter().take_while(|(_, time)| time > &sec_back).count() as u32);
let used = history.iter().take_while(|(_, time)| time > &sec_back).count() as u32;
let mut allowed = limits.overall_s.saturating_sub(used);
if allowed == 0 {
hchats_s.clear();
@ -166,8 +193,14 @@ async fn worker(limits: Limits, mut queue_rx: mpsc::Receiver<(ChatId, Sender<()>
*hchats.entry(chat.clone()).or_insert(0) += 1;
history.push_back((chat.clone(), Instant::now()));
}
entry.remove().1.send(());
// Explicitly ignore result.
//
// If request doesn't listen to unlock channel we don't want
// anything to do with it.
let _ = entry.remove().1.send(());
// We've "sent" 1 request, so now we can send 1 less
allowed -= 1;
if allowed == 0 {
break;
@ -178,6 +211,8 @@ async fn worker(limits: Limits, mut queue_rx: mpsc::Receiver<(ChatId, Sender<()>
}
drop(queue_rem);
// It's easier to just recompute last second stats, instead of keeping
// track of it alongside with minute stats, so we just throw this away.
hchats_s.clear();
delay_for(DELAY).await;
}
@ -251,7 +286,6 @@ impl<B: Requester> Requester for Throttle<B> {
}
pub trait GetChatId {
// FIXME(waffle): add note about false negatives with ChatId::Username
fn get_chat_id(&self) -> &ChatId;
}
@ -280,13 +314,13 @@ where
<R as HasPayload>::Payload: GetChatId,
{
type Err = R::Err;
type Send = LimitedSend<R>;
type SendRef = LimitedSend<R>;
type Send = ThrottlingSend<R>;
type SendRef = ThrottlingSend<R>;
fn send(self) -> Self::Send {
let (tx, rx) = channel();
let send = self.1.send_t((self.0.payload_ref().get_chat_id().clone(), tx));
LimitedSend::Registering { request: self.0, send, wait: rx }
ThrottlingSend::Registering { request: self.0, send, wait: rx }
}
fn send_ref(&self) -> Self::SendRef {
@ -295,7 +329,7 @@ where
}
#[pin_project::pin_project(project = SendProj, project_replace = SendRepl)]
pub enum LimitedSend<R: Request> {
pub enum ThrottlingSend<R: Request> {
Registering {
request: R,
#[pin]
@ -314,7 +348,7 @@ pub enum LimitedSend<R: Request> {
Done,
}
impl<R: Request> Future for LimitedSend<R> {
impl<R: Request> Future for ThrottlingSend<R> {
type Output = Result<Output<R>, R::Err>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
@ -325,9 +359,9 @@ impl<R: Request> Future for LimitedSend<R> {
// FIXME(waffle): remove unwrap
r.unwrap();
if let SendRepl::Registering { request, send: _, wait } =
self.as_mut().project_replace(LimitedSend::Done)
self.as_mut().project_replace(ThrottlingSend::Done)
{
self.as_mut().project_replace(LimitedSend::Pending { request, wait });
self.as_mut().project_replace(ThrottlingSend::Pending { request, wait });
}
self.poll(cx)
@ -339,9 +373,9 @@ impl<R: Request> Future for LimitedSend<R> {
// FIXME(waffle): remove unwrap
r.unwrap();
if let SendRepl::Pending { request, wait: _ } =
self.as_mut().project_replace(LimitedSend::Done)
self.as_mut().project_replace(ThrottlingSend::Done)
{
self.as_mut().project_replace(LimitedSend::Sent { fut: request.send() });
self.as_mut().project_replace(ThrottlingSend::Sent { fut: request.send() });
}
self.poll(cx)
@ -349,7 +383,7 @@ impl<R: Request> Future for LimitedSend<R> {
},
SendProj::Sent { fut } => {
let res = futures::ready!(fut.poll(cx));
self.set(LimitedSend::Done);
self.set(ThrottlingSend::Done);
Poll::Ready(res)
}
SendProj::Done => Poll::Pending,
@ -358,12 +392,13 @@ impl<R: Request> Future for LimitedSend<R> {
}
mod chan_send {
use crate::types::ChatId;
use std::{future::Future, pin::Pin};
use futures::task::{Context, Poll};
use pin_project::__private::Pin;
use std::future::Future;
use tokio::sync::{mpsc, mpsc::error::SendError, oneshot::Sender};
use crate::types::ChatId;
pub(crate) trait SendTy {
fn send_t(self, val: (ChatId, Sender<()>)) -> ChanSend;
}