Improve Throttling adoptor

- Freeze when getting `RetryAfter(_)` error
- Retry requests that previously returned `RetryAfter(_)` error
This commit is contained in:
Maybe Waffle 2021-10-24 15:46:50 +03:00
parent 6a91c44836
commit 43802a5c41
3 changed files with 424 additions and 193 deletions

View file

@ -20,6 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed ### Changed
- Improve `Throttling` adoptor
- Freeze when getting `RetryAfter(_)` error
- Retry requests that previously returned `RetryAfter(_)` error
- `RequestError::RetryAfter` now has a `u32` field instead of `i32` - `RequestError::RetryAfter` now has a `u32` field instead of `i32`
### Added ### Added

View file

@ -11,7 +11,6 @@ use futures::{
task::{Context, Poll}, task::{Context, Poll},
FutureExt, FutureExt,
}; };
use never::Never;
use tokio::sync::{ use tokio::sync::{
mpsc, mpsc,
oneshot::{self, Receiver, Sender}, oneshot::{self, Receiver, Sender},
@ -21,6 +20,7 @@ use vecrem::VecExt;
use crate::{ use crate::{
adaptors::throttle::chan_send::{ChanSend, MpscSend}, adaptors::throttle::chan_send::{ChanSend, MpscSend},
errors::AsResponseParameters,
requests::{HasPayload, Output, Request, Requester}, requests::{HasPayload, Output, Request, Requester},
types::*, types::*,
}; };
@ -76,7 +76,11 @@ impl<B> Throttle<B> {
/// ///
/// Note: [`Throttle`] will only send requests if returned worker is /// Note: [`Throttle`] will only send requests if returned worker is
/// polled/spawned/awaited. /// polled/spawned/awaited.
pub fn new(bot: B, limits: Limits) -> (Self, impl Future<Output = ()>) { pub fn new(bot: B, limits: Limits) -> (Self, impl Future<Output = ()>)
where
B: Requester + Clone,
B::Err: AsResponseParameters,
{
let settings = Settings { let settings = Settings {
limits, limits,
..<_>::default() ..<_>::default()
@ -88,11 +92,15 @@ impl<B> Throttle<B> {
/// ///
/// Note: [`Throttle`] will only send requests if returned worker is /// Note: [`Throttle`] will only send requests if returned worker is
/// polled/spawned/awaited. /// polled/spawned/awaited.
pub fn with_settings(bot: B, settings: Settings) -> (Self, impl Future<Output = ()>) { pub fn with_settings(bot: B, settings: Settings) -> (Self, impl Future<Output = ()>)
where
B: Requester + Clone,
B::Err: AsResponseParameters,
{
let (tx, rx) = mpsc::channel(settings.limits.messages_per_sec_overall as usize); let (tx, rx) = mpsc::channel(settings.limits.messages_per_sec_overall as usize);
let (info_tx, info_rx) = mpsc::channel(2); let (info_tx, info_rx) = mpsc::channel(2);
let worker = worker(settings, rx, info_rx); let worker = worker(settings, rx, info_rx, bot.clone());
let this = Self { let this = Self {
bot, bot,
queue: tx, queue: tx,
@ -107,44 +115,27 @@ impl<B> Throttle<B> {
/// Note: it's recommended to use [`RequesterExt::throttle`] instead. /// Note: it's recommended to use [`RequesterExt::throttle`] instead.
/// ///
/// [`RequesterExt::throttle`]: crate::requests::RequesterExt::throttle /// [`RequesterExt::throttle`]: crate::requests::RequesterExt::throttle
pub fn new_spawn(bot: B, limits: Limits) -> Self { pub fn new_spawn(bot: B, limits: Limits) -> Self
// new/with_settings copy-pasted here to avoid [rust-lang/#76882] where
// B: Requester + Clone + Send + Sync + 'static,
// [rust-lang/#76882]: https://github.com/rust-lang/rust/issues/76882 B::Err: AsResponseParameters,
B::GetChat: Send,
let (tx, rx) = mpsc::channel(limits.messages_per_sec_overall as usize); {
let (info_tx, info_rx) = mpsc::channel(2); let (this, worker) = Self::new(bot, limits);
let settings = Settings {
limits,
..<_>::default()
};
let worker = worker(settings, rx, info_rx);
let this = Self {
bot,
queue: tx,
info_tx,
};
tokio::spawn(worker); tokio::spawn(worker);
this this
} }
/// Creates new [`Throttle`] spawning the worker with `tokio::spawn` /// Creates new [`Throttle`] spawning the worker with `tokio::spawn`
pub fn spawn_with_settings(bot: B, settings: Settings) -> Self { pub fn spawn_with_settings(bot: B, settings: Settings) -> Self
// with_settings copy-pasted here to avoid [rust-lang/#76882] where
// B: Requester + Clone + Send + Sync + 'static,
// [rust-lang/#76882]: https://github.com/rust-lang/rust/issues/76882 B::Err: AsResponseParameters,
B::GetChat: Send,
let (tx, rx) = mpsc::channel(settings.limits.messages_per_sec_overall as usize); {
let (info_tx, info_rx) = mpsc::channel(2); let (this, worker) = Self::with_settings(bot, settings);
let worker = worker(settings, rx, info_rx);
let this = Self {
bot,
queue: tx,
info_tx,
};
tokio::spawn(worker); tokio::spawn(worker);
this this
@ -241,6 +232,8 @@ impl Default for Limits {
pub struct Settings { pub struct Settings {
pub limits: Limits, pub limits: Limits,
pub on_queue_full: BoxedFnMut<usize, BoxedFuture>, pub on_queue_full: BoxedFnMut<usize, BoxedFuture>,
pub retry: bool,
pub check_slow_mode: bool,
} }
impl Settings { impl Settings {
@ -257,6 +250,16 @@ impl Settings {
self.on_queue_full = Box::new(move |pending| Box::pin(val(pending))); self.on_queue_full = Box::new(move |pending| Box::pin(val(pending)));
self self
} }
pub fn no_retry(mut self) -> Self {
self.retry = false;
self
}
pub fn check_slow_mode(mut self) -> Self {
self.check_slow_mode = true;
self
}
} }
impl Default for Settings { impl Default for Settings {
@ -267,6 +270,8 @@ impl Default for Settings {
log::warn!("Throttle queue is full ({} pending requests)", pending); log::warn!("Throttle queue is full ({} pending requests)", pending);
Box::pin(ready(())) Box::pin(ready(()))
}), }),
retry: true,
check_slow_mode: false,
} }
} }
} }
@ -306,6 +311,13 @@ struct RequestsSentToChats {
per_sec: HashMap<ChatIdHash, RequestsSent>, per_sec: HashMap<ChatIdHash, RequestsSent>,
} }
struct FreezeUntil {
until: Instant,
after: Duration,
chat: ChatIdHash,
retry: Option<RequestLock>,
}
// Throttling is quite complicated. This comment describes the algorithm of the // Throttling is quite complicated. This comment describes the algorithm of the
// current implementation. // current implementation.
// //
@ -349,14 +361,20 @@ struct RequestsSentToChats {
// the request that it can be now executed, increase counts, add record to the // the request that it can be now executed, increase counts, add record to the
// history. // history.
async fn worker( async fn worker<B>(
Settings { Settings {
mut limits, mut limits,
mut on_queue_full, mut on_queue_full,
retry,
check_slow_mode,
}: Settings, }: Settings,
mut rx: mpsc::Receiver<(ChatIdHash, RequestLock)>, mut rx: mpsc::Receiver<(ChatIdHash, RequestLock)>,
mut info_rx: mpsc::Receiver<InfoMessage>, mut info_rx: mpsc::Receiver<InfoMessage>,
) { bot: B,
) where
B: Requester,
B::Err: AsResponseParameters,
{
// FIXME(waffle): Make an research about data structures for this queue. // FIXME(waffle): Make an research about data structures for this queue.
// Currently this is O(n) removing (n = number of elements // Currently this is O(n) removing (n = number of elements
// stayed), amortized O(1) push (vec+vecrem). // stayed), amortized O(1) push (vec+vecrem).
@ -366,12 +384,17 @@ async fn worker(
let mut history: VecDeque<(ChatIdHash, Instant)> = VecDeque::new(); let mut history: VecDeque<(ChatIdHash, Instant)> = VecDeque::new();
let mut requests_sent = RequestsSentToChats::default(); let mut requests_sent = RequestsSentToChats::default();
let mut slow_mode: Option<HashMap<ChatIdHash, (Duration, Instant)>> =
check_slow_mode.then(HashMap::new);
let mut rx_is_closed = false; let mut rx_is_closed = false;
let mut last_queue_full = Instant::now() let mut last_queue_full = Instant::now()
.checked_sub(QUEUE_FULL_DELAY) .checked_sub(QUEUE_FULL_DELAY)
.unwrap_or_else(Instant::now); .unwrap_or_else(Instant::now);
let (freeze_tx, mut freeze_rx) = mpsc::channel::<FreezeUntil>(1);
while !rx_is_closed || !queue.is_empty() { while !rx_is_closed || !queue.is_empty() {
// FIXME(waffle): // FIXME(waffle):
// 1. If the `queue` is empty, `read_from_rx` call down below will 'block' // 1. If the `queue` is empty, `read_from_rx` call down below will 'block'
@ -383,7 +406,32 @@ async fn worker(
// *blocked in asynchronous way // *blocked in asynchronous way
answer_info(&mut info_rx, &mut limits); answer_info(&mut info_rx, &mut limits);
read_from_rx(&mut rx, &mut queue, &mut rx_is_closed).await; freeze(
&mut freeze_rx,
&freeze_tx,
slow_mode.as_mut(),
&mut queue,
&bot,
None,
)
.await;
loop {
tokio::select! {
() = read_from_rx(&mut rx, &mut queue, &mut rx_is_closed) => break,
freeze_until = freeze_rx.recv() => {
freeze(
&mut freeze_rx,
&freeze_tx,
slow_mode.as_mut(),
&mut queue,
&bot,
freeze_until
)
.await;
},
}
}
//debug_assert_eq!(queue.capacity(), limits.messages_per_sec_overall as usize); //debug_assert_eq!(queue.capacity(), limits.messages_per_sec_overall as usize);
if queue.len() == queue.capacity() && last_queue_full.elapsed() > QUEUE_FULL_DELAY { if queue.len() == queue.capacity() && last_queue_full.elapsed() > QUEUE_FULL_DELAY {
@ -470,6 +518,15 @@ async fn worker(
while let Some(entry) = queue_removing.next() { while let Some(entry) = queue_removing.next() {
let chat = &entry.value().0; let chat = &entry.value().0;
let slow_mode = slow_mode.as_mut().and_then(|sm| sm.get_mut(chat));
if let Some(&mut (delay, last)) = slow_mode {
if last + delay > Instant::now() {
continue;
}
}
let requests_sent_per_sec_count = requests_sent.per_sec.get(chat).copied().unwrap_or(0); let requests_sent_per_sec_count = requests_sent.per_sec.get(chat).copied().unwrap_or(0);
let requests_sent_per_min_count = requests_sent.per_min.get(chat).copied().unwrap_or(0); let requests_sent_per_min_count = requests_sent.per_min.get(chat).copied().unwrap_or(0);
@ -483,18 +540,26 @@ async fn worker(
&& requests_sent_per_min_count < messages_per_min_limit; && requests_sent_per_min_count < messages_per_min_limit;
if limits_not_exceeded { if limits_not_exceeded {
*requests_sent.per_sec.entry(*chat).or_insert(0) += 1; // Unlock the associated request.
*requests_sent.per_min.entry(*chat).or_insert(0) += 1;
history.push_back((*chat, Instant::now()));
// Close the channel and unlock the associated request. let chat = *chat;
let (_, lock) = entry.remove(); let (_, lock) = entry.remove();
lock.unlock();
// We have "sent" one request, so now we can send one less. // Only count request as sent if the request wasn't dropped before unlocked
allowed -= 1; if lock.unlock(retry, freeze_tx.clone()).is_ok() {
if allowed == 0 { *requests_sent.per_sec.entry(chat).or_insert(0) += 1;
break; *requests_sent.per_min.entry(chat).or_insert(0) += 1;
history.push_back((chat, Instant::now()));
if let Some((_, last)) = slow_mode {
*last = Instant::now();
}
// We have "sent" one request, so now we can send one less.
allowed -= 1;
if allowed == 0 {
break;
}
} }
} }
} }
@ -521,8 +586,78 @@ fn answer_info(rx: &mut mpsc::Receiver<InfoMessage>, limits: &mut Limits) {
} }
} }
async fn freeze(
rx: &mut mpsc::Receiver<FreezeUntil>,
tx: &mpsc::Sender<FreezeUntil>,
mut slow_mode: Option<&mut HashMap<ChatIdHash, (Duration, Instant)>>,
queue: &mut Vec<(ChatIdHash, RequestLock)>,
bot: &impl Requester,
mut imm: Option<FreezeUntil>,
) {
// FIXME(waffle): https://github.com/tokio-rs/tokio/issues/3350
while let Some(FreezeUntil {
until,
after,
chat,
mut retry,
}) = imm.take().or_else(|| {
tokio::task::unconstrained(rx.recv())
.now_or_never()
.flatten()
}) {
if let Some(slow_mode) = slow_mode.as_deref_mut() {
// TODO: do something with channels?...
if let hash @ ChatIdHash::Id(id) = chat {
// TODO: maybe not call `get_chat` every time?
// At this point there isn't much we can do with the error besides ignoring
if let Ok(chat) = bot.get_chat(id).send().await {
match chat.slow_mode_delay() {
Some(delay) => {
let now = Instant::now();
let new_delay = Duration::from_secs(delay.into());
slow_mode.insert(hash, (new_delay, now));
}
None => {
slow_mode.remove(&hash);
}
};
}
}
}
// slow mode is enabled and it is <= to the delay asked by telegram
let slow_mode_enabled_and_likely_the_cause = slow_mode
.as_ref()
.and_then(|m| m.get(&chat).map(|(delay, _)| delay <= &after))
.unwrap_or(false);
// Do not sleep if slow mode is enabled since the freeze is most likely caused
// by the said slow mode and not by the global limits.
if slow_mode_enabled_and_likely_the_cause {
queue.extend(Some(chat).zip(retry.take()));
} else {
log::warn!(
"freezing the bot for approximately {:?} due to `RetryAfter` error from telegram",
after
);
tokio::time::sleep_until(until.into()).await;
log::warn!("unfreezing the bot");
if let Some(lock) = retry {
// Since we are already retrying the request, retries are obviously turned on.
let retry = true;
let _ = lock.unlock(retry, tx.clone());
}
}
}
}
async fn read_from_rx<T>(rx: &mut mpsc::Receiver<T>, queue: &mut Vec<T>, rx_is_closed: &mut bool) { async fn read_from_rx<T>(rx: &mut mpsc::Receiver<T>, queue: &mut Vec<T>, rx_is_closed: &mut bool) {
if queue.is_empty() { if queue.is_empty() {
log::warn!("A-blocking on queue");
match rx.recv().await { match rx.recv().await {
Some(req) => queue.push(req), Some(req) => queue.push(req),
None => *rx_is_closed = true, None => *rx_is_closed = true,
@ -544,7 +679,7 @@ async fn read_from_rx<T>(rx: &mut mpsc::Receiver<T>, queue: &mut Vec<T>, rx_is_c
macro_rules! f { macro_rules! f {
($m:ident $this:ident ($($arg:ident : $T:ty),*)) => { ($m:ident $this:ident ($($arg:ident : $T:ty),*)) => {
ThrottlingRequest { ThrottlingRequest {
request: $this.inner().$m($($arg),*), request: Arc::new($this.inner().$m($($arg),*)),
chat_id: |p| (&p.payload_ref().chat_id).into(), chat_id: |p| (&p.payload_ref().chat_id).into(),
worker: $this.queue.clone(), worker: $this.queue.clone(),
} }
@ -571,24 +706,26 @@ macro_rules! ftyid {
impl<B: Requester> Requester for Throttle<B> impl<B: Requester> Requester for Throttle<B>
where where
B::SendMessage: Send, B::Err: AsResponseParameters,
B::ForwardMessage: Send,
B::CopyMessage: Send, B::SendMessage: Clone + Send + Sync,
B::SendPhoto: Send, B::ForwardMessage: Clone + Send + Sync,
B::SendAudio: Send, B::CopyMessage: Clone + Send + Sync,
B::SendDocument: Send, B::SendPhoto: Clone + Send + Sync,
B::SendVideo: Send, B::SendAudio: Clone + Send + Sync,
B::SendAnimation: Send, B::SendDocument: Clone + Send + Sync,
B::SendVoice: Send, B::SendVideo: Clone + Send + Sync,
B::SendVideoNote: Send, B::SendAnimation: Clone + Send + Sync,
B::SendMediaGroup: Send, B::SendVoice: Clone + Send + Sync,
B::SendLocation: Send, B::SendVideoNote: Clone + Send + Sync,
B::SendVenue: Send, B::SendMediaGroup: Clone + Send + Sync,
B::SendContact: Send, B::SendLocation: Clone + Send + Sync,
B::SendPoll: Send, B::SendVenue: Clone + Send + Sync,
B::SendDice: Send, B::SendContact: Clone + Send + Sync,
B::SendSticker: Send, B::SendPoll: Clone + Send + Sync,
B::SendInvoice: Send, B::SendDice: Clone + Send + Sync,
B::SendSticker: Clone + Send + Sync,
B::SendInvoice: Clone + Send + Sync,
{ {
type Err = B::Err; type Err = B::Err;
@ -667,16 +804,16 @@ impl From<&Recipient> for ChatIdHash {
#[must_use = "Requests are lazy and do nothing unless sent"] #[must_use = "Requests are lazy and do nothing unless sent"]
pub struct ThrottlingRequest<R: HasPayload> { pub struct ThrottlingRequest<R: HasPayload> {
request: R, request: Arc<R>,
chat_id: fn(&R::Payload) -> ChatIdHash, chat_id: fn(&R::Payload) -> ChatIdHash,
worker: mpsc::Sender<(ChatIdHash, RequestLock)>, worker: mpsc::Sender<(ChatIdHash, RequestLock)>,
} }
impl<R: HasPayload> HasPayload for ThrottlingRequest<R> { impl<R: HasPayload + Clone> HasPayload for ThrottlingRequest<R> {
type Payload = R::Payload; type Payload = R::Payload;
fn payload_mut(&mut self) -> &mut Self::Payload { fn payload_mut(&mut self) -> &mut Self::Payload {
self.request.payload_mut() Arc::make_mut(&mut self.request).payload_mut()
} }
fn payload_ref(&self) -> &Self::Payload { fn payload_ref(&self) -> &Self::Payload {
@ -686,11 +823,13 @@ impl<R: HasPayload> HasPayload for ThrottlingRequest<R> {
impl<R> Request for ThrottlingRequest<R> impl<R> Request for ThrottlingRequest<R>
where where
R: Request + Send, R: Request + Clone + Send + Sync,
R::Err: AsResponseParameters,
Output<R>: Send,
{ {
type Err = R::Err; type Err = R::Err;
type Send = ThrottlingSend<R>; type Send = ThrottlingSend<R>;
type SendRef = ThrottlingSendRef<R>; type SendRef = ThrottlingSend<R>;
fn send(self) -> Self::Send { fn send(self) -> Self::Send {
let (tx, rx) = channel(); let (tx, rx) = channel();
@ -699,7 +838,8 @@ where
let send = self.worker.send1((chat_id, tx)); let send = self.worker.send1((chat_id, tx));
let inner = ThrottlingSendInner::Registering { let inner = ThrottlingSendInner::Registering {
request: self.request, request: Arc::try_unwrap(self.request).into(),
chat: chat_id,
send, send,
wait: rx, wait: rx,
}; };
@ -718,69 +858,114 @@ where
// //
// However `Request` documentation explicitly notes that `send{,_ref}` // However `Request` documentation explicitly notes that `send{,_ref}`
// should **not** do any kind of work, so it's ok. // should **not** do any kind of work, so it's ok.
let request = self.request.send_ref(); let request = Either::Left(Arc::clone(&self.request));
let inner = ThrottlingSendRefInner::Registering { let inner = ThrottlingSendInner::Registering {
chat: chat_id,
request, request,
send, send,
wait: rx, wait: rx,
}; };
ThrottlingSendRef(inner) ThrottlingSend(inner)
} }
} }
use either::Either;
use std::sync::Arc;
#[pin_project::pin_project] #[pin_project::pin_project]
pub struct ThrottlingSend<R: Request>(#[pin] ThrottlingSendInner<R>); pub struct ThrottlingSend<R: Request>(#[pin] ThrottlingSendInner<R>);
#[pin_project::pin_project(project = SendProj, project_replace = SendRepl)] #[pin_project::pin_project(project = SendProj, project_replace = SendRepl)]
enum ThrottlingSendInner<R: Request> { enum ThrottlingSendInner<R: Request> {
Registering { Registering {
request: R, request: Either<Arc<R>, R>,
chat: ChatIdHash,
#[pin] #[pin]
send: ChanSend<(ChatIdHash, RequestLock)>, send: ChanSend<(ChatIdHash, RequestLock)>,
wait: RequestWaiter, wait: RequestWaiter,
}, },
Freezing {
#[pin]
freeze_fut: ChanSend<FreezeUntil>,
res: Result<Output<R>, R::Err>,
},
FreezingRetry {
#[pin]
freeze_fut: ChanSend<FreezeUntil>,
request: Either<Arc<R>, R>,
chat: ChatIdHash,
wait: RequestWaiter,
},
Pending { Pending {
request: R, request: Either<Arc<R>, R>,
chat: ChatIdHash,
#[pin] #[pin]
wait: RequestWaiter, wait: RequestWaiter,
}, },
Sent { Sent {
freeze: mpsc::Sender<FreezeUntil>,
chat: ChatIdHash,
#[pin] #[pin]
fut: R::Send, fut: R::Send,
}, },
SentRef {
freeze: mpsc::Sender<FreezeUntil>,
chat: ChatIdHash,
#[pin]
fut: R::SendRef,
},
SentRetryable {
request: Either<Arc<R>, R>,
chat: ChatIdHash,
freeze: mpsc::Sender<FreezeUntil>,
#[pin]
fut: R::SendRef,
},
Done, Done,
} }
impl<R: Request> Future for ThrottlingSend<R> { impl<R: Request> Future for ThrottlingSend<R>
where
R::Err: AsResponseParameters,
{
type Output = Result<Output<R>, R::Err>; type Output = Result<Output<R>, R::Err>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().project().0; let mut this = self.as_mut().project().0;
match this.as_mut().project() { match this.as_mut().project() {
SendProj::Registering { SendProj::Registering { send, .. } => match send.poll(cx) {
request: _,
send,
wait: _,
} => match send.poll(cx) {
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,
Poll::Ready(res) => { Poll::Ready(res) => {
if let SendRepl::Registering { if let SendRepl::Registering {
request, request,
send: _, send: _,
wait, wait,
chat,
} = this.as_mut().project_replace(ThrottlingSendInner::Done) } = this.as_mut().project_replace(ThrottlingSendInner::Done)
{ {
match res { match res {
Ok(()) => this Ok(()) => this.as_mut().set(ThrottlingSendInner::Pending {
.as_mut() request,
.project_replace(ThrottlingSendInner::Pending { request, wait }), wait,
chat,
}),
// The worker is unlikely to drop queue before sending all requests, // The worker is unlikely to drop queue before sending all requests,
// but just in case it has dropped the queue, we want to just send the // but just in case it has dropped the queue, we want to just send the
// request. // request.
Err(_) => this.as_mut().project_replace(ThrottlingSendInner::Sent { Err(_) => this.as_mut().set(match request {
fut: request.send(), Either::Left(shared) => ThrottlingSendInner::SentRef {
freeze: mpsc::channel(1).0,
fut: shared.send_ref(),
chat,
},
Either::Right(owned) => ThrottlingSendInner::Sent {
freeze: mpsc::channel(1).0,
fut: owned.send(),
chat,
},
}), }),
}; };
} }
@ -788,116 +973,154 @@ impl<R: Request> Future for ThrottlingSend<R> {
self.poll(cx) self.poll(cx)
} }
}, },
SendProj::Pending { request: _, wait } => match wait.poll(cx) { SendProj::Pending { wait, .. } => match wait.poll(cx) {
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,
// Worker pass "message" to unlock us by closing the channel, // Worker pass "message" to unlock us by closing the channel,
// and thus we can safely ignore this result as we know it will // and thus we can safely ignore this result as we know it will
// always be `Err(_)` (because `Ok(Never)` is uninhibited) // always be `Err(_)` (because `Ok(Never)` is uninhibited)
// and that's what we want. // and that's what we want.
Poll::Ready(_) => { Poll::Ready((retry, freeze)) => {
if let SendRepl::Pending { request, wait: _ } = if let SendRepl::Pending { request, chat, .. } =
this.as_mut().project_replace(ThrottlingSendInner::Done) this.as_mut().project_replace(ThrottlingSendInner::Done)
{ {
this.as_mut().project_replace(ThrottlingSendInner::Sent { let repl = match (retry, request) {
fut: request.send(), (true, request) => ThrottlingSendInner::SentRetryable {
}); fut: request.as_ref().either(|r| &**r, |r| r).send_ref(),
} chat,
request,
self.poll(cx) freeze,
} },
}, (false, Either::Left(shared)) => ThrottlingSendInner::SentRef {
SendProj::Sent { fut } => { fut: shared.send_ref(),
let res = futures::ready!(fut.poll(cx)); chat,
this.set(ThrottlingSendInner::Done); freeze,
Poll::Ready(res) },
} (false, Either::Right(owned)) => ThrottlingSendInner::Sent {
SendProj::Done => Poll::Pending, fut: owned.send(),
} chat,
} freeze,
} },
#[pin_project::pin_project]
pub struct ThrottlingSendRef<R: Request>(#[pin] ThrottlingSendRefInner<R>);
#[pin_project::pin_project(project = SendRefProj, project_replace = SendRefRepl)]
enum ThrottlingSendRefInner<R: Request> {
Registering {
request: R::SendRef,
#[pin]
send: ChanSend<(ChatIdHash, RequestLock)>,
wait: RequestWaiter,
},
Pending {
request: R::SendRef,
#[pin]
wait: RequestWaiter,
},
Sent {
#[pin]
fut: R::SendRef,
},
Done,
}
impl<R: Request> Future for ThrottlingSendRef<R> {
type Output = Result<Output<R>, R::Err>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().project().0;
match this.as_mut().project() {
SendRefProj::Registering {
request: _,
send,
wait: _,
} => match send.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(res) => {
if let SendRefRepl::Registering {
request,
send: _,
wait,
} = this.as_mut().project_replace(ThrottlingSendRefInner::Done)
{
match res {
Ok(()) => this
.as_mut()
.project_replace(ThrottlingSendRefInner::Pending { request, wait }),
// The worker is unlikely to drop queue before sending all requests,
// but just in case it has dropped the queue, we want to just send the
// request.
Err(_) => this
.as_mut()
.project_replace(ThrottlingSendRefInner::Sent { fut: request }),
}; };
this.as_mut().project_replace(repl);
} }
self.poll(cx) self.poll(cx)
} }
}, },
SendRefProj::Pending { request: _, wait } => match wait.poll(cx) { SendProj::Freezing { freeze_fut, .. } => {
Poll::Pending => Poll::Pending, // Error here means that the worker died, so we can't really do anything about
// Worker pass "message" to unlock us by closing the channel, // it
// and thus we can safely ignore this result as we know it will let _ = futures::ready!(freeze_fut.poll(cx));
// always be `Err(_)` (because `Ok(Never)` is uninhibited) if let SendRepl::Freezing { res, .. } =
// and that's what we want. this.as_mut().project_replace(ThrottlingSendInner::Done)
Poll::Ready(_) => { {
if let SendRefRepl::Pending { request, wait: _ } = Poll::Ready(res)
this.as_mut().project_replace(ThrottlingSendRefInner::Done) } else {
{ // The match above guarantees that this is unreachable
this.as_mut() unreachable!()
.project_replace(ThrottlingSendRefInner::Sent { fut: request });
}
self.poll(cx)
} }
},
SendRefProj::Sent { fut } => {
let res = futures::ready!(fut.poll(cx));
this.set(ThrottlingSendRefInner::Done);
Poll::Ready(res)
} }
SendRefProj::Done => Poll::Pending, SendProj::FreezingRetry { freeze_fut, .. } => {
// Error here means that the worker died, so we can't really do anything about
// it
let _ = futures::ready!(freeze_fut.poll(cx));
if let SendRepl::FreezingRetry {
request,
chat,
wait,
..
} = this.as_mut().project_replace(ThrottlingSendInner::Done)
{
this.as_mut().set(ThrottlingSendInner::Pending {
request,
chat,
wait,
});
self.poll(cx)
} else {
unreachable!()
}
}
SendProj::Sent { fut, freeze, chat } => {
let res = futures::ready!(fut.poll(cx));
if let Err(Some(retry_after)) = res.as_ref().map_err(<_>::retry_after) {
let after = Duration::from_secs(retry_after.into());
let freeze_fut = freeze.clone().send1(FreezeUntil {
until: Instant::now() + after,
after,
chat: *chat,
retry: None,
});
this.set(ThrottlingSendInner::Freezing { freeze_fut, res });
self.poll(cx)
} else {
this.set(ThrottlingSendInner::Done);
Poll::Ready(res)
}
}
SendProj::SentRef { freeze, fut, chat } => {
let res = futures::ready!(fut.poll(cx));
if let Err(Some(retry_after)) = res.as_ref().map_err(<_>::retry_after) {
let after = Duration::from_secs(retry_after.into());
let freeze_fut = freeze.clone().send1(FreezeUntil {
until: Instant::now() + after,
after,
chat: *chat,
retry: None,
});
this.set(ThrottlingSendInner::Freezing { freeze_fut, res });
self.poll(cx)
} else {
this.set(ThrottlingSendInner::Done);
Poll::Ready(res)
}
}
SendProj::SentRetryable { fut, .. } => {
let res = futures::ready!(fut.poll(cx));
let (lock, wait) = channel();
if let Err(Some(retry_after)) = res.as_ref().map_err(<_>::retry_after) {
let after = Duration::from_secs(retry_after.into());
if let SendRepl::SentRetryable {
request,
freeze,
chat,
..
} = this.as_mut().project_replace(ThrottlingSendInner::Done)
{
log::warn!("Freezing, before retrying: {}", retry_after);
let freeze_fut = freeze.send1(FreezeUntil {
until: Instant::now(),
after,
chat,
retry: Some(lock),
});
this.as_mut().set(ThrottlingSendInner::FreezingRetry {
freeze_fut,
request,
chat,
wait,
})
}
self.poll(cx)
} else {
this.set(ThrottlingSendInner::Done);
Poll::Ready(res)
}
}
SendProj::Done => {
log::error!("Polling done");
Poll::Pending
}
} }
} }
} }
@ -910,25 +1133,26 @@ fn channel() -> (RequestLock, RequestWaiter) {
} }
#[must_use] #[must_use]
struct RequestLock(Sender<Never>); struct RequestLock(Sender<(bool, mpsc::Sender<FreezeUntil>)>);
impl RequestLock { impl RequestLock {
fn unlock(self) { fn unlock(self, retry: bool, freeze: mpsc::Sender<FreezeUntil>) -> Result<(), ()> {
// Unlock request by closing oneshot channel self.0.send((retry, freeze)).map_err(drop)
} }
} }
#[must_use] #[must_use]
#[pin_project::pin_project] #[pin_project::pin_project]
struct RequestWaiter(#[pin] Receiver<Never>); struct RequestWaiter(#[pin] Receiver<(bool, mpsc::Sender<FreezeUntil>)>);
impl Future for RequestWaiter { impl Future for RequestWaiter {
type Output = (); type Output = (bool, mpsc::Sender<FreezeUntil>);
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project(); let this = self.project();
match this.0.poll(cx) { match this.0.poll(cx) {
Poll::Ready(_) => Poll::Ready(()), Poll::Ready(Ok(ret)) => Poll::Ready(ret),
Poll::Ready(Err(_)) => panic!("`RequestLock` is dropped by the throttle worker"),
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,
} }
} }

View file

@ -1,4 +1,6 @@
use crate::{adaptors::DefaultParseMode, requests::Requester, types::ParseMode}; use crate::{
adaptors::DefaultParseMode, errors::AsResponseParameters, requests::Requester, types::ParseMode,
};
#[cfg(feature = "cache_me")] #[cfg(feature = "cache_me")]
use crate::adaptors::CacheMe; use crate::adaptors::CacheMe;
@ -60,7 +62,9 @@ pub trait RequesterExt: Requester {
#[cfg(feature = "throttle")] #[cfg(feature = "throttle")]
fn throttle(self, limits: Limits) -> Throttle<Self> fn throttle(self, limits: Limits) -> Throttle<Self>
where where
Self: Sized, Self: Sized + Clone + Send + Sync + 'static,
Self::Err: AsResponseParameters,
Self::GetChat: Send,
{ {
Throttle::new_spawn(self, limits) Throttle::new_spawn(self, limits)
} }