Refactoring

This commit is contained in:
Temirkhan Myrzamadi 2021-02-16 00:50:46 +06:00
parent d8c16d420d
commit d58b21c42f

View file

@ -23,48 +23,47 @@ use crate::{
types::*, types::*,
}; };
// Throttling is quite complicated this comment describes the algorithm of // Throttling is quite complicated. This comment describes the algorithm of the
// current implementation. NOTE: this only describes CURRENT implementation. // current implementation.
// Implementation may change at any time.
// //
// ### Request // ### Request
// //
// When throttling request is sent, it sends a tuple of `ChatId` (more // When a throttling request is sent, it sends a tuple of `ChatId` and
// accurately, just local `Id`) and `Sender<()>` to the worker. Then the request // `Sender<()>` to the worker. Then the request waits for a notification from
// waits for notification from worker. When notification is received it sends // the worker. When notification is received, it sends the underlying request.
// underlying request.
// //
// ### Worker // ### Worker
// //
// Worker does the most important job - it checks for limit exceed. // The worker does the most important job -- it ensures that the limits are
// never exceeded.
// //
// The worker stores "history" of requests sent in last minute (and to which // The worker stores a history of requests sent in the last minute (and to which
// chats the were sent) and queue of pending updates. // chats they were sent) and a queue of pending updates.
// //
// The worker does the following algorithm loop: // The worker does the following algorithm loop:
// //
// 1. If queue is empty wait for the first message in incoming channel (and adds // 1. If the queue is empty, wait for the first message in incoming channel (and
// it to queue). // add it to the queue).
// //
// 2. Read all present messages from incoming channel and transfer them to // 2. Read all present messages from an incoming channel and transfer them to
// queue. // the queue.
// //
// 3. Record current time. // 3. Record the current time.
// //
// 4. Clear history from records which time < (current - minute) // 4. Clear the history from records whose time < (current time - minute).
// //
// 5. Count all requests in which were sent last second, // 5. Count all requests which were sent last second, `allowed = limit.overall_s
// `allowed = limit.overall_s - count` // - count`.
// //
// 6. If `allowed == 0` wait a bit and `continue` to the next iteration // 6. If `allowed == 0` wait a bit and `continue` to the next iteration.
// //
// 7. Count how many requests were sent to which chats (i.e.: create // 7. Count how many requests were sent to which chats (i.e.: create
// `Map<ChatId, Count>`) (note: the same map, but for last minute also // `Map<ChatId, Count>`). (Note: the same map, but for last minute also exists,
// exists, but it's updated, instead of recreation) // but it's updated, instead of recreation.)
// //
// 8. While `allowed >= 0` search for requests which chat hasn't exceed limits // 8. While `allowed >= 0` search for requests which chat haven't exceed the
// (i.e.: map[chat] < limit), if one is found, decrease `allowed`, notify // limits (i.e.: map[chat] < limit), if one is found, decrease `allowed`, notify
// request that it can be now executed, increase counts, add record to // the request that it can be now executed, increase counts, add record to the
// history. // history.
const MINUTE: Duration = Duration::from_secs(60); const MINUTE: Duration = Duration::from_secs(60);
@ -151,47 +150,34 @@ impl Default for Limits {
pub struct Throttle<B> { pub struct Throttle<B> {
bot: B, bot: B,
// Sender<Never> is used to pass the signal to unlock by closing the channel. // Sender<Never> is used to pass the signal to unlock by closing the channel.
queue: mpsc::Sender<(Id, Sender<Never>)>, queue: mpsc::Sender<(FastChatId, Sender<Never>)>,
} }
async fn worker(limits: Limits, mut rx: mpsc::Receiver<(Id, Sender<Never>)>) { type RequestsSent = u32;
// +- Same idea as in `Throttle::new`
let capacity = limits.messages_per_sec_overall + (limits.messages_per_sec_overall / 4);
// FIXME(waffle): Make an research about data structures for this queue.
// Currently this is O(n) removing (n = number of elements
// stayed), amortized O(1) push (vec+vecrem).
let mut queue: Vec<(Id, Sender<Never>)> = Vec::with_capacity(capacity as usize);
// I wish there was special data structure for history which removed the // I wish there was special data structure for history which removed the
// need in 2 hashmaps // need in 2 hashmaps
// (waffle) // (waffle)
let mut history: VecDeque<(Id, Instant)> = VecDeque::new(); #[derive(Default)]
let mut hchats: HashMap<Id, u32> = HashMap::new(); struct RequestsSentToChats {
let mut hchats_s = HashMap::new(); per_min: HashMap<FastChatId, RequestsSent>,
per_sec: HashMap<FastChatId, RequestsSent>,
// set to true when `rx` is closed
let mut close = false;
while !close || !queue.is_empty() {
// If there are no pending requests we are just waiting
if queue.is_empty() {
match rx.recv().await {
Some(req) => queue.push(req),
None => close = true,
}
} }
// update local queue with latest requests async fn worker(limits: Limits, mut rx: mpsc::Receiver<(FastChatId, Sender<Never>)>) {
loop { // FIXME(waffle): Make an research about data structures for this queue.
// FIXME(waffle): https://github.com/tokio-rs/tokio/issues/3350 // Currently this is O(n) removing (n = number of elements
match rx.recv().now_or_never() { // stayed), amortized O(1) push (vec+vecrem).
Some(Some(req)) => queue.push(req), let mut queue: Vec<(FastChatId, Sender<Never>)> =
// There are no items in queue Vec::with_capacity(limits.messages_per_sec_overall as usize);
None => break,
// The queue was closed let mut when_requests_were_sent: VecDeque<(FastChatId, Instant)> = VecDeque::new();
Some(None) => close = true, let mut requests_sent_to_chats = RequestsSentToChats::default();
}
} let mut rx_is_closed = false;
while !rx_is_closed || !queue.is_empty() {
read_from_rx(&mut rx, &mut queue, &mut rx_is_closed).await;
// _Maybe_ we need to use `spawn_blocking` here, because there is // _Maybe_ we need to use `spawn_blocking` here, because there is
// decent amount of blocking work. However _for now_ I've decided not // decent amount of blocking work. However _for now_ I've decided not
@ -231,18 +217,21 @@ async fn worker(limits: Limits, mut rx: mpsc::Receiver<(Id, Sender<Never>)>) {
let sec_back = now - SECOND; let sec_back = now - SECOND;
// make history and hchats up-to-date // make history and hchats up-to-date
while let Some((_, time)) = history.front() { while let Some((_, time)) = when_requests_were_sent.front() {
// history is sorted, we found first up-to-date thing // history is sorted, we found first up-to-date thing
if time >= &min_back { if time >= &min_back {
break; break;
} }
if let Some((chat, _)) = history.pop_front() { if let Some((chat, _)) = when_requests_were_sent.pop_front() {
let ent = hchats.entry(chat).and_modify(|count| { let entry = requests_sent_to_chats
.per_min
.entry(chat)
.and_modify(|count| {
*count -= 1; *count -= 1;
}); });
if let Entry::Occupied(entry) = ent { if let Entry::Occupied(entry) = entry {
if *entry.get() == 0 { if *entry.get() == 0 {
entry.remove_entry(); entry.remove_entry();
} }
@ -252,69 +241,90 @@ async fn worker(limits: Limits, mut rx: mpsc::Receiver<(Id, Sender<Never>)>) {
// as truncates which is ok since in case of truncation it would always be >= // as truncates which is ok since in case of truncation it would always be >=
// limits.overall_s // limits.overall_s
let used = history let used = when_requests_were_sent
.iter() .iter()
.take_while(|(_, time)| time > &sec_back) .take_while(|(_, time)| time > &sec_back)
.count() as u32; .count() as u32;
let mut allowed = limits.messages_per_sec_overall.saturating_sub(used); let mut allowed = limits.messages_per_sec_overall.saturating_sub(used);
if allowed == 0 { if allowed == 0 {
hchats_s.clear(); requests_sent_to_chats.per_sec.clear();
tokio::time::sleep(DELAY).await; tokio::time::sleep(DELAY).await;
continue; continue;
} }
for (chat, _) in history.iter().take_while(|(_, time)| time > &sec_back) { for (chat, _) in when_requests_were_sent
*hchats_s.entry(*chat).or_insert(0) += 1; .iter()
.take_while(|(_, time)| time > &sec_back)
{
*requests_sent_to_chats.per_sec.entry(*chat).or_insert(0) += 1;
} }
{ let mut queue_removing_iter = queue.removing();
let mut queue_rem = queue.removing();
while let Some(entry) = queue_rem.next() { while let Some(entry) = queue_removing_iter.next() {
let chat = &entry.value().0; let chat = &entry.value().0;
let cond = { let messages_sent = requests_sent_to_chats
hchats_s.get(chat).copied().unwrap_or(0) < limits.messages_per_sec_chat .per_sec
&& hchats.get(chat).copied().unwrap_or(0) < limits.messages_per_min_chat .get(chat)
}; .copied()
.unwrap_or(0);
let limits_not_exceeded = messages_sent < limits.messages_per_sec_chat
&& messages_sent < limits.messages_per_min_chat;
if cond { if limits_not_exceeded {
{ *requests_sent_to_chats.per_sec.entry(*chat).or_insert(0) += 1;
*hchats_s.entry(*chat).or_insert(0) += 1; *requests_sent_to_chats.per_min.entry(*chat).or_insert(0) += 1;
*hchats.entry(*chat).or_insert(0) += 1; when_requests_were_sent.push_back((*chat, Instant::now()));
history.push_back((*chat, Instant::now()));
}
// This will close the channel unlocking associated request // Close the channel and unlock the associated request.
drop(entry.remove()); drop(entry.remove());
// We've "sent" 1 request, so now we can send 1 less // We have "sent" one request, so now we can send one less.
allowed -= 1; allowed -= 1;
if allowed == 0 { if allowed == 0 {
break; break;
} }
} }
} }
}
// It's easier to just recompute last second stats, instead of keeping // It's easier to just recompute last second stats, instead of keeping
// track of it alongside with minute stats, so we just throw this away. // track of it alongside with minute stats, so we just throw this away.
hchats_s.clear(); requests_sent_to_chats.per_sec.clear();
tokio::time::sleep(DELAY).await; tokio::time::sleep(DELAY).await;
} }
} }
async fn read_from_rx(
rx: &mut mpsc::Receiver<(FastChatId, Sender<Never>)>,
queue: &mut Vec<(FastChatId, Sender<Never>)>,
rx_is_closed: &mut bool,
) {
if queue.is_empty() {
match rx.recv().await {
Some(req) => queue.push(req),
None => *rx_is_closed = true,
}
}
loop {
// FIXME(waffle): https://github.com/tokio-rs/tokio/issues/3350
match rx.recv().now_or_never() {
Some(Some(req)) => queue.push(req),
Some(None) => *rx_is_closed = true,
// There are no items in queue.
None => break,
}
}
}
impl<B> Throttle<B> { impl<B> Throttle<B> {
/// Creates new [`Throttle`] alongside with worker future. /// Creates new [`Throttle`] alongside with worker future.
/// ///
/// Note: [`Throttle`] will only send requests if returned worker is /// Note: [`Throttle`] will only send requests if returned worker is
/// polled/spawned/awaited. /// polled/spawned/awaited.
pub fn new(bot: B, limits: Limits) -> (Self, impl Future<Output = ()>) { pub fn new(bot: B, limits: Limits) -> (Self, impl Future<Output = ()>) {
// A buffer made slightly bigger (112.5%) than overall limit let (tx, rx) = mpsc::channel(limits.messages_per_sec_overall as usize);
// so we won't lose performance when hitting limits.
//
// (I hope this makes sense) (waffle)
let buffer_size = limits.messages_per_sec_overall + (limits.messages_per_sec_overall / 8);
let (tx, rx) = mpsc::channel(buffer_size as usize);
let worker = worker(limits, rx); let worker = worker(limits, rx);
let this = Self { bot, queue: tx }; let this = Self { bot, queue: tx };
@ -448,7 +458,7 @@ where
prices, prices,
), ),
self.queue.clone(), self.queue.clone(),
|p| Id::Id(p.payload_ref().chat_id as _), |p| FastChatId::Id(p.payload_ref().chat_id as _),
) )
} }
@ -482,25 +492,25 @@ download_forward! {
{ this => this.inner() } { this => this.inner() }
} }
/// Id used in worker. /// An ID used in the worker.
/// ///
/// It is used instead of `ChatId` to make copying cheap even in case of /// It is used instead of `ChatId` to make copying cheap even in case of
/// usernames. (It just hashes username) /// usernames. (It is just a hashed username.)
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)] #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
enum Id { enum FastChatId {
Id(i64), Id(i64),
ChannelUsernameHash(u64), ChannelUsernameHash(u64),
} }
impl From<&ChatId> for Id { impl From<&ChatId> for FastChatId {
fn from(value: &ChatId) -> Self { fn from(value: &ChatId) -> Self {
match value { match value {
ChatId::Id(id) => Id::Id(*id), ChatId::Id(id) => FastChatId::Id(*id),
ChatId::ChannelUsername(username) => { ChatId::ChannelUsername(username) => {
let mut hasher = std::collections::hash_map::DefaultHasher::new(); let mut hasher = std::collections::hash_map::DefaultHasher::new();
username.hash(&mut hasher); username.hash(&mut hasher);
let hash = hasher.finish(); let hash = hasher.finish();
Id::ChannelUsernameHash(hash) FastChatId::ChannelUsernameHash(hash)
} }
} }
} }
@ -508,8 +518,8 @@ impl From<&ChatId> for Id {
pub struct ThrottlingRequest<R: HasPayload>( pub struct ThrottlingRequest<R: HasPayload>(
R, R,
mpsc::Sender<(Id, Sender<Never>)>, mpsc::Sender<(FastChatId, Sender<Never>)>,
fn(&R::Payload) -> Id, fn(&R::Payload) -> FastChatId,
); );
impl<R: HasPayload> HasPayload for ThrottlingRequest<R> { impl<R: HasPayload> HasPayload for ThrottlingRequest<R> {
@ -743,29 +753,30 @@ mod chan_send {
use never::Never; use never::Never;
use tokio::sync::{mpsc, mpsc::error::SendError, oneshot::Sender}; use tokio::sync::{mpsc, mpsc::error::SendError, oneshot::Sender};
use crate::adaptors::throttle::Id; use crate::adaptors::throttle::FastChatId;
pub(super) trait SendTy { pub(super) trait SendTy {
fn send_t(self, val: (Id, Sender<Never>)) -> ChanSend; fn send_t(self, val: (FastChatId, Sender<Never>)) -> ChanSend;
} }
#[pin_project::pin_project] #[pin_project::pin_project]
pub(super) struct ChanSend(#[pin] Inner); pub(super) struct ChanSend(#[pin] Inner);
#[cfg(not(feature = "nightly"))] #[cfg(not(feature = "nightly"))]
type Inner = Pin<Box<dyn Future<Output = Result<(), SendError<(Id, Sender<Never>)>>> + Send>>; type Inner =
Pin<Box<dyn Future<Output = Result<(), SendError<(FastChatId, Sender<Never>)>>> + Send>>;
#[cfg(feature = "nightly")] #[cfg(feature = "nightly")]
type Inner = impl Future<Output = Result<(), SendError<(Id, Sender<Never>)>>>; type Inner = impl Future<Output = Result<(), SendError<(FastChatId, Sender<Never>)>>>;
impl SendTy for mpsc::Sender<(Id, Sender<Never>)> { impl SendTy for mpsc::Sender<(FastChatId, Sender<Never>)> {
// `return`s trick IDEA not to show errors // `return`s trick IDEA not to show errors
#[allow(clippy::needless_return)] #[allow(clippy::needless_return)]
fn send_t(self, val: (Id, Sender<Never>)) -> ChanSend { fn send_t(self, val: (FastChatId, Sender<Never>)) -> ChanSend {
#[cfg(feature = "nightly")] #[cfg(feature = "nightly")]
{ {
fn def( fn def(
sender: mpsc::Sender<(Id, Sender<Never>)>, sender: mpsc::Sender<(FastChatId, Sender<Never>)>,
val: (Id, Sender<Never>), val: (FastChatId, Sender<Never>),
) -> Inner { ) -> Inner {
async move { sender.send(val).await } async move { sender.send(val).await }
} }
@ -780,7 +791,7 @@ mod chan_send {
} }
impl Future for ChanSend { impl Future for ChanSend {
type Output = Result<(), SendError<(Id, Sender<Never>)>>; type Output = Result<(), SendError<(FastChatId, Sender<Never>)>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().0.poll(cx) self.project().0.poll(cx)