Merge pull request #46 from teloxide/refactor-throttle

Refactor requests throttling
This commit is contained in:
Temirkhan Myrzamadi 2021-02-15 23:17:29 +03:00 committed by GitHub
commit 557002b43a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -23,49 +23,48 @@ use crate::{
types::*,
};
// Throttling is quite complicated this comment describes the algorithm of
// current implementation. NOTE: this only describes CURRENT implementation.
// Implementation may change at any time.
// Throttling is quite complicated. This comment describes the algorithm of the
// current implementation.
//
// ### Request
//
// When throttling request is sent, it sends a tuple of `ChatId` (more
// accurately, just local `Id`) and `Sender<()>` to the worker. Then the request
// waits for notification from worker. When notification is received it sends
// underlying request.
// When a throttling request is sent, it sends a tuple of `ChatId` and
// `Sender<()>` to the worker. Then the request waits for a notification from
// the worker. When notification is received, it sends the underlying request.
//
// ### Worker
//
// Worker does the most important job - it checks for limit exceed.
// The worker does the most important job -- it ensures that the limits are
// never exceeded.
//
// The worker stores "history" of requests sent in last minute (and to which
// chats the were sent) and queue of pending updates.
// The worker stores a history of requests sent in the last minute (and to which
// chats they were sent) and a queue of pending updates.
//
// The worker does the following algorithm loop:
//
// 1. If queue is empty wait for the first message in incoming channel (and adds
// it to queue).
// 1. If the queue is empty, wait for the first message in incoming channel (and
// add it to the queue).
//
// 2. Read all present messages from incoming channel and transfer them to
// queue.
// 2. Read all present messages from an incoming channel and transfer them to
// the queue.
//
// 3. Record current time.
// 3. Record the current time.
//
// 4. Clear history from records which time < (current - minute)
// 4. Clear the history from records whose time < (current time - minute).
//
// 5. Count all requests in which were sent last second,
// `allowed = limit.overall_s - count`
// 5. Count all requests which were sent last second, `allowed =
// limit.messages_per_sec_overall - count`.
//
// 6. If `allowed == 0` wait a bit and `continue` to the next iteration
// 6. If `allowed == 0` wait a bit and `continue` to the next iteration.
//
// 7. Count how many requests were sent to which chats (i.e.: create
// `Map<ChatId, Count>`) (note: the same map, but for last minute also
// exists, but it's updated, instead of recreation)
// `Map<ChatId, Count>`). (Note: the same map, but for last minute also exists,
// but it's updated, instead of recreation.)
//
// 8. While `allowed >= 0` search for requests which chat hasn't exceed limits
// (i.e.: map[chat] < limit), if one is found, decrease `allowed`, notify
// request that it can be now executed, increase counts, add record to
// history.
// 8. While `allowed >= 0` search for requests which chat haven't exceed the
// limits (i.e.: map[chat] < limit), if one is found, decrease `allowed`, notify
// the request that it can be now executed, increase counts, add record to the
// history.
const MINUTE: Duration = Duration::from_secs(60);
const SECOND: Duration = Duration::from_secs(1);
@ -86,12 +85,14 @@ const DELAY: Duration = Duration::from_millis(250);
/// [@BotSupport]: https://t.me/botsupport
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub struct Limits {
/// Allowed messages in one chat per second
pub chat_s: u32,
/// Allowed messages per second
pub overall_s: u32,
/// Allowed messages in one chat per minute
pub chat_m: u32,
/// Allowed messages in one chat per second.
pub messages_per_sec_chat: u32,
/// Allowed messages in one chat per minute.
pub messages_per_min_chat: u32,
/// Allowed messages per second.
pub messages_per_sec_overall: u32,
}
/// Defaults are taken from [telegram documentation][tgdoc].
@ -100,9 +101,9 @@ pub struct Limits {
impl Default for Limits {
fn default() -> Self {
Self {
chat_s: 1,
overall_s: 30,
chat_m: 20,
messages_per_sec_chat: 1,
messages_per_sec_overall: 30,
messages_per_min_chat: 20,
}
}
}
@ -149,48 +150,34 @@ impl Default for Limits {
pub struct Throttle<B> {
bot: B,
// Sender<Never> is used to pass the signal to unlock by closing the channel.
queue: mpsc::Sender<(Id, Sender<Never>)>,
queue: mpsc::Sender<(ChatIdHash, Sender<Never>)>,
}
async fn worker(limits: Limits, mut queue_rx: mpsc::Receiver<(Id, Sender<Never>)>) {
// +- Same idea as in `Throttle::new`
let cap = limits.overall_s + (limits.overall_s / 4);
type RequestsSent = u32;
// I wish there was special data structure for history which removed the
// need in 2 hashmaps
// (waffle)
#[derive(Default)]
struct RequestsSentToChats {
per_min: HashMap<ChatIdHash, RequestsSent>,
per_sec: HashMap<ChatIdHash, RequestsSent>,
}
async fn worker(limits: Limits, mut rx: mpsc::Receiver<(ChatIdHash, Sender<Never>)>) {
// FIXME(waffle): Make an research about data structures for this queue.
// Currently this is O(n) removing (n = number of elements
// stayed), amortized O(1) push (vec+vecrem).
let mut queue: Vec<(Id, Sender<Never>)> = Vec::with_capacity(cap as usize);
let mut queue: Vec<(ChatIdHash, Sender<Never>)> =
Vec::with_capacity(limits.messages_per_sec_overall as usize);
// I wish there was special data structure for history which removed the
// need in 2 hashmaps
// (waffle)
let mut history: VecDeque<(Id, Instant)> = VecDeque::new();
// hchats[chat] = history.iter().filter(|(c, _)| c == chat).count()
let mut hchats: HashMap<Id, u32> = HashMap::new();
let mut hchats_s = HashMap::new();
let mut history: VecDeque<(ChatIdHash, Instant)> = VecDeque::new();
let mut requests_sent = RequestsSentToChats::default();
// set to true when `queue_rx` is closed
let mut close = false;
let mut rx_is_closed = false;
while !close || !queue.is_empty() {
// If there are no pending requests we are just waiting
if queue.is_empty() {
match queue_rx.recv().await {
Some(req) => queue.push(req),
None => close = true,
}
}
// update local queue with latest requests
loop {
// FIXME(waffle): https://github.com/tokio-rs/tokio/issues/3350
match queue_rx.recv().now_or_never() {
Some(Some(req)) => queue.push(req),
// There are no items in queue
None => break,
// The queue was closed
Some(None) => close = true,
}
}
while !rx_is_closed || !queue.is_empty() {
read_from_rx(&mut rx, &mut queue, &mut rx_is_closed).await;
// _Maybe_ we need to use `spawn_blocking` here, because there is
// decent amount of blocking work. However _for now_ I've decided not
@ -237,11 +224,11 @@ async fn worker(limits: Limits, mut queue_rx: mpsc::Receiver<(Id, Sender<Never>)
}
if let Some((chat, _)) = history.pop_front() {
let ent = hchats.entry(chat).and_modify(|count| {
let entry = requests_sent.per_min.entry(chat).and_modify(|count| {
*count -= 1;
});
if let Entry::Occupied(entry) = ent {
if let Entry::Occupied(entry) = entry {
if *entry.get() == 0 {
entry.remove_entry();
}
@ -255,71 +242,82 @@ async fn worker(limits: Limits, mut queue_rx: mpsc::Receiver<(Id, Sender<Never>)
.iter()
.take_while(|(_, time)| time > &sec_back)
.count() as u32;
let mut allowed = limits.overall_s.saturating_sub(used);
let mut allowed = limits.messages_per_sec_overall.saturating_sub(used);
if allowed == 0 {
hchats_s.clear();
requests_sent.per_sec.clear();
tokio::time::sleep(DELAY).await;
continue;
}
for (chat, _) in history.iter().take_while(|(_, time)| time > &sec_back) {
*hchats_s.entry(*chat).or_insert(0) += 1;
*requests_sent.per_sec.entry(*chat).or_insert(0) += 1;
}
{
let mut queue_rem = queue.removing();
while let Some(entry) = queue_rem.next() {
let chat = &entry.value().0;
let cond = {
hchats_s.get(chat).copied().unwrap_or(0) < limits.chat_s
&& hchats.get(chat).copied().unwrap_or(0) < limits.chat_m
};
let mut queue_removing = queue.removing();
if cond {
{
*hchats_s.entry(*chat).or_insert(0) += 1;
*hchats.entry(*chat).or_insert(0) += 1;
history.push_back((*chat, Instant::now()));
}
while let Some(entry) = queue_removing.next() {
let chat = &entry.value().0;
let requests_sent_count = requests_sent.per_sec.get(chat).copied().unwrap_or(0);
let limits_not_exceeded = requests_sent_count < limits.messages_per_sec_chat
&& requests_sent_count < limits.messages_per_min_chat;
// This will close the channel unlocking associated request
drop(entry.remove());
if limits_not_exceeded {
*requests_sent.per_sec.entry(*chat).or_insert(0) += 1;
*requests_sent.per_min.entry(*chat).or_insert(0) += 1;
history.push_back((*chat, Instant::now()));
// We've "sent" 1 request, so now we can send 1 less
allowed -= 1;
if allowed == 0 {
break;
}
// Close the channel and unlock the associated request.
drop(entry.remove());
// We have "sent" one request, so now we can send one less.
allowed -= 1;
if allowed == 0 {
break;
}
}
}
// It's easier to just recompute last second stats, instead of keeping
// track of it alongside with minute stats, so we just throw this away.
hchats_s.clear();
requests_sent.per_sec.clear();
tokio::time::sleep(DELAY).await;
}
}
async fn read_from_rx(
rx: &mut mpsc::Receiver<(ChatIdHash, Sender<Never>)>,
queue: &mut Vec<(ChatIdHash, Sender<Never>)>,
rx_is_closed: &mut bool,
) {
if queue.is_empty() {
match rx.recv().await {
Some(req) => queue.push(req),
None => *rx_is_closed = true,
}
}
loop {
// FIXME(waffle): https://github.com/tokio-rs/tokio/issues/3350
match rx.recv().now_or_never() {
Some(Some(req)) => queue.push(req),
Some(None) => *rx_is_closed = true,
// There are no items in queue.
None => break,
}
}
}
impl<B> Throttle<B> {
/// Creates new [`Throttle`] alongside with worker future.
///
/// Note: [`Throttle`] will only send requests if returned worker is
/// polled/spawned/awaited.
pub fn new(bot: B, limits: Limits) -> (Self, impl Future<Output = ()>) {
// A buffer made slightly bigger (112.5%) than overall limit
// so we won't lose performance when hitting limits.
//
// (I hope this makes sense) (waffle)
let buffer = limits.overall_s + (limits.overall_s / 8);
let (queue_tx, queue_rx) = mpsc::channel(buffer as usize);
let (tx, rx) = mpsc::channel(limits.messages_per_sec_overall as usize);
let worker = worker(limits, queue_rx);
let this = Self {
bot,
queue: queue_tx,
};
let worker = worker(limits, rx);
let this = Self { bot, queue: tx };
(this, worker)
}
@ -450,7 +448,7 @@ where
prices,
),
self.queue.clone(),
|p| Id::Id(p.payload_ref().chat_id as _),
|p| ChatIdHash::Id(p.payload_ref().chat_id as _),
)
}
@ -484,25 +482,25 @@ download_forward! {
{ this => this.inner() }
}
/// Id used in worker.
/// An ID used in the worker.
///
/// It is used instead of `ChatId` to make copying cheap even in case of
/// usernames. (It just hashes username)
/// usernames. (It is just a hashed username.)
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
enum Id {
enum ChatIdHash {
Id(i64),
Ch(u64),
ChannelUsernameHash(u64),
}
impl From<&ChatId> for Id {
impl From<&ChatId> for ChatIdHash {
fn from(value: &ChatId) -> Self {
match value {
ChatId::Id(id) => Id::Id(*id),
ChatId::Id(id) => ChatIdHash::Id(*id),
ChatId::ChannelUsername(username) => {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
username.hash(&mut hasher);
let hash = hasher.finish();
Id::Ch(hash)
ChatIdHash::ChannelUsernameHash(hash)
}
}
}
@ -510,8 +508,8 @@ impl From<&ChatId> for Id {
pub struct ThrottlingRequest<R: HasPayload>(
R,
mpsc::Sender<(Id, Sender<Never>)>,
fn(&R::Payload) -> Id,
mpsc::Sender<(ChatIdHash, Sender<Never>)>,
fn(&R::Payload) -> ChatIdHash,
);
impl<R: HasPayload> HasPayload for ThrottlingRequest<R> {
@ -745,29 +743,30 @@ mod chan_send {
use never::Never;
use tokio::sync::{mpsc, mpsc::error::SendError, oneshot::Sender};
use crate::adaptors::throttle::Id;
use crate::adaptors::throttle::ChatIdHash;
pub(super) trait SendTy {
fn send_t(self, val: (Id, Sender<Never>)) -> ChanSend;
fn send_t(self, val: (ChatIdHash, Sender<Never>)) -> ChanSend;
}
#[pin_project::pin_project]
pub(super) struct ChanSend(#[pin] Inner);
#[cfg(not(feature = "nightly"))]
type Inner = Pin<Box<dyn Future<Output = Result<(), SendError<(Id, Sender<Never>)>>> + Send>>;
type Inner =
Pin<Box<dyn Future<Output = Result<(), SendError<(ChatIdHash, Sender<Never>)>>> + Send>>;
#[cfg(feature = "nightly")]
type Inner = impl Future<Output = Result<(), SendError<(Id, Sender<Never>)>>>;
type Inner = impl Future<Output = Result<(), SendError<(ChatIdHash, Sender<Never>)>>>;
impl SendTy for mpsc::Sender<(Id, Sender<Never>)> {
impl SendTy for mpsc::Sender<(ChatIdHash, Sender<Never>)> {
// `return`s trick IDEA not to show errors
#[allow(clippy::needless_return)]
fn send_t(self, val: (Id, Sender<Never>)) -> ChanSend {
fn send_t(self, val: (ChatIdHash, Sender<Never>)) -> ChanSend {
#[cfg(feature = "nightly")]
{
fn def(
sender: mpsc::Sender<(Id, Sender<Never>)>,
val: (Id, Sender<Never>),
sender: mpsc::Sender<(ChatIdHash, Sender<Never>)>,
val: (ChatIdHash, Sender<Never>),
) -> Inner {
async move { sender.send(val).await }
}
@ -782,7 +781,7 @@ mod chan_send {
}
impl Future for ChanSend {
type Output = Result<(), SendError<(Id, Sender<Never>)>>;
type Output = Result<(), SendError<(ChatIdHash, Sender<Never>)>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().0.poll(cx)