mirror of
https://github.com/teloxide/teloxide.git
synced 2024-12-22 14:35:36 +01:00
Refactor UpdateListener
to allow properly setting allowed updates for webhooks
This commit is contained in:
parent
0d47b40137
commit
316ca97886
11 changed files with 469 additions and 437 deletions
|
@ -21,7 +21,7 @@ categories = ["web-programming", "api-bindings", "asynchronous"]
|
|||
default = ["native-tls", "ctrlc_handler", "teloxide-core/default", "auto-send"]
|
||||
|
||||
webhooks = ["rand"]
|
||||
webhooks-axum = ["webhooks", "axum", "tower", "tower-http"]
|
||||
webhooks-axum = ["webhooks", "axum", "tower", "tower-http", "hyper"]
|
||||
|
||||
# FIXME: rename `sqlite-storage` -> `sqlite-storage-nativetls`
|
||||
sqlite-storage = ["sqlx", "sqlx/runtime-tokio-native-tls", "native-tls"]
|
||||
|
@ -109,6 +109,7 @@ bincode = { version = "1.3", optional = true }
|
|||
axum = { version = "0.6.0", optional = true }
|
||||
tower = { version = "0.4.12", optional = true }
|
||||
tower-http = { version = "0.3.4", features = ["trace"], optional = true }
|
||||
hyper = { version = "0.14", optional = true, default-features = false }
|
||||
rand = { version = "0.8.5", optional = true }
|
||||
|
||||
|
||||
|
|
|
@ -41,9 +41,7 @@ async fn main() {
|
|||
let host = env::var("HOST").expect("HOST env variable is not set");
|
||||
let url = format!("https://{host}/webhook").parse().unwrap();
|
||||
|
||||
let listener = webhooks::axum(bot.clone(), webhooks::Options::new(addr, url))
|
||||
.await
|
||||
.expect("Couldn't setup webhook");
|
||||
let listener = webhooks::axum(bot.clone(), webhooks::Options::new(addr, url));
|
||||
|
||||
teloxide::repl_with_listener(
|
||||
bot,
|
||||
|
|
|
@ -12,9 +12,7 @@ async fn main() {
|
|||
|
||||
let addr = ([127, 0, 0, 1], 8443).into();
|
||||
let url = "Your HTTPS ngrok URL here. Get it by `ngrok http 8443`".parse().unwrap();
|
||||
let listener = webhooks::axum(bot.clone(), webhooks::Options::new(addr, url))
|
||||
.await
|
||||
.expect("Couldn't setup webhook");
|
||||
let listener = webhooks::axum(bot.clone(), webhooks::Options::new(addr, url));
|
||||
|
||||
teloxide::repl_with_listener(
|
||||
bot,
|
||||
|
|
|
@ -20,7 +20,8 @@ use tokio_stream::wrappers::ReceiverStream;
|
|||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fmt::Debug,
|
||||
error::Error,
|
||||
fmt::{self, Debug, Display},
|
||||
future::Future,
|
||||
hash::Hash,
|
||||
ops::{ControlFlow, Deref},
|
||||
|
@ -214,6 +215,15 @@ pub struct Dispatcher<R, Err, Key> {
|
|||
state: ShutdownToken,
|
||||
}
|
||||
|
||||
/// An error returned from [`Disatcher::try_dispatch_with_listener`].
|
||||
pub enum TryDispatchError<R: Requester, L: UpdateListener> {
|
||||
/// An error from calling `get_me` while creating dispatcher context.
|
||||
GetMe(R::Err),
|
||||
|
||||
/// An error during update listener setup.
|
||||
ListenerSetup(L::SetupErr),
|
||||
}
|
||||
|
||||
struct Worker {
|
||||
tx: tokio::sync::mpsc::Sender<Update>,
|
||||
handle: tokio::task::JoinHandle<()>,
|
||||
|
@ -300,8 +310,8 @@ where
|
|||
update_listener_error_handler: Arc<Eh>,
|
||||
) where
|
||||
UListener: UpdateListener + 'a,
|
||||
Eh: ErrorHandler<UListener::Err> + 'a,
|
||||
UListener::Err: Debug,
|
||||
Eh: ErrorHandler<UListener::StreamErr> + 'a,
|
||||
UListener::SetupErr: Debug,
|
||||
{
|
||||
self.try_dispatch_with_listener(update_listener, update_listener_error_handler)
|
||||
.await
|
||||
|
@ -319,14 +329,13 @@ where
|
|||
&'a mut self,
|
||||
mut update_listener: UListener,
|
||||
update_listener_error_handler: Arc<Eh>,
|
||||
) -> Result<(), R::Err>
|
||||
) -> Result<(), TryDispatchError<R, UListener>>
|
||||
where
|
||||
UListener: UpdateListener + 'a,
|
||||
Eh: ErrorHandler<UListener::Err> + 'a,
|
||||
UListener::Err: Debug,
|
||||
Eh: ErrorHandler<UListener::StreamErr> + 'a,
|
||||
{
|
||||
// FIXME: there should be a way to check if dependency is already inserted
|
||||
let me = self.bot.get_me().send().await?;
|
||||
let me = self.bot.get_me().send().await.map_err(TryDispatchError::GetMe)?;
|
||||
self.dependencies.insert(me);
|
||||
self.dependencies.insert(self.bot.clone());
|
||||
|
||||
|
@ -340,7 +349,7 @@ where
|
|||
self.state.start_dispatching();
|
||||
|
||||
{
|
||||
let stream = update_listener.as_stream();
|
||||
let stream = update_listener.listen().await.map_err(TryDispatchError::ListenerSetup)?;
|
||||
tokio::pin!(stream);
|
||||
|
||||
loop {
|
||||
|
@ -521,6 +530,51 @@ impl<R, Err, Key> Dispatcher<R, Err, Key> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<R: Requester, L: UpdateListener> Debug for TryDispatchError<R, L>
|
||||
where
|
||||
R: Requester,
|
||||
R::Err: Debug,
|
||||
L: UpdateListener,
|
||||
L::SetupErr: Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::GetMe(e) => f.debug_tuple("GetMe").field(&e).finish(),
|
||||
Self::ListenerSetup(e) => f.debug_tuple("ListenerSetup").field(&e).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, L> fmt::Display for TryDispatchError<R, L>
|
||||
where
|
||||
R: Requester,
|
||||
R::Err: Display,
|
||||
L: UpdateListener,
|
||||
L::SetupErr: Display,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::GetMe(e) => write!(f, "Error while setting up update listener: {e}"),
|
||||
Self::ListenerSetup(e) => write!(f, "Error while setting up update listener: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, L> Error for TryDispatchError<R, L>
|
||||
where
|
||||
R: Requester,
|
||||
R::Err: Error + Debug + Display + 'static,
|
||||
L: UpdateListener,
|
||||
L::SetupErr: Error + Debug + Display + 'static,
|
||||
{
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
match self {
|
||||
Self::GetMe(e) => Some(e),
|
||||
Self::ListenerSetup(e) => Some(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_worker<Err>(
|
||||
deps: DependencyMap,
|
||||
handler: Arc<UpdateHandler<Err>>,
|
||||
|
|
|
@ -85,8 +85,11 @@ pub trait CommandReplExt {
|
|||
fn repl_with_listener<'a, R, H, L, Args>(bot: R, handler: H, listener: L) -> BoxFuture<'a, ()>
|
||||
where
|
||||
H: Injectable<DependencyMap, ResponseResult<()>, Args> + Send + Sync + 'static,
|
||||
L: UpdateListener + Send + 'a,
|
||||
L::Err: Debug + Send + 'a,
|
||||
// FIXME: why do we need + 'static here??? (and in similar cases) (compiler error does not
|
||||
// provide enough info...)
|
||||
L: UpdateListener + Send + 'static,
|
||||
L::SetupErr: Debug,
|
||||
L::StreamErr: Debug + Send + 'a,
|
||||
R: Requester + Clone + Send + Sync + 'static,
|
||||
<R as Requester>::GetMe: Send;
|
||||
}
|
||||
|
@ -120,8 +123,9 @@ where
|
|||
fn repl_with_listener<'a, R, H, L, Args>(bot: R, handler: H, listener: L) -> BoxFuture<'a, ()>
|
||||
where
|
||||
H: Injectable<DependencyMap, ResponseResult<()>, Args> + Send + Sync + 'static,
|
||||
L: UpdateListener + Send + 'a,
|
||||
L::Err: Debug + Send + 'a,
|
||||
L: UpdateListener + Send + 'static,
|
||||
L::SetupErr: Debug,
|
||||
L::StreamErr: Debug + Send + 'a,
|
||||
R: Requester + Clone + Send + Sync + 'static,
|
||||
<R as Requester>::GetMe: Send,
|
||||
{
|
||||
|
@ -277,7 +281,8 @@ pub async fn commands_repl_with_listener<'a, R, Cmd, H, L, Args>(
|
|||
Cmd: BotCommands + Send + Sync + 'static,
|
||||
H: Injectable<DependencyMap, ResponseResult<()>, Args> + Send + Sync + 'static,
|
||||
L: UpdateListener + Send + 'a,
|
||||
L::Err: Debug + Send + 'a,
|
||||
L::SetupErr: Debug,
|
||||
L::StreamErr: Debug,
|
||||
R: Requester + Clone + Send + Sync + 'static,
|
||||
{
|
||||
use crate::dispatching::Dispatcher;
|
||||
|
|
|
@ -108,7 +108,8 @@ where
|
|||
R: Requester + Clone + Send + Sync + 'static,
|
||||
H: Injectable<DependencyMap, ResponseResult<()>, Args> + Send + Sync + 'static,
|
||||
L: UpdateListener + Send,
|
||||
L::Err: Debug,
|
||||
L::SetupErr: Debug,
|
||||
L::StreamErr: Debug,
|
||||
{
|
||||
use crate::dispatching::Dispatcher;
|
||||
|
||||
|
|
|
@ -30,7 +30,9 @@
|
|||
#[cfg(feature = "webhooks")]
|
||||
pub mod webhooks;
|
||||
|
||||
use futures::Stream;
|
||||
use std::pin::Pin;
|
||||
|
||||
use futures::{Future, Stream};
|
||||
|
||||
use crate::{
|
||||
stop::StopToken,
|
||||
|
@ -38,30 +40,54 @@ use crate::{
|
|||
};
|
||||
|
||||
mod polling;
|
||||
mod stateful_listener;
|
||||
|
||||
#[allow(deprecated)]
|
||||
pub use self::{
|
||||
polling::{polling, polling_default, Polling, PollingBuilder, PollingStream},
|
||||
stateful_listener::StatefulListener,
|
||||
};
|
||||
pub use self::polling::{polling, polling_default, Polling, PollingBuilder, PollingStream};
|
||||
|
||||
/// An update listener.
|
||||
///
|
||||
/// Implementors of this trait allow getting updates from Telegram. See
|
||||
/// [module-level documentation] for more.
|
||||
///
|
||||
/// Some functions of this trait are located in the supertrait
|
||||
/// ([`AsUpdateStream`]), see also:
|
||||
/// - [`AsUpdateStream::Stream`]
|
||||
/// - [`AsUpdateStream::as_stream`]
|
||||
///
|
||||
/// [module-level documentation]: mod@self
|
||||
pub trait UpdateListener:
|
||||
for<'a> AsUpdateStream<'a, StreamErr = <Self as UpdateListener>::Err>
|
||||
{
|
||||
/// The type of errors that can be returned from this listener.
|
||||
type Err;
|
||||
pub trait UpdateListener {
|
||||
type SetupErr;
|
||||
|
||||
/// Error that can be returned from the [`Stream`]
|
||||
///
|
||||
/// [`Stream`]: UpdateListener::Stream
|
||||
type StreamErr;
|
||||
|
||||
/// The stream of updates from Telegram.
|
||||
// NB: `Send` is not strictly required here, but it makes it easier to return
|
||||
// `impl AsUpdateStream` and also you want `Send` streams almost (?) always
|
||||
// anyway.
|
||||
type Stream<'a>: Stream<Item = Result<Update, Self::StreamErr>> + Send + 'a
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
/// Creates the update [`Stream`].
|
||||
///
|
||||
/// This function should also do all the necessary setup, and return an
|
||||
/// error if something goes wrong with it. For example for webhooks this
|
||||
/// should call `set_webhook`.
|
||||
///
|
||||
/// [`Stream`]: AsUpdateStream::Stream
|
||||
fn listen(
|
||||
&mut self,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'_>, Self::SetupErr>> + Send + '_>>;
|
||||
|
||||
/// Hint which updates should the listener listen for.
|
||||
///
|
||||
/// For example [`polling()`] should send the hint as
|
||||
/// [`GetUpdates::allowed_updates`]
|
||||
///
|
||||
/// Note: this is a very important method, without setting appropriate
|
||||
/// allowed updates, telegram will not send some update kinds.
|
||||
///
|
||||
/// [`GetUpdates::allowed_updates`]:
|
||||
/// crate::payloads::GetUpdates::allowed_updates
|
||||
fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator<Item = AllowedUpdate>);
|
||||
|
||||
/// Returns a token which stops this listener.
|
||||
///
|
||||
|
@ -77,44 +103,6 @@ pub trait UpdateListener:
|
|||
#[must_use = "This function doesn't stop listening, to stop listening you need to call `stop` \
|
||||
on the returned token"]
|
||||
fn stop_token(&mut self) -> StopToken;
|
||||
|
||||
/// Hint which updates should the listener listen for.
|
||||
///
|
||||
/// For example [`polling()`] should send the hint as
|
||||
/// [`GetUpdates::allowed_updates`]
|
||||
///
|
||||
/// Note however that this is a _hint_ and as such, it can be ignored. The
|
||||
/// listener is not guaranteed to only return updates which types are listed
|
||||
/// in the hint.
|
||||
///
|
||||
/// [`GetUpdates::allowed_updates`]:
|
||||
/// crate::payloads::GetUpdates::allowed_updates
|
||||
fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator<Item = AllowedUpdate>) {
|
||||
let _ = hint;
|
||||
}
|
||||
}
|
||||
|
||||
/// [`UpdateListener`]'s supertrait/extension.
|
||||
///
|
||||
/// This trait is a workaround to not require GAT.
|
||||
pub trait AsUpdateStream<'a> {
|
||||
/// Error that can be returned from the [`Stream`]
|
||||
///
|
||||
/// [`Stream`]: AsUpdateStream::Stream
|
||||
// NB: This should be named differently to `UpdateListener::Err`, so that it's
|
||||
// unambiguous
|
||||
type StreamErr;
|
||||
|
||||
/// The stream of updates from Telegram.
|
||||
// NB: `Send` is not strictly required here, but it makes it easier to return
|
||||
// `impl AsUpdateStream` and also you want `Send` streams almost (?) always
|
||||
// anyway.
|
||||
type Stream: Stream<Item = Result<Update, Self::StreamErr>> + Send + 'a;
|
||||
|
||||
/// Creates the update [`Stream`].
|
||||
///
|
||||
/// [`Stream`]: AsUpdateStream::Stream
|
||||
fn as_stream(&'a mut self) -> Self::Stream;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
|
|
|
@ -17,7 +17,7 @@ use crate::{
|
|||
requests::{HasPayload, Request, Requester},
|
||||
stop::{mk_stop_token, StopFlag, StopToken},
|
||||
types::{AllowedUpdate, Update},
|
||||
update_listeners::{assert_update_listener, AsUpdateStream, UpdateListener},
|
||||
update_listeners::{assert_update_listener, UpdateListener},
|
||||
};
|
||||
|
||||
/// Builder for polling update listener.
|
||||
|
@ -243,7 +243,7 @@ where
|
|||
/// [get_updates]: crate::requests::Requester::get_updates
|
||||
/// [`Dispatcher`]: crate::dispatching::Dispatcher
|
||||
#[must_use = "`Polling` is an update listener and does nothing unless used"]
|
||||
pub struct Polling<B: Requester> {
|
||||
pub struct Polling<B> {
|
||||
bot: B,
|
||||
timeout: Option<Duration>,
|
||||
limit: Option<u8>,
|
||||
|
@ -320,7 +320,53 @@ pub struct PollingStream<'a, B: Requester> {
|
|||
}
|
||||
|
||||
impl<B: Requester + Send + 'static> UpdateListener for Polling<B> {
|
||||
type Err = B::Err;
|
||||
type SetupErr = B::Err;
|
||||
type StreamErr = B::Err;
|
||||
type Stream<'a> = PollingStream<'a, B>;
|
||||
|
||||
fn listen(
|
||||
&mut self,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'_>, Self::SetupErr>> + Send + '_>> {
|
||||
Box::pin(async {
|
||||
let timeout = self.timeout.map(|t| t.as_secs().try_into().expect("timeout is too big"));
|
||||
let allowed_updates = self.allowed_updates.clone();
|
||||
let drop_pending_updates = self.drop_pending_updates;
|
||||
|
||||
let token_used_and_updated = self.reinit_stop_flag_if_needed();
|
||||
|
||||
// FIXME: document that `listen` is a destructive operation, actually,
|
||||
// and you need to call `stop_token` *again* after it
|
||||
//
|
||||
// maybe also remove the panic, it's a lot of additional work, for little
|
||||
// benefit, it seems like.
|
||||
if token_used_and_updated {
|
||||
panic!(
|
||||
"detected calling `as_stream` a second time after calling `stop_token`. \
|
||||
`as_stream` updates the stop token, thus you need to call it again after \
|
||||
calling `as_stream`"
|
||||
)
|
||||
}
|
||||
|
||||
// FIXME: do update dropping *here*
|
||||
|
||||
// Unwrap: just called reinit
|
||||
let flag = self.flag.take().unwrap();
|
||||
let stream = PollingStream {
|
||||
polling: self,
|
||||
drop_pending_updates,
|
||||
timeout,
|
||||
allowed_updates,
|
||||
offset: 0,
|
||||
force_stop: false,
|
||||
stopping: false,
|
||||
buffer: Vec::new().into_iter(),
|
||||
in_flight: None,
|
||||
flag,
|
||||
};
|
||||
|
||||
Ok(stream)
|
||||
})
|
||||
}
|
||||
|
||||
fn stop_token(&mut self) -> StopToken {
|
||||
self.reinit_stop_flag_if_needed();
|
||||
|
@ -335,44 +381,6 @@ impl<B: Requester + Send + 'static> UpdateListener for Polling<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Requester + Send + 'a> AsUpdateStream<'a> for Polling<B> {
|
||||
type StreamErr = B::Err;
|
||||
type Stream = PollingStream<'a, B>;
|
||||
|
||||
fn as_stream(&'a mut self) -> Self::Stream {
|
||||
let timeout = self.timeout.map(|t| t.as_secs().try_into().expect("timeout is too big"));
|
||||
let allowed_updates = self.allowed_updates.clone();
|
||||
let drop_pending_updates = self.drop_pending_updates;
|
||||
|
||||
let token_used_and_updated = self.reinit_stop_flag_if_needed();
|
||||
|
||||
// FIXME: document that `as_stream` is a destructive operation, actually,
|
||||
// and you need to call `stop_token` *again* after it
|
||||
if token_used_and_updated {
|
||||
panic!(
|
||||
"detected calling `as_stream` a second time after calling `stop_token`. \
|
||||
`as_stream` updates the stop token, thus you need to call it again after calling \
|
||||
`as_stream`"
|
||||
)
|
||||
}
|
||||
|
||||
// Unwrap: just called reinit
|
||||
let flag = self.flag.take().unwrap();
|
||||
PollingStream {
|
||||
polling: self,
|
||||
drop_pending_updates,
|
||||
timeout,
|
||||
allowed_updates,
|
||||
offset: 0,
|
||||
force_stop: false,
|
||||
stopping: false,
|
||||
buffer: Vec::new().into_iter(),
|
||||
in_flight: None,
|
||||
flag,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Requester> Stream for PollingStream<'_, B> {
|
||||
type Item = Result<Update, B::Err>;
|
||||
|
||||
|
@ -471,8 +479,12 @@ fn polling_is_send() {
|
|||
let mut polling = polling(bot, None, None, None);
|
||||
|
||||
assert_send(&polling);
|
||||
assert_send(&polling.as_stream());
|
||||
assert_send(&polling.listen());
|
||||
assert_send(&polling.stop_token());
|
||||
|
||||
_ = async {
|
||||
assert_send(&polling.listen().await.unwrap());
|
||||
};
|
||||
|
||||
fn assert_send(_: &impl Send) {}
|
||||
}
|
||||
|
|
|
@ -1,92 +0,0 @@
|
|||
use futures::Stream;
|
||||
|
||||
use crate::{
|
||||
stop::StopToken,
|
||||
types::{AllowedUpdate, Update},
|
||||
update_listeners::{AsUpdateStream, UpdateListener},
|
||||
};
|
||||
|
||||
/// A listener created from functions.
|
||||
///
|
||||
/// This type allows to turn a stream of updates (+ some additional functions)
|
||||
/// into an [`UpdateListener`].
|
||||
///
|
||||
/// For an example of usage, see [`polling`].
|
||||
///
|
||||
/// [`polling`]: crate::update_listeners::polling()
|
||||
#[non_exhaustive]
|
||||
pub struct StatefulListener<St, Assf, Sf, Hauf> {
|
||||
/// The state of the listener.
|
||||
pub state: St,
|
||||
|
||||
/// The function used as [`AsUpdateStream::as_stream`].
|
||||
///
|
||||
/// Must implement `for<'a> FnMut(&'a mut St) -> impl Stream + 'a`.
|
||||
pub stream: Assf,
|
||||
|
||||
/// The function used as [`UpdateListener::stop_token`].
|
||||
///
|
||||
/// Must implement `FnMut(&mut St) -> StopToken`.
|
||||
pub stop_token: Sf,
|
||||
|
||||
/// The function used as [`UpdateListener::hint_allowed_updates`].
|
||||
///
|
||||
/// Must implement `FnMut(&mut St, &mut dyn Iterator<Item =
|
||||
/// AllowedUpdate>)`.
|
||||
pub hint_allowed_updates: Option<Hauf>,
|
||||
}
|
||||
|
||||
type Haufn<State> = for<'a, 'b> fn(&'a mut State, &'b mut dyn Iterator<Item = AllowedUpdate>);
|
||||
|
||||
impl<St, Assf, Sf> StatefulListener<St, Assf, Sf, Haufn<St>> {
|
||||
/// Creates a new stateful listener from its components.
|
||||
pub fn new(state: St, stream: Assf, stop_token: Sf) -> Self {
|
||||
Self::new_with_hints(state, stream, stop_token, None)
|
||||
}
|
||||
}
|
||||
|
||||
impl<St, Assf, Sf, Hauf> StatefulListener<St, Assf, Sf, Hauf> {
|
||||
/// Creates a new stateful listener from its components.
|
||||
pub fn new_with_hints(
|
||||
state: St,
|
||||
stream: Assf,
|
||||
stop_token: Sf,
|
||||
hint_allowed_updates: Option<Hauf>,
|
||||
) -> Self {
|
||||
Self { state, stream, stop_token, hint_allowed_updates }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, St, Assf, Sf, Hauf, Strm, E> AsUpdateStream<'a> for StatefulListener<St, Assf, Hauf, Sf>
|
||||
where
|
||||
(St, Strm): 'a,
|
||||
Strm: Send,
|
||||
Assf: FnMut(&'a mut St) -> Strm,
|
||||
Strm: Stream<Item = Result<Update, E>>,
|
||||
{
|
||||
type StreamErr = E;
|
||||
type Stream = Strm;
|
||||
|
||||
fn as_stream(&'a mut self) -> Self::Stream {
|
||||
(self.stream)(&mut self.state)
|
||||
}
|
||||
}
|
||||
|
||||
impl<St, Assf, Sf, Hauf, E> UpdateListener for StatefulListener<St, Assf, Sf, Hauf>
|
||||
where
|
||||
Self: for<'a> AsUpdateStream<'a, StreamErr = E>,
|
||||
Sf: FnMut(&mut St) -> StopToken,
|
||||
Hauf: FnMut(&mut St, &mut dyn Iterator<Item = AllowedUpdate>),
|
||||
{
|
||||
type Err = E;
|
||||
|
||||
fn stop_token(&mut self) -> StopToken {
|
||||
(self.stop_token)(&mut self.state)
|
||||
}
|
||||
|
||||
fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator<Item = AllowedUpdate>) {
|
||||
if let Some(f) = &mut self.hint_allowed_updates {
|
||||
f(&mut self.state, hint);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
//!
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use crate::{requests::Requester, types::InputFile};
|
||||
use crate::{requests::Requester, types::InputFile, update_listeners::AllowedUpdate};
|
||||
|
||||
/// Options related to setting up webhooks.
|
||||
#[must_use]
|
||||
|
@ -116,7 +116,7 @@ impl Options {
|
|||
}
|
||||
|
||||
#[cfg(feature = "webhooks-axum")]
|
||||
pub use self::axum::{axum, axum_no_setup, axum_to_router};
|
||||
pub use self::axum::axum;
|
||||
|
||||
#[cfg(feature = "webhooks-axum")]
|
||||
mod axum;
|
||||
|
@ -124,25 +124,26 @@ mod axum;
|
|||
// TODO: add different implementation (for example: warp)
|
||||
|
||||
/// Calls `set_webhook` with arguments from `options`.
|
||||
///
|
||||
/// Note: this takes out `certificate`.
|
||||
async fn setup_webhook<R>(bot: R, options: &mut Options) -> Result<(), R::Err>
|
||||
async fn setup_webhook<R>(
|
||||
bot: R,
|
||||
options: &Options,
|
||||
allowed_updates: Option<Vec<AllowedUpdate>>,
|
||||
) -> Result<(), R::Err>
|
||||
where
|
||||
R: Requester,
|
||||
{
|
||||
use crate::requests::Request;
|
||||
use teloxide_core::requests::HasPayload;
|
||||
|
||||
let secret = options.get_or_gen_secret_token().to_owned();
|
||||
let &mut Options {
|
||||
ref url, ref mut certificate, max_connections, drop_pending_updates, ..
|
||||
} = options;
|
||||
let Options { url, certificate, secret_token, max_connections, drop_pending_updates, .. } =
|
||||
options;
|
||||
|
||||
let mut req = bot.set_webhook(url.clone());
|
||||
req.payload_mut().certificate = certificate.take();
|
||||
req.payload_mut().max_connections = max_connections;
|
||||
req.payload_mut().drop_pending_updates = Some(drop_pending_updates);
|
||||
req.payload_mut().secret_token = Some(secret);
|
||||
req.payload_mut().certificate = certificate.clone();
|
||||
req.payload_mut().max_connections = *max_connections;
|
||||
req.payload_mut().drop_pending_updates = Some(*drop_pending_updates);
|
||||
req.payload_mut().secret_token = secret_token.clone();
|
||||
req.payload_mut().allowed_updates = allowed_updates;
|
||||
|
||||
req.send().await?;
|
||||
|
||||
|
@ -182,17 +183,3 @@ fn check_secret(bytes: &[u8]) -> Result<&[u8], &'static str> {
|
|||
|
||||
Ok(bytes)
|
||||
}
|
||||
|
||||
/// Returns first (`.0`) field from a tuple as a `&mut` reference.
|
||||
///
|
||||
/// This hack is needed because there isn't currently a way to easily force a
|
||||
/// closure to be higher-ranked (`for<'a> &'a mut _ -> &'a mut _`) which causes
|
||||
/// problems when using [`StatefulListener`] to implement update listener.
|
||||
///
|
||||
/// This could be probably removed once [rfc#3216] is implemented.
|
||||
///
|
||||
/// [`StatefulListener`]:
|
||||
/// [rfc#3216]: https://github.com/rust-lang/rfcs/pull/3216
|
||||
fn tuple_first_mut<A, B>(tuple: &mut (A, B)) -> &mut A {
|
||||
&mut tuple.0
|
||||
}
|
||||
|
|
|
@ -1,18 +1,66 @@
|
|||
use std::{convert::Infallible, future::Future, pin::Pin};
|
||||
use crate::{
|
||||
requests::{Request, Requester},
|
||||
stop::{mk_stop_token, StopFlag},
|
||||
types::{AllowedUpdate, True, Update, UpdateKind},
|
||||
update_listeners::{
|
||||
webhooks::{setup_webhook, Options},
|
||||
StopToken, UpdateListener,
|
||||
},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
extract::{FromRequestParts, State},
|
||||
http::{request::Parts, status::StatusCode},
|
||||
response::IntoResponse,
|
||||
routing::post,
|
||||
};
|
||||
use futures::stream::Stream;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
use crate::{
|
||||
requests::Requester,
|
||||
stop::StopFlag,
|
||||
types::{Update, UpdateKind},
|
||||
update_listeners::{webhooks::Options, UpdateListener},
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
error::Error,
|
||||
fmt::{self, Debug, Display},
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task,
|
||||
task::Poll,
|
||||
};
|
||||
|
||||
/// A webhook update listener backed by [`axum`](mod@axum).
|
||||
pub struct Axum<B> {
|
||||
bot: B,
|
||||
options: Options,
|
||||
token: StopToken,
|
||||
flag: Option<StopFlag>,
|
||||
allowed_updates: Option<Vec<AllowedUpdate>>,
|
||||
/// This is a stream of updates, coming from an `axum::Router` we've
|
||||
/// created.
|
||||
///
|
||||
/// N.B. This field is only initialized by `take_router` and is only
|
||||
/// de-initialized by `listen`. Basically, it's a way to pass the
|
||||
/// channel from the router creation, to the listener stream.
|
||||
stream: Option<UnboundedReceiverStream<Update>>,
|
||||
}
|
||||
|
||||
#[pin_project::pin_project]
|
||||
pub struct AxumStream<'a, B: Requester> {
|
||||
axum: &'a mut Axum<B>,
|
||||
|
||||
#[pin]
|
||||
inner: UnboundedReceiverStream<Update>,
|
||||
|
||||
#[pin]
|
||||
webhook_deletion: Option<Option<<B::DeleteWebhook as Request>::Send>>,
|
||||
}
|
||||
|
||||
pub enum SetupError<B: Requester> {
|
||||
Bind(hyper::Error),
|
||||
SetWebhook(B::Err),
|
||||
}
|
||||
|
||||
/// Webhook implementation based on the [mod@axum] framework.
|
||||
///
|
||||
/// This function does all the work necessary for webhook to work, it:
|
||||
|
@ -23,38 +71,71 @@ use crate::{
|
|||
/// [`set_webhook`]: crate::payloads::SetWebhook
|
||||
/// [`delete_webhook`]: crate::payloads::DeleteWebhook
|
||||
/// [`stop`]: crate::stop::StopToken::stop
|
||||
///
|
||||
/// ## Panics
|
||||
///
|
||||
/// If binding to the [address] fails.
|
||||
///
|
||||
/// [address]: Options::address
|
||||
///
|
||||
/// ## Fails
|
||||
///
|
||||
/// If `set_webhook()` fails.
|
||||
///
|
||||
/// ## See also
|
||||
///
|
||||
/// [`axum_to_router`] and [`axum_no_setup`] for lower-level versions of this
|
||||
/// function.
|
||||
pub async fn axum<R>(
|
||||
bot: R,
|
||||
options: Options,
|
||||
) -> Result<impl UpdateListener<Err = Infallible>, R::Err>
|
||||
pub fn axum<R>(bot: R, mut options: Options) -> Axum<R>
|
||||
where
|
||||
R: Requester + Send + 'static,
|
||||
<R as Requester>::DeleteWebhook: Send,
|
||||
R: Requester + Sync + Send + Clone + 'static,
|
||||
R::SetWebhook: Send,
|
||||
R::DeleteWebhook: Send,
|
||||
{
|
||||
let Options { address, .. } = options;
|
||||
_ = options.get_or_gen_secret_token();
|
||||
let (token, flag) = mk_stop_token();
|
||||
|
||||
let (mut update_listener, stop_flag, app) = axum_to_router(bot, options).await?;
|
||||
let stop_token = update_listener.stop_token();
|
||||
Axum { bot, options, token, flag: Some(flag), allowed_updates: None, stream: None }
|
||||
}
|
||||
|
||||
impl<B> Axum<B> {
|
||||
/// Returns a router that will listen to updates.
|
||||
///
|
||||
/// N.B. you need to get a new router each time you re-start dispatching.
|
||||
pub fn take_router(&mut self) -> Option<axum::Router> {
|
||||
match self.stream {
|
||||
None => {
|
||||
self.reinit_stop_flag_if_needed();
|
||||
let stop_flag = self.flag.as_ref().unwrap().clone();
|
||||
let (router, stream) = create_router(&self.options, stop_flag);
|
||||
self.stream = Some(stream);
|
||||
Some(router)
|
||||
}
|
||||
Some(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn reinit_stop_flag_if_needed(&mut self) {
|
||||
if self.flag.is_none() {
|
||||
let (token, flag) = mk_stop_token();
|
||||
self.token = token;
|
||||
self.flag = Some(flag);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> UpdateListener for Axum<B>
|
||||
where
|
||||
B: Requester + Sync + Send + 'static,
|
||||
B::SetWebhook: Send,
|
||||
{
|
||||
type SetupErr = SetupError<B>;
|
||||
type StreamErr = Infallible;
|
||||
type Stream<'a> = AxumStream<'a, B>;
|
||||
|
||||
fn listen(
|
||||
&mut self,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'_>, Self::SetupErr>> + Send + '_>> {
|
||||
Box::pin(async {
|
||||
// Unwrap: `stop_flag` always returns `Some`.
|
||||
self.reinit_stop_flag_if_needed();
|
||||
let stop_flag = self.flag.take().unwrap();
|
||||
let stop_token = self.token.clone();
|
||||
|
||||
// If the user did not take the router themselves — spawn an axum server
|
||||
if let Some(router) = self.take_router() {
|
||||
let server = axum::Server::try_bind(&self.options.address)
|
||||
.map_err(SetupError::Bind)?
|
||||
.serve(router.into_make_service())
|
||||
.with_graceful_shutdown(stop_flag);
|
||||
|
||||
tokio::spawn(async move {
|
||||
axum::Server::bind(&address)
|
||||
.serve(app.into_make_service())
|
||||
.with_graceful_shutdown(stop_flag)
|
||||
server
|
||||
.await
|
||||
.map_err(|err| {
|
||||
stop_token.stop();
|
||||
|
@ -62,114 +143,139 @@ where
|
|||
})
|
||||
.expect("Axum server error");
|
||||
});
|
||||
}
|
||||
|
||||
Ok(update_listener)
|
||||
// Unwrap: just called `take_router`
|
||||
let stream = self.stream.take().unwrap();
|
||||
|
||||
setup_webhook(&self.bot, &self.options, self.allowed_updates.clone())
|
||||
.await
|
||||
.map_err(SetupError::SetWebhook)?;
|
||||
|
||||
let stream = AxumStream { axum: self, inner: stream, webhook_deletion: None };
|
||||
|
||||
Ok(stream)
|
||||
})
|
||||
}
|
||||
|
||||
fn stop_token(&mut self) -> StopToken {
|
||||
self.reinit_stop_flag_if_needed();
|
||||
self.token.clone()
|
||||
}
|
||||
|
||||
fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator<Item = AllowedUpdate>) {
|
||||
// TODO: we should probably warn if there already were different allowed updates
|
||||
// before
|
||||
self.allowed_updates = Some(hint.collect());
|
||||
}
|
||||
}
|
||||
|
||||
/// Webhook implementation based on the [mod@axum] framework that can reuse
|
||||
/// existing [mod@axum] server.
|
||||
///
|
||||
/// This function does most of the work necessary for webhook to work, it:
|
||||
/// - Calls [`set_webhook`], so telegram starts sending updates our way
|
||||
/// - When the update listener is [`stop`]ped, calls [`delete_webhook`]
|
||||
///
|
||||
/// The only missing part is running [mod@axum] server with a returned
|
||||
/// [`axum::Router`].
|
||||
///
|
||||
/// This function is intended to be used in cases when you already have an
|
||||
/// [mod@axum] server running and can reuse it for webhooks.
|
||||
///
|
||||
/// **Note**: in order for webhooks to work, you need to use returned
|
||||
/// [`axum::Router`] in an [mod@axum] server that is bound to
|
||||
/// [`options.address`].
|
||||
///
|
||||
/// It may also be desired to use [`with_graceful_shutdown`] with the returned
|
||||
/// future in order to shutdown the server with the [`stop`] of the listener.
|
||||
///
|
||||
/// [`set_webhook`]: crate::payloads::SetWebhook
|
||||
/// [`delete_webhook`]: crate::payloads::DeleteWebhook
|
||||
/// [`stop`]: crate::stop::StopToken::stop
|
||||
/// [`options.address`]: Options::address
|
||||
/// [`with_graceful_shutdown`]: axum::Server::with_graceful_shutdown
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// A update listener, stop-future, axum router triplet on success.
|
||||
///
|
||||
/// The "stop-future" is resolved after [`stop`] is called on the stop token of
|
||||
/// the returned update listener.
|
||||
///
|
||||
/// ## Fails
|
||||
///
|
||||
/// If `set_webhook()` fails.
|
||||
///
|
||||
/// ## See also
|
||||
///
|
||||
/// [`fn@axum`] for higher-level and [`axum_no_setup`] for lower-level
|
||||
/// versions of this function.
|
||||
pub async fn axum_to_router<R>(
|
||||
bot: R,
|
||||
mut options: Options,
|
||||
) -> Result<
|
||||
(impl UpdateListener<Err = Infallible>, impl Future<Output = ()> + Send, axum::Router),
|
||||
R::Err,
|
||||
>
|
||||
impl<B> Stream for AxumStream<'_, B>
|
||||
where
|
||||
R: Requester + Send,
|
||||
<R as Requester>::DeleteWebhook: Send,
|
||||
B: Requester,
|
||||
{
|
||||
use crate::{requests::Request, update_listeners::webhooks::setup_webhook};
|
||||
use futures::FutureExt;
|
||||
type Item = Result<Update, Infallible>;
|
||||
|
||||
setup_webhook(&bot, &mut options).await?;
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let mut this = self.project();
|
||||
|
||||
let (listener, stop_flag, router) = axum_no_setup(options);
|
||||
match this.inner.poll_next(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(Some(upd)) => Poll::Ready(Some(Ok(upd))),
|
||||
Poll::Ready(None) => {
|
||||
if let Some(mut deletion) = this.webhook_deletion.as_mut().as_pin_mut() {
|
||||
if let Some(future) = deletion.as_mut().as_pin_mut() {
|
||||
// `Some(Some(_))` — we are currently deleting webhook
|
||||
match future.poll(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
|
||||
let stop_flag = stop_flag.then(move |()| async move {
|
||||
// This assignment is needed to not require `R: Sync` since without it `&bot`
|
||||
// temporary lives across `.await` points.
|
||||
let req = bot.delete_webhook().send();
|
||||
let res = req.await;
|
||||
if let Err(err) = res {
|
||||
log::error!("Couldn't delete webhook: {}", err);
|
||||
// We completed webhook deletion (potentially failed)
|
||||
Poll::Ready(Ok(True) | Err(_)) => {
|
||||
this.webhook_deletion.set(Some(None));
|
||||
Poll::Ready(None)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// `Some(None)` — we've already deleted webhook
|
||||
Poll::Ready(None)
|
||||
}
|
||||
} else {
|
||||
// `None` — we haven't yet started deleting webhook
|
||||
|
||||
this.webhook_deletion.set(Some(Some(this.axum.bot.delete_webhook().send())));
|
||||
|
||||
// Immediately wake up to poll `self.in_flight`
|
||||
// (without this this stream becomes a zombie)
|
||||
cx.waker().wake_by_ref();
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Requester> Debug for SetupError<R>
|
||||
where
|
||||
R: Requester,
|
||||
R::Err: Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Bind(e) => f.debug_tuple("Bind").field(&e).finish(),
|
||||
Self::SetWebhook(e) => f.debug_tuple("SetWebhook").field(&e).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> fmt::Display for SetupError<R>
|
||||
where
|
||||
R: Requester,
|
||||
R::Err: Display,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Bind(e) => write!(f, "Error while binding an address for webhooks: {e}"),
|
||||
Self::SetWebhook(e) => write!(f, "Error while setting up webhooks: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> Error for SetupError<R>
|
||||
where
|
||||
R: Requester,
|
||||
R::Err: Error + Debug + Display + 'static,
|
||||
{
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
match self {
|
||||
Self::Bind(e) => Some(e),
|
||||
Self::SetWebhook(e) => Some(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_router(
|
||||
options: &Options,
|
||||
stop_flag: StopFlag,
|
||||
) -> (axum::Router, UnboundedReceiverStream<Update>) {
|
||||
let (tx, rx): (mpsc::UnboundedSender<Update>, _) = mpsc::unbounded_channel();
|
||||
|
||||
let app = axum::Router::new()
|
||||
.route(options.url.path(), post(telegram_request))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.with_state(WebhookState {
|
||||
tx: ClosableSender::new(tx),
|
||||
flag: stop_flag,
|
||||
secret: options.secret_token.clone(),
|
||||
});
|
||||
|
||||
Ok((listener, stop_flag, router))
|
||||
(app, rx.into())
|
||||
}
|
||||
|
||||
/// Webhook implementation based on the [mod@axum] framework that doesn't
|
||||
/// perform any setup work.
|
||||
///
|
||||
/// ## Note about the stop-future
|
||||
///
|
||||
/// This function returns a future that is resolved when `.stop()` is called on
|
||||
/// a stop token of the update listener. Note that even if the future is not
|
||||
/// used, after `.stop()` is called, update listener will not produce new
|
||||
/// updates.
|
||||
///
|
||||
/// ## See also
|
||||
///
|
||||
/// [`fn@axum`] and [`axum_to_router`] for higher-level versions of this
|
||||
/// function.
|
||||
pub fn axum_no_setup(
|
||||
options: Options,
|
||||
) -> (impl UpdateListener<Err = Infallible>, impl Future<Output = ()>, axum::Router) {
|
||||
use crate::{
|
||||
stop::{mk_stop_token, StopToken},
|
||||
update_listeners::{webhooks::tuple_first_mut, StatefulListener},
|
||||
};
|
||||
use axum::{response::IntoResponse, routing::post};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
let (tx, rx): (UpdateSender, _) = mpsc::unbounded_channel();
|
||||
|
||||
async fn telegram_request(
|
||||
async fn telegram_request(
|
||||
State(WebhookState { secret, flag, mut tx }): State<WebhookState>,
|
||||
secret_header: XTelegramBotApiSecretToken,
|
||||
input: String,
|
||||
) -> impl IntoResponse {
|
||||
) -> impl IntoResponse {
|
||||
// FIXME: use constant time comparison here
|
||||
if secret_header.0.as_deref() != secret.as_deref().map(str::as_bytes) {
|
||||
return StatusCode::UNAUTHORIZED;
|
||||
|
@ -194,7 +300,7 @@ pub fn axum_no_setup(
|
|||
*value = serde_json::from_str(&input).unwrap_or_default();
|
||||
}
|
||||
|
||||
tx.send(Ok(update)).expect("Cannot send an incoming update from the webhook")
|
||||
tx.send(update).expect("Cannot send an incoming update from the webhook")
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!(
|
||||
|
@ -208,37 +314,11 @@ pub fn axum_no_setup(
|
|||
};
|
||||
|
||||
StatusCode::OK
|
||||
}
|
||||
|
||||
let (stop_token, stop_flag) = mk_stop_token();
|
||||
|
||||
let app = axum::Router::new()
|
||||
.route(options.url.path(), post(telegram_request))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.with_state(WebhookState {
|
||||
tx: ClosableSender::new(tx),
|
||||
flag: stop_flag.clone(),
|
||||
secret: options.secret_token,
|
||||
});
|
||||
|
||||
let stream = UnboundedReceiverStream::new(rx);
|
||||
|
||||
// FIXME: this should support `hint_allowed_updates()`
|
||||
let listener = StatefulListener::new(
|
||||
(stream, stop_token),
|
||||
tuple_first_mut,
|
||||
|state: &mut (_, StopToken)| state.1.clone(),
|
||||
);
|
||||
|
||||
(listener, stop_flag, app)
|
||||
}
|
||||
|
||||
type UpdateSender = mpsc::UnboundedSender<Result<Update, std::convert::Infallible>>;
|
||||
type UpdateCSender = ClosableSender<Result<Update, std::convert::Infallible>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct WebhookState {
|
||||
tx: UpdateCSender,
|
||||
tx: ClosableSender<Update>,
|
||||
flag: StopFlag,
|
||||
secret: Option<String>,
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue