diff --git a/src/adaptors/throttle.rs b/src/adaptors/throttle.rs index 63bbc842..6aca69d5 100644 --- a/src/adaptors/throttle.rs +++ b/src/adaptors/throttle.rs @@ -23,49 +23,48 @@ use crate::{ types::*, }; -// Throttling is quite complicated this comment describes the algorithm of -// current implementation. NOTE: this only describes CURRENT implementation. -// Implementation may change at any time. +// Throttling is quite complicated. This comment describes the algorithm of the +// current implementation. // // ### Request // -// When throttling request is sent, it sends a tuple of `ChatId` (more -// accurately, just local `Id`) and `Sender<()>` to the worker. Then the request -// waits for notification from worker. When notification is received it sends -// underlying request. +// When a throttling request is sent, it sends a tuple of `ChatId` and +// `Sender<()>` to the worker. Then the request waits for a notification from +// the worker. When notification is received, it sends the underlying request. // // ### Worker // -// Worker does the most important job - it checks for limit exceed. +// The worker does the most important job -- it ensures that the limits are +// never exceeded. // -// The worker stores "history" of requests sent in last minute (and to which -// chats the were sent) and queue of pending updates. +// The worker stores a history of requests sent in the last minute (and to which +// chats they were sent) and a queue of pending updates. // // The worker does the following algorithm loop: // -// 1. If queue is empty wait for the first message in incoming channel (and adds -// it to queue). +// 1. If the queue is empty, wait for the first message in incoming channel (and +// add it to the queue). // -// 2. Read all present messages from incoming channel and transfer them to -// queue. +// 2. Read all present messages from an incoming channel and transfer them to +// the queue. // -// 3. Record current time. +// 3. Record the current time. // -// 4. Clear history from records which time < (current - minute) +// 4. Clear the history from records whose time < (current time - minute). // -// 5. Count all requests in which were sent last second, -// `allowed = limit.overall_s - count` +// 5. Count all requests which were sent last second, `allowed = limit.overall_s +// - count`. // -// 6. If `allowed == 0` wait a bit and `continue` to the next iteration +// 6. If `allowed == 0` wait a bit and `continue` to the next iteration. // // 7. Count how many requests were sent to which chats (i.e.: create -// `Map`) (note: the same map, but for last minute also -// exists, but it's updated, instead of recreation) +// `Map`). (Note: the same map, but for last minute also exists, +// but it's updated, instead of recreation.) // -// 8. While `allowed >= 0` search for requests which chat hasn't exceed limits -// (i.e.: map[chat] < limit), if one is found, decrease `allowed`, notify -// request that it can be now executed, increase counts, add record to -// history. +// 8. While `allowed >= 0` search for requests which chat haven't exceed the +// limits (i.e.: map[chat] < limit), if one is found, decrease `allowed`, notify +// the request that it can be now executed, increase counts, add record to the +// history. const MINUTE: Duration = Duration::from_secs(60); const SECOND: Duration = Duration::from_secs(1); @@ -151,47 +150,34 @@ impl Default for Limits { pub struct Throttle { bot: B, // Sender is used to pass the signal to unlock by closing the channel. - queue: mpsc::Sender<(Id, Sender)>, + queue: mpsc::Sender<(FastChatId, Sender)>, } -async fn worker(limits: Limits, mut rx: mpsc::Receiver<(Id, Sender)>) { - // +- Same idea as in `Throttle::new` - let capacity = limits.messages_per_sec_overall + (limits.messages_per_sec_overall / 4); +type RequestsSent = u32; + +// I wish there was special data structure for history which removed the +// need in 2 hashmaps +// (waffle) +#[derive(Default)] +struct RequestsSentToChats { + per_min: HashMap, + per_sec: HashMap, +} + +async fn worker(limits: Limits, mut rx: mpsc::Receiver<(FastChatId, Sender)>) { // FIXME(waffle): Make an research about data structures for this queue. // Currently this is O(n) removing (n = number of elements // stayed), amortized O(1) push (vec+vecrem). - let mut queue: Vec<(Id, Sender)> = Vec::with_capacity(capacity as usize); + let mut queue: Vec<(FastChatId, Sender)> = + Vec::with_capacity(limits.messages_per_sec_overall as usize); - // I wish there was special data structure for history which removed the - // need in 2 hashmaps - // (waffle) - let mut history: VecDeque<(Id, Instant)> = VecDeque::new(); - let mut hchats: HashMap = HashMap::new(); - let mut hchats_s = HashMap::new(); + let mut when_requests_were_sent: VecDeque<(FastChatId, Instant)> = VecDeque::new(); + let mut requests_sent_to_chats = RequestsSentToChats::default(); - // set to true when `rx` is closed - let mut close = false; + let mut rx_is_closed = false; - while !close || !queue.is_empty() { - // If there are no pending requests we are just waiting - if queue.is_empty() { - match rx.recv().await { - Some(req) => queue.push(req), - None => close = true, - } - } - - // update local queue with latest requests - loop { - // FIXME(waffle): https://github.com/tokio-rs/tokio/issues/3350 - match rx.recv().now_or_never() { - Some(Some(req)) => queue.push(req), - // There are no items in queue - None => break, - // The queue was closed - Some(None) => close = true, - } - } + while !rx_is_closed || !queue.is_empty() { + read_from_rx(&mut rx, &mut queue, &mut rx_is_closed).await; // _Maybe_ we need to use `spawn_blocking` here, because there is // decent amount of blocking work. However _for now_ I've decided not @@ -231,18 +217,21 @@ async fn worker(limits: Limits, mut rx: mpsc::Receiver<(Id, Sender)>) { let sec_back = now - SECOND; // make history and hchats up-to-date - while let Some((_, time)) = history.front() { + while let Some((_, time)) = when_requests_were_sent.front() { // history is sorted, we found first up-to-date thing if time >= &min_back { break; } - if let Some((chat, _)) = history.pop_front() { - let ent = hchats.entry(chat).and_modify(|count| { - *count -= 1; - }); + if let Some((chat, _)) = when_requests_were_sent.pop_front() { + let entry = requests_sent_to_chats + .per_min + .entry(chat) + .and_modify(|count| { + *count -= 1; + }); - if let Entry::Occupied(entry) = ent { + if let Entry::Occupied(entry) = entry { if *entry.get() == 0 { entry.remove_entry(); } @@ -252,69 +241,90 @@ async fn worker(limits: Limits, mut rx: mpsc::Receiver<(Id, Sender)>) { // as truncates which is ok since in case of truncation it would always be >= // limits.overall_s - let used = history + let used = when_requests_were_sent .iter() .take_while(|(_, time)| time > &sec_back) .count() as u32; let mut allowed = limits.messages_per_sec_overall.saturating_sub(used); if allowed == 0 { - hchats_s.clear(); + requests_sent_to_chats.per_sec.clear(); tokio::time::sleep(DELAY).await; continue; } - for (chat, _) in history.iter().take_while(|(_, time)| time > &sec_back) { - *hchats_s.entry(*chat).or_insert(0) += 1; + for (chat, _) in when_requests_were_sent + .iter() + .take_while(|(_, time)| time > &sec_back) + { + *requests_sent_to_chats.per_sec.entry(*chat).or_insert(0) += 1; } - { - let mut queue_rem = queue.removing(); - while let Some(entry) = queue_rem.next() { - let chat = &entry.value().0; - let cond = { - hchats_s.get(chat).copied().unwrap_or(0) < limits.messages_per_sec_chat - && hchats.get(chat).copied().unwrap_or(0) < limits.messages_per_min_chat - }; + let mut queue_removing_iter = queue.removing(); - if cond { - { - *hchats_s.entry(*chat).or_insert(0) += 1; - *hchats.entry(*chat).or_insert(0) += 1; - history.push_back((*chat, Instant::now())); - } + while let Some(entry) = queue_removing_iter.next() { + let chat = &entry.value().0; + let messages_sent = requests_sent_to_chats + .per_sec + .get(chat) + .copied() + .unwrap_or(0); + let limits_not_exceeded = messages_sent < limits.messages_per_sec_chat + && messages_sent < limits.messages_per_min_chat; - // This will close the channel unlocking associated request - drop(entry.remove()); + if limits_not_exceeded { + *requests_sent_to_chats.per_sec.entry(*chat).or_insert(0) += 1; + *requests_sent_to_chats.per_min.entry(*chat).or_insert(0) += 1; + when_requests_were_sent.push_back((*chat, Instant::now())); - // We've "sent" 1 request, so now we can send 1 less - allowed -= 1; - if allowed == 0 { - break; - } + // Close the channel and unlock the associated request. + drop(entry.remove()); + + // We have "sent" one request, so now we can send one less. + allowed -= 1; + if allowed == 0 { + break; } } } // It's easier to just recompute last second stats, instead of keeping // track of it alongside with minute stats, so we just throw this away. - hchats_s.clear(); + requests_sent_to_chats.per_sec.clear(); tokio::time::sleep(DELAY).await; } } +async fn read_from_rx( + rx: &mut mpsc::Receiver<(FastChatId, Sender)>, + queue: &mut Vec<(FastChatId, Sender)>, + rx_is_closed: &mut bool, +) { + if queue.is_empty() { + match rx.recv().await { + Some(req) => queue.push(req), + None => *rx_is_closed = true, + } + } + + loop { + // FIXME(waffle): https://github.com/tokio-rs/tokio/issues/3350 + match rx.recv().now_or_never() { + Some(Some(req)) => queue.push(req), + Some(None) => *rx_is_closed = true, + // There are no items in queue. + None => break, + } + } +} + impl Throttle { /// Creates new [`Throttle`] alongside with worker future. /// /// Note: [`Throttle`] will only send requests if returned worker is /// polled/spawned/awaited. pub fn new(bot: B, limits: Limits) -> (Self, impl Future) { - // A buffer made slightly bigger (112.5%) than overall limit - // so we won't lose performance when hitting limits. - // - // (I hope this makes sense) (waffle) - let buffer_size = limits.messages_per_sec_overall + (limits.messages_per_sec_overall / 8); - let (tx, rx) = mpsc::channel(buffer_size as usize); + let (tx, rx) = mpsc::channel(limits.messages_per_sec_overall as usize); let worker = worker(limits, rx); let this = Self { bot, queue: tx }; @@ -448,7 +458,7 @@ where prices, ), self.queue.clone(), - |p| Id::Id(p.payload_ref().chat_id as _), + |p| FastChatId::Id(p.payload_ref().chat_id as _), ) } @@ -482,25 +492,25 @@ download_forward! { { this => this.inner() } } -/// Id used in worker. +/// An ID used in the worker. /// /// It is used instead of `ChatId` to make copying cheap even in case of -/// usernames. (It just hashes username) +/// usernames. (It is just a hashed username.) #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)] -enum Id { +enum FastChatId { Id(i64), ChannelUsernameHash(u64), } -impl From<&ChatId> for Id { +impl From<&ChatId> for FastChatId { fn from(value: &ChatId) -> Self { match value { - ChatId::Id(id) => Id::Id(*id), + ChatId::Id(id) => FastChatId::Id(*id), ChatId::ChannelUsername(username) => { let mut hasher = std::collections::hash_map::DefaultHasher::new(); username.hash(&mut hasher); let hash = hasher.finish(); - Id::ChannelUsernameHash(hash) + FastChatId::ChannelUsernameHash(hash) } } } @@ -508,8 +518,8 @@ impl From<&ChatId> for Id { pub struct ThrottlingRequest( R, - mpsc::Sender<(Id, Sender)>, - fn(&R::Payload) -> Id, + mpsc::Sender<(FastChatId, Sender)>, + fn(&R::Payload) -> FastChatId, ); impl HasPayload for ThrottlingRequest { @@ -743,29 +753,30 @@ mod chan_send { use never::Never; use tokio::sync::{mpsc, mpsc::error::SendError, oneshot::Sender}; - use crate::adaptors::throttle::Id; + use crate::adaptors::throttle::FastChatId; pub(super) trait SendTy { - fn send_t(self, val: (Id, Sender)) -> ChanSend; + fn send_t(self, val: (FastChatId, Sender)) -> ChanSend; } #[pin_project::pin_project] pub(super) struct ChanSend(#[pin] Inner); #[cfg(not(feature = "nightly"))] - type Inner = Pin)>>> + Send>>; + type Inner = + Pin)>>> + Send>>; #[cfg(feature = "nightly")] - type Inner = impl Future)>>>; + type Inner = impl Future)>>>; - impl SendTy for mpsc::Sender<(Id, Sender)> { + impl SendTy for mpsc::Sender<(FastChatId, Sender)> { // `return`s trick IDEA not to show errors #[allow(clippy::needless_return)] - fn send_t(self, val: (Id, Sender)) -> ChanSend { + fn send_t(self, val: (FastChatId, Sender)) -> ChanSend { #[cfg(feature = "nightly")] { fn def( - sender: mpsc::Sender<(Id, Sender)>, - val: (Id, Sender), + sender: mpsc::Sender<(FastChatId, Sender)>, + val: (FastChatId, Sender), ) -> Inner { async move { sender.send(val).await } } @@ -780,7 +791,7 @@ mod chan_send { } impl Future for ChanSend { - type Output = Result<(), SendError<(Id, Sender)>>; + type Output = Result<(), SendError<(FastChatId, Sender)>>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.project().0.poll(cx)