mirror of
https://github.com/teloxide/teloxide.git
synced 2025-03-13 11:18:17 +01:00
Merge pull request #50 from teloxide/clean_throttle
Make some cleanup of throttle
This commit is contained in:
commit
1b51f2616c
2 changed files with 95 additions and 65 deletions
|
@ -19,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- `RequesterExt` trait which is implemented for all `Requester`s and allows easily wrapping them in adaptors
|
||||
- `adaptors` module ([#14][pr14])
|
||||
- `throttle`, `cache_me`, `auto_send` and `full` crate features
|
||||
- Request throttling - opt-in feature represented by `Throttle` bot adapter which allows automatically checking telegram limits ([#10][pr10], [#46][pr46])
|
||||
- Request throttling - opt-in feature represented by `Throttle` bot adapter which allows automatically checking telegram limits ([#10][pr10], [#46][pr46], [#50][pr50])
|
||||
- Request auto sending - ability to `.await` requests without need to call `.send()` (opt-in feature represented by `AutoSend` bot adapter, [#8][pr8])
|
||||
- `get_me` caching (opt-in feature represented by `CacheMe` bot adapter)
|
||||
- `Requester` trait which represents bot-clients ([#7][pr7], [#12][pr12], [#27][pr27])
|
||||
|
@ -55,6 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
[pr39]: https://github.com/teloxide/teloxide-core/pull/39
|
||||
[pr46]: https://github.com/teloxide/teloxide-core/pull/46
|
||||
[pr49]: https://github.com/teloxide/teloxide-core/pull/49
|
||||
[pr50]: https://github.com/teloxide/teloxide-core/pull/50
|
||||
|
||||
### Changed
|
||||
|
||||
|
|
|
@ -13,12 +13,12 @@ use futures::{
|
|||
use never::Never;
|
||||
use tokio::sync::{
|
||||
mpsc,
|
||||
oneshot::{channel, Receiver, Sender},
|
||||
oneshot::{self, Receiver, Sender},
|
||||
};
|
||||
use vecrem::VecExt;
|
||||
|
||||
use crate::{
|
||||
adaptors::throttle::chan_send::{ChanSend, SendTy},
|
||||
adaptors::throttle::chan_send::{ChanSend, MpscSend},
|
||||
requests::{HasPayload, Output, Request, Requester},
|
||||
types::*,
|
||||
};
|
||||
|
@ -129,7 +129,6 @@ impl Default for Limits {
|
|||
/// ```no_run (throttle fails to spawn task without tokio runtime)
|
||||
/// use teloxide_core::{adaptors::throttle::Limits, requests::RequesterExt, Bot};
|
||||
///
|
||||
/// # #[allow(deprecated)]
|
||||
/// let bot = Bot::new("TOKEN").throttle(Limits::default());
|
||||
///
|
||||
/// /* send many requests here */
|
||||
|
@ -149,8 +148,8 @@ impl Default for Limits {
|
|||
/// wrapper.
|
||||
pub struct Throttle<B> {
|
||||
bot: B,
|
||||
// Sender<Never> is used to pass the signal to unlock by closing the channel.
|
||||
queue: mpsc::Sender<(ChatIdHash, Sender<Never>)>,
|
||||
// `RequestLock` allows to unlock requests (allowing them to be sent).
|
||||
queue: mpsc::Sender<(ChatIdHash, RequestLock)>,
|
||||
}
|
||||
|
||||
type RequestsSent = u32;
|
||||
|
@ -164,11 +163,11 @@ struct RequestsSentToChats {
|
|||
per_sec: HashMap<ChatIdHash, RequestsSent>,
|
||||
}
|
||||
|
||||
async fn worker(limits: Limits, mut rx: mpsc::Receiver<(ChatIdHash, Sender<Never>)>) {
|
||||
async fn worker(limits: Limits, mut rx: mpsc::Receiver<(ChatIdHash, RequestLock)>) {
|
||||
// FIXME(waffle): Make an research about data structures for this queue.
|
||||
// Currently this is O(n) removing (n = number of elements
|
||||
// stayed), amortized O(1) push (vec+vecrem).
|
||||
let mut queue: Vec<(ChatIdHash, Sender<Never>)> =
|
||||
let mut queue: Vec<(ChatIdHash, RequestLock)> =
|
||||
Vec::with_capacity(limits.messages_per_sec_overall as usize);
|
||||
|
||||
let mut history: VecDeque<(ChatIdHash, Instant)> = VecDeque::new();
|
||||
|
@ -178,6 +177,7 @@ async fn worker(limits: Limits, mut rx: mpsc::Receiver<(ChatIdHash, Sender<Never
|
|||
|
||||
while !rx_is_closed || !queue.is_empty() {
|
||||
read_from_rx(&mut rx, &mut queue, &mut rx_is_closed).await;
|
||||
debug_assert_eq!(queue.capacity(), limits.messages_per_sec_overall as usize);
|
||||
|
||||
// _Maybe_ we need to use `spawn_blocking` here, because there is
|
||||
// decent amount of blocking work. However _for now_ I've decided not
|
||||
|
@ -268,7 +268,8 @@ async fn worker(limits: Limits, mut rx: mpsc::Receiver<(ChatIdHash, Sender<Never
|
|||
history.push_back((*chat, Instant::now()));
|
||||
|
||||
// Close the channel and unlock the associated request.
|
||||
drop(entry.remove());
|
||||
let (_, lock) = entry.remove();
|
||||
lock.unlock();
|
||||
|
||||
// We have "sent" one request, so now we can send one less.
|
||||
allowed -= 1;
|
||||
|
@ -285,11 +286,7 @@ async fn worker(limits: Limits, mut rx: mpsc::Receiver<(ChatIdHash, Sender<Never
|
|||
}
|
||||
}
|
||||
|
||||
async fn read_from_rx(
|
||||
rx: &mut mpsc::Receiver<(ChatIdHash, Sender<Never>)>,
|
||||
queue: &mut Vec<(ChatIdHash, Sender<Never>)>,
|
||||
rx_is_closed: &mut bool,
|
||||
) {
|
||||
async fn read_from_rx<T>(rx: &mut mpsc::Receiver<T>, queue: &mut Vec<T>, rx_is_closed: &mut bool) {
|
||||
if queue.is_empty() {
|
||||
match rx.recv().await {
|
||||
Some(req) => queue.push(req),
|
||||
|
@ -297,7 +294,8 @@ async fn read_from_rx(
|
|||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
// Don't grow queue bigger than the capacity to limit DOS posibility
|
||||
while queue.len() < queue.capacity() {
|
||||
// FIXME(waffle): https://github.com/tokio-rs/tokio/issues/3350
|
||||
match rx.recv().now_or_never() {
|
||||
Some(Some(req)) => queue.push(req),
|
||||
|
@ -359,11 +357,11 @@ impl<B> Throttle<B> {
|
|||
|
||||
macro_rules! f {
|
||||
($m:ident $this:ident ($($arg:ident : $T:ty),*)) => {
|
||||
ThrottlingRequest(
|
||||
$this.inner().$m($($arg),*),
|
||||
$this.queue.clone(),
|
||||
|p| (&p.payload_ref().chat_id).into(),
|
||||
)
|
||||
ThrottlingRequest {
|
||||
request: $this.inner().$m($($arg),*),
|
||||
chat_id: |p| (&p.payload_ref().chat_id).into(),
|
||||
worker: $this.queue.clone(),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -436,8 +434,8 @@ where
|
|||
C: Into<String>,
|
||||
Pri: IntoIterator<Item = LabeledPrice>,
|
||||
{
|
||||
ThrottlingRequest(
|
||||
self.inner().send_invoice(
|
||||
ThrottlingRequest {
|
||||
request: self.inner().send_invoice(
|
||||
chat_id,
|
||||
title,
|
||||
description,
|
||||
|
@ -447,9 +445,9 @@ where
|
|||
currency,
|
||||
prices,
|
||||
),
|
||||
self.queue.clone(),
|
||||
|p| ChatIdHash::Id(p.payload_ref().chat_id as _),
|
||||
)
|
||||
chat_id: |p| ChatIdHash::Id(p.payload_ref().chat_id as _),
|
||||
worker: self.queue.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
requester_forward! {
|
||||
|
@ -506,21 +504,21 @@ impl From<&ChatId> for ChatIdHash {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct ThrottlingRequest<R: HasPayload>(
|
||||
R,
|
||||
mpsc::Sender<(ChatIdHash, Sender<Never>)>,
|
||||
fn(&R::Payload) -> ChatIdHash,
|
||||
);
|
||||
pub struct ThrottlingRequest<R: HasPayload> {
|
||||
request: R,
|
||||
chat_id: fn(&R::Payload) -> ChatIdHash,
|
||||
worker: mpsc::Sender<(ChatIdHash, RequestLock)>,
|
||||
}
|
||||
|
||||
impl<R: HasPayload> HasPayload for ThrottlingRequest<R> {
|
||||
type Payload = R::Payload;
|
||||
|
||||
fn payload_mut(&mut self) -> &mut Self::Payload {
|
||||
self.0.payload_mut()
|
||||
self.request.payload_mut()
|
||||
}
|
||||
|
||||
fn payload_ref(&self) -> &Self::Payload {
|
||||
self.0.payload_ref()
|
||||
self.request.payload_ref()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -534,18 +532,23 @@ where
|
|||
|
||||
fn send(self) -> Self::Send {
|
||||
let (tx, rx) = channel();
|
||||
let id = self.2(self.payload_ref());
|
||||
let send = self.1.send_t((id, tx));
|
||||
ThrottlingSend(ThrottlingSendInner::Registering {
|
||||
request: self.0,
|
||||
|
||||
let chat_id = (self.chat_id)(self.payload_ref());
|
||||
let send = self.worker.send1((chat_id, tx));
|
||||
|
||||
let inner = ThrottlingSendInner::Registering {
|
||||
request: self.request,
|
||||
send,
|
||||
wait: rx,
|
||||
})
|
||||
};
|
||||
ThrottlingSend(inner)
|
||||
}
|
||||
|
||||
fn send_ref(&self) -> Self::SendRef {
|
||||
let (tx, rx) = channel();
|
||||
let send = self.1.clone().send_t((self.2(self.payload_ref()), tx));
|
||||
|
||||
let chat_id = (self.chat_id)(self.payload_ref());
|
||||
let send = self.worker.clone().send1((chat_id, tx));
|
||||
|
||||
// As we can't move self.0 (request) out, as we do in `send` we are
|
||||
// forced to call `send_ref()`. This may have overhead and/or lead to
|
||||
|
@ -553,13 +556,14 @@ where
|
|||
//
|
||||
// However `Request` documentation explicitly notes that `send{,_ref}`
|
||||
// should **not** do any kind of work, so it's ok.
|
||||
let request = self.0.send_ref();
|
||||
let request = self.request.send_ref();
|
||||
|
||||
ThrottlingSendRef(ThrottlingSendRefInner::Registering {
|
||||
let inner = ThrottlingSendRefInner::Registering {
|
||||
request,
|
||||
send,
|
||||
wait: rx,
|
||||
})
|
||||
};
|
||||
ThrottlingSendRef(inner)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -571,13 +575,13 @@ enum ThrottlingSendInner<R: Request> {
|
|||
Registering {
|
||||
request: R,
|
||||
#[pin]
|
||||
send: ChanSend,
|
||||
wait: Receiver<Never>,
|
||||
send: ChanSend<(ChatIdHash, RequestLock)>,
|
||||
wait: RequestWaiter,
|
||||
},
|
||||
Pending {
|
||||
request: R,
|
||||
#[pin]
|
||||
wait: Receiver<Never>,
|
||||
wait: RequestWaiter,
|
||||
},
|
||||
Sent {
|
||||
#[pin]
|
||||
|
@ -658,13 +662,13 @@ enum ThrottlingSendRefInner<R: Request> {
|
|||
Registering {
|
||||
request: R::SendRef,
|
||||
#[pin]
|
||||
send: ChanSend,
|
||||
wait: Receiver<Never>,
|
||||
send: ChanSend<(ChatIdHash, RequestLock)>,
|
||||
wait: RequestWaiter,
|
||||
},
|
||||
Pending {
|
||||
request: R::SendRef,
|
||||
#[pin]
|
||||
wait: Receiver<Never>,
|
||||
wait: RequestWaiter,
|
||||
},
|
||||
Sent {
|
||||
#[pin]
|
||||
|
@ -736,38 +740,63 @@ impl<R: Request> Future for ThrottlingSendRef<R> {
|
|||
}
|
||||
}
|
||||
|
||||
fn channel() -> (RequestLock, RequestWaiter) {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let tx = RequestLock(tx);
|
||||
let rx = RequestWaiter(rx);
|
||||
(tx, rx)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
struct RequestLock(Sender<Never>);
|
||||
|
||||
impl RequestLock {
|
||||
fn unlock(self) {
|
||||
// Unlock request by closing oneshot channel
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[pin_project::pin_project]
|
||||
struct RequestWaiter(#[pin] Receiver<Never>);
|
||||
|
||||
impl Future for RequestWaiter {
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
let this = self.project();
|
||||
match this.0.poll(cx) {
|
||||
Poll::Ready(_) => Poll::Ready(()),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod chan_send {
|
||||
use std::{future::Future, pin::Pin};
|
||||
|
||||
use futures::task::{Context, Poll};
|
||||
use never::Never;
|
||||
use tokio::sync::{mpsc, mpsc::error::SendError, oneshot::Sender};
|
||||
use tokio::sync::{mpsc, mpsc::error::SendError};
|
||||
|
||||
use crate::adaptors::throttle::ChatIdHash;
|
||||
|
||||
pub(super) trait SendTy {
|
||||
fn send_t(self, val: (ChatIdHash, Sender<Never>)) -> ChanSend;
|
||||
pub(super) trait MpscSend<T> {
|
||||
fn send1(self, val: T) -> ChanSend<T>;
|
||||
}
|
||||
|
||||
#[pin_project::pin_project]
|
||||
pub(super) struct ChanSend(#[pin] Inner);
|
||||
pub(super) struct ChanSend<T>(#[pin] Inner<T>);
|
||||
|
||||
#[cfg(not(feature = "nightly"))]
|
||||
type Inner =
|
||||
Pin<Box<dyn Future<Output = Result<(), SendError<(ChatIdHash, Sender<Never>)>>> + Send>>;
|
||||
type Inner<T> = Pin<Box<dyn Future<Output = Result<(), SendError<T>>> + Send>>;
|
||||
#[cfg(feature = "nightly")]
|
||||
type Inner = impl Future<Output = Result<(), SendError<(ChatIdHash, Sender<Never>)>>>;
|
||||
type Inner<T> = impl Future<Output = Result<(), SendError<T>>>;
|
||||
|
||||
impl SendTy for mpsc::Sender<(ChatIdHash, Sender<Never>)> {
|
||||
impl<T: Send + 'static> MpscSend<T> for mpsc::Sender<T> {
|
||||
// `return`s trick IDEA not to show errors
|
||||
#[allow(clippy::needless_return)]
|
||||
fn send_t(self, val: (ChatIdHash, Sender<Never>)) -> ChanSend {
|
||||
fn send1(self, val: T) -> ChanSend<T> {
|
||||
#[cfg(feature = "nightly")]
|
||||
{
|
||||
fn def(
|
||||
sender: mpsc::Sender<(ChatIdHash, Sender<Never>)>,
|
||||
val: (ChatIdHash, Sender<Never>),
|
||||
) -> Inner {
|
||||
fn def<T>(sender: mpsc::Sender<T>, val: T) -> Inner<T> {
|
||||
async move { sender.send(val).await }
|
||||
}
|
||||
return ChanSend(def(self, val));
|
||||
|
@ -780,8 +809,8 @@ mod chan_send {
|
|||
}
|
||||
}
|
||||
|
||||
impl Future for ChanSend {
|
||||
type Output = Result<(), SendError<(ChatIdHash, Sender<Never>)>>;
|
||||
impl<T> Future for ChanSend<T> {
|
||||
type Output = Result<(), SendError<T>>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.project().0.poll(cx)
|
||||
|
|
Loading…
Add table
Reference in a new issue