Refactoring

This commit is contained in:
Temirkhan Myrzamadi 2021-02-16 00:50:46 +06:00
parent d8c16d420d
commit d58b21c42f

View file

@ -23,48 +23,47 @@ use crate::{
types::*,
};
// Throttling is quite complicated this comment describes the algorithm of
// current implementation. NOTE: this only describes CURRENT implementation.
// Implementation may change at any time.
// Throttling is quite complicated. This comment describes the algorithm of the
// current implementation.
//
// ### Request
//
// When throttling request is sent, it sends a tuple of `ChatId` (more
// accurately, just local `Id`) and `Sender<()>` to the worker. Then the request
// waits for notification from worker. When notification is received it sends
// underlying request.
// When a throttling request is sent, it sends a tuple of `ChatId` and
// `Sender<()>` to the worker. Then the request waits for a notification from
// the worker. When notification is received, it sends the underlying request.
//
// ### Worker
//
// Worker does the most important job - it checks for limit exceed.
// The worker does the most important job -- it ensures that the limits are
// never exceeded.
//
// The worker stores "history" of requests sent in last minute (and to which
// chats the were sent) and queue of pending updates.
// The worker stores a history of requests sent in the last minute (and to which
// chats they were sent) and a queue of pending updates.
//
// The worker does the following algorithm loop:
//
// 1. If queue is empty wait for the first message in incoming channel (and adds
// it to queue).
// 1. If the queue is empty, wait for the first message in incoming channel (and
// add it to the queue).
//
// 2. Read all present messages from incoming channel and transfer them to
// queue.
// 2. Read all present messages from an incoming channel and transfer them to
// the queue.
//
// 3. Record current time.
// 3. Record the current time.
//
// 4. Clear history from records which time < (current - minute)
// 4. Clear the history from records whose time < (current time - minute).
//
// 5. Count all requests in which were sent last second,
// `allowed = limit.overall_s - count`
// 5. Count all requests which were sent last second, `allowed = limit.overall_s
// - count`.
//
// 6. If `allowed == 0` wait a bit and `continue` to the next iteration
// 6. If `allowed == 0` wait a bit and `continue` to the next iteration.
//
// 7. Count how many requests were sent to which chats (i.e.: create
// `Map<ChatId, Count>`) (note: the same map, but for last minute also
// exists, but it's updated, instead of recreation)
// `Map<ChatId, Count>`). (Note: the same map, but for last minute also exists,
// but it's updated, instead of recreation.)
//
// 8. While `allowed >= 0` search for requests which chat hasn't exceed limits
// (i.e.: map[chat] < limit), if one is found, decrease `allowed`, notify
// request that it can be now executed, increase counts, add record to
// 8. While `allowed >= 0` search for requests which chat haven't exceed the
// limits (i.e.: map[chat] < limit), if one is found, decrease `allowed`, notify
// the request that it can be now executed, increase counts, add record to the
// history.
const MINUTE: Duration = Duration::from_secs(60);
@ -151,47 +150,34 @@ impl Default for Limits {
pub struct Throttle<B> {
bot: B,
// Sender<Never> is used to pass the signal to unlock by closing the channel.
queue: mpsc::Sender<(Id, Sender<Never>)>,
queue: mpsc::Sender<(FastChatId, Sender<Never>)>,
}
async fn worker(limits: Limits, mut rx: mpsc::Receiver<(Id, Sender<Never>)>) {
// +- Same idea as in `Throttle::new`
let capacity = limits.messages_per_sec_overall + (limits.messages_per_sec_overall / 4);
type RequestsSent = u32;
// I wish there was special data structure for history which removed the
// need in 2 hashmaps
// (waffle)
#[derive(Default)]
struct RequestsSentToChats {
per_min: HashMap<FastChatId, RequestsSent>,
per_sec: HashMap<FastChatId, RequestsSent>,
}
async fn worker(limits: Limits, mut rx: mpsc::Receiver<(FastChatId, Sender<Never>)>) {
// FIXME(waffle): Make an research about data structures for this queue.
// Currently this is O(n) removing (n = number of elements
// stayed), amortized O(1) push (vec+vecrem).
let mut queue: Vec<(Id, Sender<Never>)> = Vec::with_capacity(capacity as usize);
let mut queue: Vec<(FastChatId, Sender<Never>)> =
Vec::with_capacity(limits.messages_per_sec_overall as usize);
// I wish there was special data structure for history which removed the
// need in 2 hashmaps
// (waffle)
let mut history: VecDeque<(Id, Instant)> = VecDeque::new();
let mut hchats: HashMap<Id, u32> = HashMap::new();
let mut hchats_s = HashMap::new();
let mut when_requests_were_sent: VecDeque<(FastChatId, Instant)> = VecDeque::new();
let mut requests_sent_to_chats = RequestsSentToChats::default();
// set to true when `rx` is closed
let mut close = false;
let mut rx_is_closed = false;
while !close || !queue.is_empty() {
// If there are no pending requests we are just waiting
if queue.is_empty() {
match rx.recv().await {
Some(req) => queue.push(req),
None => close = true,
}
}
// update local queue with latest requests
loop {
// FIXME(waffle): https://github.com/tokio-rs/tokio/issues/3350
match rx.recv().now_or_never() {
Some(Some(req)) => queue.push(req),
// There are no items in queue
None => break,
// The queue was closed
Some(None) => close = true,
}
}
while !rx_is_closed || !queue.is_empty() {
read_from_rx(&mut rx, &mut queue, &mut rx_is_closed).await;
// _Maybe_ we need to use `spawn_blocking` here, because there is
// decent amount of blocking work. However _for now_ I've decided not
@ -231,18 +217,21 @@ async fn worker(limits: Limits, mut rx: mpsc::Receiver<(Id, Sender<Never>)>) {
let sec_back = now - SECOND;
// make history and hchats up-to-date
while let Some((_, time)) = history.front() {
while let Some((_, time)) = when_requests_were_sent.front() {
// history is sorted, we found first up-to-date thing
if time >= &min_back {
break;
}
if let Some((chat, _)) = history.pop_front() {
let ent = hchats.entry(chat).and_modify(|count| {
if let Some((chat, _)) = when_requests_were_sent.pop_front() {
let entry = requests_sent_to_chats
.per_min
.entry(chat)
.and_modify(|count| {
*count -= 1;
});
if let Entry::Occupied(entry) = ent {
if let Entry::Occupied(entry) = entry {
if *entry.get() == 0 {
entry.remove_entry();
}
@ -252,69 +241,90 @@ async fn worker(limits: Limits, mut rx: mpsc::Receiver<(Id, Sender<Never>)>) {
// as truncates which is ok since in case of truncation it would always be >=
// limits.overall_s
let used = history
let used = when_requests_were_sent
.iter()
.take_while(|(_, time)| time > &sec_back)
.count() as u32;
let mut allowed = limits.messages_per_sec_overall.saturating_sub(used);
if allowed == 0 {
hchats_s.clear();
requests_sent_to_chats.per_sec.clear();
tokio::time::sleep(DELAY).await;
continue;
}
for (chat, _) in history.iter().take_while(|(_, time)| time > &sec_back) {
*hchats_s.entry(*chat).or_insert(0) += 1;
for (chat, _) in when_requests_were_sent
.iter()
.take_while(|(_, time)| time > &sec_back)
{
*requests_sent_to_chats.per_sec.entry(*chat).or_insert(0) += 1;
}
{
let mut queue_rem = queue.removing();
while let Some(entry) = queue_rem.next() {
let mut queue_removing_iter = queue.removing();
while let Some(entry) = queue_removing_iter.next() {
let chat = &entry.value().0;
let cond = {
hchats_s.get(chat).copied().unwrap_or(0) < limits.messages_per_sec_chat
&& hchats.get(chat).copied().unwrap_or(0) < limits.messages_per_min_chat
};
let messages_sent = requests_sent_to_chats
.per_sec
.get(chat)
.copied()
.unwrap_or(0);
let limits_not_exceeded = messages_sent < limits.messages_per_sec_chat
&& messages_sent < limits.messages_per_min_chat;
if cond {
{
*hchats_s.entry(*chat).or_insert(0) += 1;
*hchats.entry(*chat).or_insert(0) += 1;
history.push_back((*chat, Instant::now()));
}
if limits_not_exceeded {
*requests_sent_to_chats.per_sec.entry(*chat).or_insert(0) += 1;
*requests_sent_to_chats.per_min.entry(*chat).or_insert(0) += 1;
when_requests_were_sent.push_back((*chat, Instant::now()));
// This will close the channel unlocking associated request
// Close the channel and unlock the associated request.
drop(entry.remove());
// We've "sent" 1 request, so now we can send 1 less
// We have "sent" one request, so now we can send one less.
allowed -= 1;
if allowed == 0 {
break;
}
}
}
}
// It's easier to just recompute last second stats, instead of keeping
// track of it alongside with minute stats, so we just throw this away.
hchats_s.clear();
requests_sent_to_chats.per_sec.clear();
tokio::time::sleep(DELAY).await;
}
}
async fn read_from_rx(
rx: &mut mpsc::Receiver<(FastChatId, Sender<Never>)>,
queue: &mut Vec<(FastChatId, Sender<Never>)>,
rx_is_closed: &mut bool,
) {
if queue.is_empty() {
match rx.recv().await {
Some(req) => queue.push(req),
None => *rx_is_closed = true,
}
}
loop {
// FIXME(waffle): https://github.com/tokio-rs/tokio/issues/3350
match rx.recv().now_or_never() {
Some(Some(req)) => queue.push(req),
Some(None) => *rx_is_closed = true,
// There are no items in queue.
None => break,
}
}
}
impl<B> Throttle<B> {
/// Creates new [`Throttle`] alongside with worker future.
///
/// Note: [`Throttle`] will only send requests if returned worker is
/// polled/spawned/awaited.
pub fn new(bot: B, limits: Limits) -> (Self, impl Future<Output = ()>) {
// A buffer made slightly bigger (112.5%) than overall limit
// so we won't lose performance when hitting limits.
//
// (I hope this makes sense) (waffle)
let buffer_size = limits.messages_per_sec_overall + (limits.messages_per_sec_overall / 8);
let (tx, rx) = mpsc::channel(buffer_size as usize);
let (tx, rx) = mpsc::channel(limits.messages_per_sec_overall as usize);
let worker = worker(limits, rx);
let this = Self { bot, queue: tx };
@ -448,7 +458,7 @@ where
prices,
),
self.queue.clone(),
|p| Id::Id(p.payload_ref().chat_id as _),
|p| FastChatId::Id(p.payload_ref().chat_id as _),
)
}
@ -482,25 +492,25 @@ download_forward! {
{ this => this.inner() }
}
/// Id used in worker.
/// An ID used in the worker.
///
/// It is used instead of `ChatId` to make copying cheap even in case of
/// usernames. (It just hashes username)
/// usernames. (It is just a hashed username.)
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
enum Id {
enum FastChatId {
Id(i64),
ChannelUsernameHash(u64),
}
impl From<&ChatId> for Id {
impl From<&ChatId> for FastChatId {
fn from(value: &ChatId) -> Self {
match value {
ChatId::Id(id) => Id::Id(*id),
ChatId::Id(id) => FastChatId::Id(*id),
ChatId::ChannelUsername(username) => {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
username.hash(&mut hasher);
let hash = hasher.finish();
Id::ChannelUsernameHash(hash)
FastChatId::ChannelUsernameHash(hash)
}
}
}
@ -508,8 +518,8 @@ impl From<&ChatId> for Id {
pub struct ThrottlingRequest<R: HasPayload>(
R,
mpsc::Sender<(Id, Sender<Never>)>,
fn(&R::Payload) -> Id,
mpsc::Sender<(FastChatId, Sender<Never>)>,
fn(&R::Payload) -> FastChatId,
);
impl<R: HasPayload> HasPayload for ThrottlingRequest<R> {
@ -743,29 +753,30 @@ mod chan_send {
use never::Never;
use tokio::sync::{mpsc, mpsc::error::SendError, oneshot::Sender};
use crate::adaptors::throttle::Id;
use crate::adaptors::throttle::FastChatId;
pub(super) trait SendTy {
fn send_t(self, val: (Id, Sender<Never>)) -> ChanSend;
fn send_t(self, val: (FastChatId, Sender<Never>)) -> ChanSend;
}
#[pin_project::pin_project]
pub(super) struct ChanSend(#[pin] Inner);
#[cfg(not(feature = "nightly"))]
type Inner = Pin<Box<dyn Future<Output = Result<(), SendError<(Id, Sender<Never>)>>> + Send>>;
type Inner =
Pin<Box<dyn Future<Output = Result<(), SendError<(FastChatId, Sender<Never>)>>> + Send>>;
#[cfg(feature = "nightly")]
type Inner = impl Future<Output = Result<(), SendError<(Id, Sender<Never>)>>>;
type Inner = impl Future<Output = Result<(), SendError<(FastChatId, Sender<Never>)>>>;
impl SendTy for mpsc::Sender<(Id, Sender<Never>)> {
impl SendTy for mpsc::Sender<(FastChatId, Sender<Never>)> {
// `return`s trick IDEA not to show errors
#[allow(clippy::needless_return)]
fn send_t(self, val: (Id, Sender<Never>)) -> ChanSend {
fn send_t(self, val: (FastChatId, Sender<Never>)) -> ChanSend {
#[cfg(feature = "nightly")]
{
fn def(
sender: mpsc::Sender<(Id, Sender<Never>)>,
val: (Id, Sender<Never>),
sender: mpsc::Sender<(FastChatId, Sender<Never>)>,
val: (FastChatId, Sender<Never>),
) -> Inner {
async move { sender.send(val).await }
}
@ -780,7 +791,7 @@ mod chan_send {
}
impl Future for ChanSend {
type Output = Result<(), SendError<(Id, Sender<Never>)>>;
type Output = Result<(), SendError<(FastChatId, Sender<Never>)>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().0.poll(cx)