Refactor UpdateListener to allow properly setting allowed updates for webhooks

This commit is contained in:
Maybe Waffle 2024-01-11 01:12:55 +01:00
parent 0d47b40137
commit 316ca97886
11 changed files with 469 additions and 437 deletions

View file

@ -21,7 +21,7 @@ categories = ["web-programming", "api-bindings", "asynchronous"]
default = ["native-tls", "ctrlc_handler", "teloxide-core/default", "auto-send"] default = ["native-tls", "ctrlc_handler", "teloxide-core/default", "auto-send"]
webhooks = ["rand"] webhooks = ["rand"]
webhooks-axum = ["webhooks", "axum", "tower", "tower-http"] webhooks-axum = ["webhooks", "axum", "tower", "tower-http", "hyper"]
# FIXME: rename `sqlite-storage` -> `sqlite-storage-nativetls` # FIXME: rename `sqlite-storage` -> `sqlite-storage-nativetls`
sqlite-storage = ["sqlx", "sqlx/runtime-tokio-native-tls", "native-tls"] sqlite-storage = ["sqlx", "sqlx/runtime-tokio-native-tls", "native-tls"]
@ -109,6 +109,7 @@ bincode = { version = "1.3", optional = true }
axum = { version = "0.6.0", optional = true } axum = { version = "0.6.0", optional = true }
tower = { version = "0.4.12", optional = true } tower = { version = "0.4.12", optional = true }
tower-http = { version = "0.3.4", features = ["trace"], optional = true } tower-http = { version = "0.3.4", features = ["trace"], optional = true }
hyper = { version = "0.14", optional = true, default-features = false }
rand = { version = "0.8.5", optional = true } rand = { version = "0.8.5", optional = true }

View file

@ -41,9 +41,7 @@ async fn main() {
let host = env::var("HOST").expect("HOST env variable is not set"); let host = env::var("HOST").expect("HOST env variable is not set");
let url = format!("https://{host}/webhook").parse().unwrap(); let url = format!("https://{host}/webhook").parse().unwrap();
let listener = webhooks::axum(bot.clone(), webhooks::Options::new(addr, url)) let listener = webhooks::axum(bot.clone(), webhooks::Options::new(addr, url));
.await
.expect("Couldn't setup webhook");
teloxide::repl_with_listener( teloxide::repl_with_listener(
bot, bot,

View file

@ -12,9 +12,7 @@ async fn main() {
let addr = ([127, 0, 0, 1], 8443).into(); let addr = ([127, 0, 0, 1], 8443).into();
let url = "Your HTTPS ngrok URL here. Get it by `ngrok http 8443`".parse().unwrap(); let url = "Your HTTPS ngrok URL here. Get it by `ngrok http 8443`".parse().unwrap();
let listener = webhooks::axum(bot.clone(), webhooks::Options::new(addr, url)) let listener = webhooks::axum(bot.clone(), webhooks::Options::new(addr, url));
.await
.expect("Couldn't setup webhook");
teloxide::repl_with_listener( teloxide::repl_with_listener(
bot, bot,

View file

@ -20,7 +20,8 @@ use tokio_stream::wrappers::ReceiverStream;
use std::{ use std::{
collections::HashMap, collections::HashMap,
fmt::Debug, error::Error,
fmt::{self, Debug, Display},
future::Future, future::Future,
hash::Hash, hash::Hash,
ops::{ControlFlow, Deref}, ops::{ControlFlow, Deref},
@ -214,6 +215,15 @@ pub struct Dispatcher<R, Err, Key> {
state: ShutdownToken, state: ShutdownToken,
} }
/// An error returned from [`Disatcher::try_dispatch_with_listener`].
pub enum TryDispatchError<R: Requester, L: UpdateListener> {
/// An error from calling `get_me` while creating dispatcher context.
GetMe(R::Err),
/// An error during update listener setup.
ListenerSetup(L::SetupErr),
}
struct Worker { struct Worker {
tx: tokio::sync::mpsc::Sender<Update>, tx: tokio::sync::mpsc::Sender<Update>,
handle: tokio::task::JoinHandle<()>, handle: tokio::task::JoinHandle<()>,
@ -300,8 +310,8 @@ where
update_listener_error_handler: Arc<Eh>, update_listener_error_handler: Arc<Eh>,
) where ) where
UListener: UpdateListener + 'a, UListener: UpdateListener + 'a,
Eh: ErrorHandler<UListener::Err> + 'a, Eh: ErrorHandler<UListener::StreamErr> + 'a,
UListener::Err: Debug, UListener::SetupErr: Debug,
{ {
self.try_dispatch_with_listener(update_listener, update_listener_error_handler) self.try_dispatch_with_listener(update_listener, update_listener_error_handler)
.await .await
@ -319,14 +329,13 @@ where
&'a mut self, &'a mut self,
mut update_listener: UListener, mut update_listener: UListener,
update_listener_error_handler: Arc<Eh>, update_listener_error_handler: Arc<Eh>,
) -> Result<(), R::Err> ) -> Result<(), TryDispatchError<R, UListener>>
where where
UListener: UpdateListener + 'a, UListener: UpdateListener + 'a,
Eh: ErrorHandler<UListener::Err> + 'a, Eh: ErrorHandler<UListener::StreamErr> + 'a,
UListener::Err: Debug,
{ {
// FIXME: there should be a way to check if dependency is already inserted // FIXME: there should be a way to check if dependency is already inserted
let me = self.bot.get_me().send().await?; let me = self.bot.get_me().send().await.map_err(TryDispatchError::GetMe)?;
self.dependencies.insert(me); self.dependencies.insert(me);
self.dependencies.insert(self.bot.clone()); self.dependencies.insert(self.bot.clone());
@ -340,7 +349,7 @@ where
self.state.start_dispatching(); self.state.start_dispatching();
{ {
let stream = update_listener.as_stream(); let stream = update_listener.listen().await.map_err(TryDispatchError::ListenerSetup)?;
tokio::pin!(stream); tokio::pin!(stream);
loop { loop {
@ -521,6 +530,51 @@ impl<R, Err, Key> Dispatcher<R, Err, Key> {
} }
} }
impl<R: Requester, L: UpdateListener> Debug for TryDispatchError<R, L>
where
R: Requester,
R::Err: Debug,
L: UpdateListener,
L::SetupErr: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::GetMe(e) => f.debug_tuple("GetMe").field(&e).finish(),
Self::ListenerSetup(e) => f.debug_tuple("ListenerSetup").field(&e).finish(),
}
}
}
impl<R, L> fmt::Display for TryDispatchError<R, L>
where
R: Requester,
R::Err: Display,
L: UpdateListener,
L::SetupErr: Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::GetMe(e) => write!(f, "Error while setting up update listener: {e}"),
Self::ListenerSetup(e) => write!(f, "Error while setting up update listener: {e}"),
}
}
}
impl<R, L> Error for TryDispatchError<R, L>
where
R: Requester,
R::Err: Error + Debug + Display + 'static,
L: UpdateListener,
L::SetupErr: Error + Debug + Display + 'static,
{
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::GetMe(e) => Some(e),
Self::ListenerSetup(e) => Some(e),
}
}
}
fn spawn_worker<Err>( fn spawn_worker<Err>(
deps: DependencyMap, deps: DependencyMap,
handler: Arc<UpdateHandler<Err>>, handler: Arc<UpdateHandler<Err>>,

View file

@ -85,8 +85,11 @@ pub trait CommandReplExt {
fn repl_with_listener<'a, R, H, L, Args>(bot: R, handler: H, listener: L) -> BoxFuture<'a, ()> fn repl_with_listener<'a, R, H, L, Args>(bot: R, handler: H, listener: L) -> BoxFuture<'a, ()>
where where
H: Injectable<DependencyMap, ResponseResult<()>, Args> + Send + Sync + 'static, H: Injectable<DependencyMap, ResponseResult<()>, Args> + Send + Sync + 'static,
L: UpdateListener + Send + 'a, // FIXME: why do we need + 'static here??? (and in similar cases) (compiler error does not
L::Err: Debug + Send + 'a, // provide enough info...)
L: UpdateListener + Send + 'static,
L::SetupErr: Debug,
L::StreamErr: Debug + Send + 'a,
R: Requester + Clone + Send + Sync + 'static, R: Requester + Clone + Send + Sync + 'static,
<R as Requester>::GetMe: Send; <R as Requester>::GetMe: Send;
} }
@ -120,8 +123,9 @@ where
fn repl_with_listener<'a, R, H, L, Args>(bot: R, handler: H, listener: L) -> BoxFuture<'a, ()> fn repl_with_listener<'a, R, H, L, Args>(bot: R, handler: H, listener: L) -> BoxFuture<'a, ()>
where where
H: Injectable<DependencyMap, ResponseResult<()>, Args> + Send + Sync + 'static, H: Injectable<DependencyMap, ResponseResult<()>, Args> + Send + Sync + 'static,
L: UpdateListener + Send + 'a, L: UpdateListener + Send + 'static,
L::Err: Debug + Send + 'a, L::SetupErr: Debug,
L::StreamErr: Debug + Send + 'a,
R: Requester + Clone + Send + Sync + 'static, R: Requester + Clone + Send + Sync + 'static,
<R as Requester>::GetMe: Send, <R as Requester>::GetMe: Send,
{ {
@ -277,7 +281,8 @@ pub async fn commands_repl_with_listener<'a, R, Cmd, H, L, Args>(
Cmd: BotCommands + Send + Sync + 'static, Cmd: BotCommands + Send + Sync + 'static,
H: Injectable<DependencyMap, ResponseResult<()>, Args> + Send + Sync + 'static, H: Injectable<DependencyMap, ResponseResult<()>, Args> + Send + Sync + 'static,
L: UpdateListener + Send + 'a, L: UpdateListener + Send + 'a,
L::Err: Debug + Send + 'a, L::SetupErr: Debug,
L::StreamErr: Debug,
R: Requester + Clone + Send + Sync + 'static, R: Requester + Clone + Send + Sync + 'static,
{ {
use crate::dispatching::Dispatcher; use crate::dispatching::Dispatcher;

View file

@ -108,7 +108,8 @@ where
R: Requester + Clone + Send + Sync + 'static, R: Requester + Clone + Send + Sync + 'static,
H: Injectable<DependencyMap, ResponseResult<()>, Args> + Send + Sync + 'static, H: Injectable<DependencyMap, ResponseResult<()>, Args> + Send + Sync + 'static,
L: UpdateListener + Send, L: UpdateListener + Send,
L::Err: Debug, L::SetupErr: Debug,
L::StreamErr: Debug,
{ {
use crate::dispatching::Dispatcher; use crate::dispatching::Dispatcher;

View file

@ -30,7 +30,9 @@
#[cfg(feature = "webhooks")] #[cfg(feature = "webhooks")]
pub mod webhooks; pub mod webhooks;
use futures::Stream; use std::pin::Pin;
use futures::{Future, Stream};
use crate::{ use crate::{
stop::StopToken, stop::StopToken,
@ -38,30 +40,54 @@ use crate::{
}; };
mod polling; mod polling;
mod stateful_listener;
#[allow(deprecated)] #[allow(deprecated)]
pub use self::{ pub use self::polling::{polling, polling_default, Polling, PollingBuilder, PollingStream};
polling::{polling, polling_default, Polling, PollingBuilder, PollingStream},
stateful_listener::StatefulListener,
};
/// An update listener. /// An update listener.
/// ///
/// Implementors of this trait allow getting updates from Telegram. See /// Implementors of this trait allow getting updates from Telegram. See
/// [module-level documentation] for more. /// [module-level documentation] for more.
/// ///
/// Some functions of this trait are located in the supertrait
/// ([`AsUpdateStream`]), see also:
/// - [`AsUpdateStream::Stream`]
/// - [`AsUpdateStream::as_stream`]
///
/// [module-level documentation]: mod@self /// [module-level documentation]: mod@self
pub trait UpdateListener: pub trait UpdateListener {
for<'a> AsUpdateStream<'a, StreamErr = <Self as UpdateListener>::Err> type SetupErr;
{
/// The type of errors that can be returned from this listener. /// Error that can be returned from the [`Stream`]
type Err; ///
/// [`Stream`]: UpdateListener::Stream
type StreamErr;
/// The stream of updates from Telegram.
// NB: `Send` is not strictly required here, but it makes it easier to return
// `impl AsUpdateStream` and also you want `Send` streams almost (?) always
// anyway.
type Stream<'a>: Stream<Item = Result<Update, Self::StreamErr>> + Send + 'a
where
Self: 'a;
/// Creates the update [`Stream`].
///
/// This function should also do all the necessary setup, and return an
/// error if something goes wrong with it. For example for webhooks this
/// should call `set_webhook`.
///
/// [`Stream`]: AsUpdateStream::Stream
fn listen(
&mut self,
) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'_>, Self::SetupErr>> + Send + '_>>;
/// Hint which updates should the listener listen for.
///
/// For example [`polling()`] should send the hint as
/// [`GetUpdates::allowed_updates`]
///
/// Note: this is a very important method, without setting appropriate
/// allowed updates, telegram will not send some update kinds.
///
/// [`GetUpdates::allowed_updates`]:
/// crate::payloads::GetUpdates::allowed_updates
fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator<Item = AllowedUpdate>);
/// Returns a token which stops this listener. /// Returns a token which stops this listener.
/// ///
@ -77,44 +103,6 @@ pub trait UpdateListener:
#[must_use = "This function doesn't stop listening, to stop listening you need to call `stop` \ #[must_use = "This function doesn't stop listening, to stop listening you need to call `stop` \
on the returned token"] on the returned token"]
fn stop_token(&mut self) -> StopToken; fn stop_token(&mut self) -> StopToken;
/// Hint which updates should the listener listen for.
///
/// For example [`polling()`] should send the hint as
/// [`GetUpdates::allowed_updates`]
///
/// Note however that this is a _hint_ and as such, it can be ignored. The
/// listener is not guaranteed to only return updates which types are listed
/// in the hint.
///
/// [`GetUpdates::allowed_updates`]:
/// crate::payloads::GetUpdates::allowed_updates
fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator<Item = AllowedUpdate>) {
let _ = hint;
}
}
/// [`UpdateListener`]'s supertrait/extension.
///
/// This trait is a workaround to not require GAT.
pub trait AsUpdateStream<'a> {
/// Error that can be returned from the [`Stream`]
///
/// [`Stream`]: AsUpdateStream::Stream
// NB: This should be named differently to `UpdateListener::Err`, so that it's
// unambiguous
type StreamErr;
/// The stream of updates from Telegram.
// NB: `Send` is not strictly required here, but it makes it easier to return
// `impl AsUpdateStream` and also you want `Send` streams almost (?) always
// anyway.
type Stream: Stream<Item = Result<Update, Self::StreamErr>> + Send + 'a;
/// Creates the update [`Stream`].
///
/// [`Stream`]: AsUpdateStream::Stream
fn as_stream(&'a mut self) -> Self::Stream;
} }
#[inline(always)] #[inline(always)]

View file

@ -17,7 +17,7 @@ use crate::{
requests::{HasPayload, Request, Requester}, requests::{HasPayload, Request, Requester},
stop::{mk_stop_token, StopFlag, StopToken}, stop::{mk_stop_token, StopFlag, StopToken},
types::{AllowedUpdate, Update}, types::{AllowedUpdate, Update},
update_listeners::{assert_update_listener, AsUpdateStream, UpdateListener}, update_listeners::{assert_update_listener, UpdateListener},
}; };
/// Builder for polling update listener. /// Builder for polling update listener.
@ -243,7 +243,7 @@ where
/// [get_updates]: crate::requests::Requester::get_updates /// [get_updates]: crate::requests::Requester::get_updates
/// [`Dispatcher`]: crate::dispatching::Dispatcher /// [`Dispatcher`]: crate::dispatching::Dispatcher
#[must_use = "`Polling` is an update listener and does nothing unless used"] #[must_use = "`Polling` is an update listener and does nothing unless used"]
pub struct Polling<B: Requester> { pub struct Polling<B> {
bot: B, bot: B,
timeout: Option<Duration>, timeout: Option<Duration>,
limit: Option<u8>, limit: Option<u8>,
@ -320,7 +320,53 @@ pub struct PollingStream<'a, B: Requester> {
} }
impl<B: Requester + Send + 'static> UpdateListener for Polling<B> { impl<B: Requester + Send + 'static> UpdateListener for Polling<B> {
type Err = B::Err; type SetupErr = B::Err;
type StreamErr = B::Err;
type Stream<'a> = PollingStream<'a, B>;
fn listen(
&mut self,
) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'_>, Self::SetupErr>> + Send + '_>> {
Box::pin(async {
let timeout = self.timeout.map(|t| t.as_secs().try_into().expect("timeout is too big"));
let allowed_updates = self.allowed_updates.clone();
let drop_pending_updates = self.drop_pending_updates;
let token_used_and_updated = self.reinit_stop_flag_if_needed();
// FIXME: document that `listen` is a destructive operation, actually,
// and you need to call `stop_token` *again* after it
//
// maybe also remove the panic, it's a lot of additional work, for little
// benefit, it seems like.
if token_used_and_updated {
panic!(
"detected calling `as_stream` a second time after calling `stop_token`. \
`as_stream` updates the stop token, thus you need to call it again after \
calling `as_stream`"
)
}
// FIXME: do update dropping *here*
// Unwrap: just called reinit
let flag = self.flag.take().unwrap();
let stream = PollingStream {
polling: self,
drop_pending_updates,
timeout,
allowed_updates,
offset: 0,
force_stop: false,
stopping: false,
buffer: Vec::new().into_iter(),
in_flight: None,
flag,
};
Ok(stream)
})
}
fn stop_token(&mut self) -> StopToken { fn stop_token(&mut self) -> StopToken {
self.reinit_stop_flag_if_needed(); self.reinit_stop_flag_if_needed();
@ -335,44 +381,6 @@ impl<B: Requester + Send + 'static> UpdateListener for Polling<B> {
} }
} }
impl<'a, B: Requester + Send + 'a> AsUpdateStream<'a> for Polling<B> {
type StreamErr = B::Err;
type Stream = PollingStream<'a, B>;
fn as_stream(&'a mut self) -> Self::Stream {
let timeout = self.timeout.map(|t| t.as_secs().try_into().expect("timeout is too big"));
let allowed_updates = self.allowed_updates.clone();
let drop_pending_updates = self.drop_pending_updates;
let token_used_and_updated = self.reinit_stop_flag_if_needed();
// FIXME: document that `as_stream` is a destructive operation, actually,
// and you need to call `stop_token` *again* after it
if token_used_and_updated {
panic!(
"detected calling `as_stream` a second time after calling `stop_token`. \
`as_stream` updates the stop token, thus you need to call it again after calling \
`as_stream`"
)
}
// Unwrap: just called reinit
let flag = self.flag.take().unwrap();
PollingStream {
polling: self,
drop_pending_updates,
timeout,
allowed_updates,
offset: 0,
force_stop: false,
stopping: false,
buffer: Vec::new().into_iter(),
in_flight: None,
flag,
}
}
}
impl<B: Requester> Stream for PollingStream<'_, B> { impl<B: Requester> Stream for PollingStream<'_, B> {
type Item = Result<Update, B::Err>; type Item = Result<Update, B::Err>;
@ -471,8 +479,12 @@ fn polling_is_send() {
let mut polling = polling(bot, None, None, None); let mut polling = polling(bot, None, None, None);
assert_send(&polling); assert_send(&polling);
assert_send(&polling.as_stream()); assert_send(&polling.listen());
assert_send(&polling.stop_token()); assert_send(&polling.stop_token());
_ = async {
assert_send(&polling.listen().await.unwrap());
};
fn assert_send(_: &impl Send) {} fn assert_send(_: &impl Send) {}
} }

View file

@ -1,92 +0,0 @@
use futures::Stream;
use crate::{
stop::StopToken,
types::{AllowedUpdate, Update},
update_listeners::{AsUpdateStream, UpdateListener},
};
/// A listener created from functions.
///
/// This type allows to turn a stream of updates (+ some additional functions)
/// into an [`UpdateListener`].
///
/// For an example of usage, see [`polling`].
///
/// [`polling`]: crate::update_listeners::polling()
#[non_exhaustive]
pub struct StatefulListener<St, Assf, Sf, Hauf> {
/// The state of the listener.
pub state: St,
/// The function used as [`AsUpdateStream::as_stream`].
///
/// Must implement `for<'a> FnMut(&'a mut St) -> impl Stream + 'a`.
pub stream: Assf,
/// The function used as [`UpdateListener::stop_token`].
///
/// Must implement `FnMut(&mut St) -> StopToken`.
pub stop_token: Sf,
/// The function used as [`UpdateListener::hint_allowed_updates`].
///
/// Must implement `FnMut(&mut St, &mut dyn Iterator<Item =
/// AllowedUpdate>)`.
pub hint_allowed_updates: Option<Hauf>,
}
type Haufn<State> = for<'a, 'b> fn(&'a mut State, &'b mut dyn Iterator<Item = AllowedUpdate>);
impl<St, Assf, Sf> StatefulListener<St, Assf, Sf, Haufn<St>> {
/// Creates a new stateful listener from its components.
pub fn new(state: St, stream: Assf, stop_token: Sf) -> Self {
Self::new_with_hints(state, stream, stop_token, None)
}
}
impl<St, Assf, Sf, Hauf> StatefulListener<St, Assf, Sf, Hauf> {
/// Creates a new stateful listener from its components.
pub fn new_with_hints(
state: St,
stream: Assf,
stop_token: Sf,
hint_allowed_updates: Option<Hauf>,
) -> Self {
Self { state, stream, stop_token, hint_allowed_updates }
}
}
impl<'a, St, Assf, Sf, Hauf, Strm, E> AsUpdateStream<'a> for StatefulListener<St, Assf, Hauf, Sf>
where
(St, Strm): 'a,
Strm: Send,
Assf: FnMut(&'a mut St) -> Strm,
Strm: Stream<Item = Result<Update, E>>,
{
type StreamErr = E;
type Stream = Strm;
fn as_stream(&'a mut self) -> Self::Stream {
(self.stream)(&mut self.state)
}
}
impl<St, Assf, Sf, Hauf, E> UpdateListener for StatefulListener<St, Assf, Sf, Hauf>
where
Self: for<'a> AsUpdateStream<'a, StreamErr = E>,
Sf: FnMut(&mut St) -> StopToken,
Hauf: FnMut(&mut St, &mut dyn Iterator<Item = AllowedUpdate>),
{
type Err = E;
fn stop_token(&mut self) -> StopToken {
(self.stop_token)(&mut self.state)
}
fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator<Item = AllowedUpdate>) {
if let Some(f) = &mut self.hint_allowed_updates {
f(&mut self.state, hint);
}
}
}

View file

@ -1,7 +1,7 @@
//! //!
use std::net::SocketAddr; use std::net::SocketAddr;
use crate::{requests::Requester, types::InputFile}; use crate::{requests::Requester, types::InputFile, update_listeners::AllowedUpdate};
/// Options related to setting up webhooks. /// Options related to setting up webhooks.
#[must_use] #[must_use]
@ -116,7 +116,7 @@ impl Options {
} }
#[cfg(feature = "webhooks-axum")] #[cfg(feature = "webhooks-axum")]
pub use self::axum::{axum, axum_no_setup, axum_to_router}; pub use self::axum::axum;
#[cfg(feature = "webhooks-axum")] #[cfg(feature = "webhooks-axum")]
mod axum; mod axum;
@ -124,25 +124,26 @@ mod axum;
// TODO: add different implementation (for example: warp) // TODO: add different implementation (for example: warp)
/// Calls `set_webhook` with arguments from `options`. /// Calls `set_webhook` with arguments from `options`.
/// async fn setup_webhook<R>(
/// Note: this takes out `certificate`. bot: R,
async fn setup_webhook<R>(bot: R, options: &mut Options) -> Result<(), R::Err> options: &Options,
allowed_updates: Option<Vec<AllowedUpdate>>,
) -> Result<(), R::Err>
where where
R: Requester, R: Requester,
{ {
use crate::requests::Request; use crate::requests::Request;
use teloxide_core::requests::HasPayload; use teloxide_core::requests::HasPayload;
let secret = options.get_or_gen_secret_token().to_owned(); let Options { url, certificate, secret_token, max_connections, drop_pending_updates, .. } =
let &mut Options { options;
ref url, ref mut certificate, max_connections, drop_pending_updates, ..
} = options;
let mut req = bot.set_webhook(url.clone()); let mut req = bot.set_webhook(url.clone());
req.payload_mut().certificate = certificate.take(); req.payload_mut().certificate = certificate.clone();
req.payload_mut().max_connections = max_connections; req.payload_mut().max_connections = *max_connections;
req.payload_mut().drop_pending_updates = Some(drop_pending_updates); req.payload_mut().drop_pending_updates = Some(*drop_pending_updates);
req.payload_mut().secret_token = Some(secret); req.payload_mut().secret_token = secret_token.clone();
req.payload_mut().allowed_updates = allowed_updates;
req.send().await?; req.send().await?;
@ -182,17 +183,3 @@ fn check_secret(bytes: &[u8]) -> Result<&[u8], &'static str> {
Ok(bytes) Ok(bytes)
} }
/// Returns first (`.0`) field from a tuple as a `&mut` reference.
///
/// This hack is needed because there isn't currently a way to easily force a
/// closure to be higher-ranked (`for<'a> &'a mut _ -> &'a mut _`) which causes
/// problems when using [`StatefulListener`] to implement update listener.
///
/// This could be probably removed once [rfc#3216] is implemented.
///
/// [`StatefulListener`]:
/// [rfc#3216]: https://github.com/rust-lang/rfcs/pull/3216
fn tuple_first_mut<A, B>(tuple: &mut (A, B)) -> &mut A {
&mut tuple.0
}

View file

@ -1,18 +1,66 @@
use std::{convert::Infallible, future::Future, pin::Pin}; use crate::{
requests::{Request, Requester},
stop::{mk_stop_token, StopFlag},
types::{AllowedUpdate, True, Update, UpdateKind},
update_listeners::{
webhooks::{setup_webhook, Options},
StopToken, UpdateListener,
},
};
use axum::{ use axum::{
extract::{FromRequestParts, State}, extract::{FromRequestParts, State},
http::{request::Parts, status::StatusCode}, http::{request::Parts, status::StatusCode},
response::IntoResponse,
routing::post,
}; };
use futures::stream::Stream;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tower_http::trace::TraceLayer;
use crate::{ use std::{
requests::Requester, convert::Infallible,
stop::StopFlag, error::Error,
types::{Update, UpdateKind}, fmt::{self, Debug, Display},
update_listeners::{webhooks::Options, UpdateListener}, future::Future,
pin::Pin,
task,
task::Poll,
}; };
/// A webhook update listener backed by [`axum`](mod@axum).
pub struct Axum<B> {
bot: B,
options: Options,
token: StopToken,
flag: Option<StopFlag>,
allowed_updates: Option<Vec<AllowedUpdate>>,
/// This is a stream of updates, coming from an `axum::Router` we've
/// created.
///
/// N.B. This field is only initialized by `take_router` and is only
/// de-initialized by `listen`. Basically, it's a way to pass the
/// channel from the router creation, to the listener stream.
stream: Option<UnboundedReceiverStream<Update>>,
}
#[pin_project::pin_project]
pub struct AxumStream<'a, B: Requester> {
axum: &'a mut Axum<B>,
#[pin]
inner: UnboundedReceiverStream<Update>,
#[pin]
webhook_deletion: Option<Option<<B::DeleteWebhook as Request>::Send>>,
}
pub enum SetupError<B: Requester> {
Bind(hyper::Error),
SetWebhook(B::Err),
}
/// Webhook implementation based on the [mod@axum] framework. /// Webhook implementation based on the [mod@axum] framework.
/// ///
/// This function does all the work necessary for webhook to work, it: /// This function does all the work necessary for webhook to work, it:
@ -23,38 +71,71 @@ use crate::{
/// [`set_webhook`]: crate::payloads::SetWebhook /// [`set_webhook`]: crate::payloads::SetWebhook
/// [`delete_webhook`]: crate::payloads::DeleteWebhook /// [`delete_webhook`]: crate::payloads::DeleteWebhook
/// [`stop`]: crate::stop::StopToken::stop /// [`stop`]: crate::stop::StopToken::stop
/// pub fn axum<R>(bot: R, mut options: Options) -> Axum<R>
/// ## Panics
///
/// If binding to the [address] fails.
///
/// [address]: Options::address
///
/// ## Fails
///
/// If `set_webhook()` fails.
///
/// ## See also
///
/// [`axum_to_router`] and [`axum_no_setup`] for lower-level versions of this
/// function.
pub async fn axum<R>(
bot: R,
options: Options,
) -> Result<impl UpdateListener<Err = Infallible>, R::Err>
where where
R: Requester + Send + 'static, R: Requester + Sync + Send + Clone + 'static,
<R as Requester>::DeleteWebhook: Send, R::SetWebhook: Send,
R::DeleteWebhook: Send,
{ {
let Options { address, .. } = options; _ = options.get_or_gen_secret_token();
let (token, flag) = mk_stop_token();
let (mut update_listener, stop_flag, app) = axum_to_router(bot, options).await?; Axum { bot, options, token, flag: Some(flag), allowed_updates: None, stream: None }
let stop_token = update_listener.stop_token(); }
impl<B> Axum<B> {
/// Returns a router that will listen to updates.
///
/// N.B. you need to get a new router each time you re-start dispatching.
pub fn take_router(&mut self) -> Option<axum::Router> {
match self.stream {
None => {
self.reinit_stop_flag_if_needed();
let stop_flag = self.flag.as_ref().unwrap().clone();
let (router, stream) = create_router(&self.options, stop_flag);
self.stream = Some(stream);
Some(router)
}
Some(_) => None,
}
}
fn reinit_stop_flag_if_needed(&mut self) {
if self.flag.is_none() {
let (token, flag) = mk_stop_token();
self.token = token;
self.flag = Some(flag);
}
}
}
impl<B> UpdateListener for Axum<B>
where
B: Requester + Sync + Send + 'static,
B::SetWebhook: Send,
{
type SetupErr = SetupError<B>;
type StreamErr = Infallible;
type Stream<'a> = AxumStream<'a, B>;
fn listen(
&mut self,
) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'_>, Self::SetupErr>> + Send + '_>> {
Box::pin(async {
// Unwrap: `stop_flag` always returns `Some`.
self.reinit_stop_flag_if_needed();
let stop_flag = self.flag.take().unwrap();
let stop_token = self.token.clone();
// If the user did not take the router themselves — spawn an axum server
if let Some(router) = self.take_router() {
let server = axum::Server::try_bind(&self.options.address)
.map_err(SetupError::Bind)?
.serve(router.into_make_service())
.with_graceful_shutdown(stop_flag);
tokio::spawn(async move { tokio::spawn(async move {
axum::Server::bind(&address) server
.serve(app.into_make_service())
.with_graceful_shutdown(stop_flag)
.await .await
.map_err(|err| { .map_err(|err| {
stop_token.stop(); stop_token.stop();
@ -62,114 +143,139 @@ where
}) })
.expect("Axum server error"); .expect("Axum server error");
}); });
}
Ok(update_listener) // Unwrap: just called `take_router`
let stream = self.stream.take().unwrap();
setup_webhook(&self.bot, &self.options, self.allowed_updates.clone())
.await
.map_err(SetupError::SetWebhook)?;
let stream = AxumStream { axum: self, inner: stream, webhook_deletion: None };
Ok(stream)
})
}
fn stop_token(&mut self) -> StopToken {
self.reinit_stop_flag_if_needed();
self.token.clone()
}
fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator<Item = AllowedUpdate>) {
// TODO: we should probably warn if there already were different allowed updates
// before
self.allowed_updates = Some(hint.collect());
}
} }
/// Webhook implementation based on the [mod@axum] framework that can reuse impl<B> Stream for AxumStream<'_, B>
/// existing [mod@axum] server.
///
/// This function does most of the work necessary for webhook to work, it:
/// - Calls [`set_webhook`], so telegram starts sending updates our way
/// - When the update listener is [`stop`]ped, calls [`delete_webhook`]
///
/// The only missing part is running [mod@axum] server with a returned
/// [`axum::Router`].
///
/// This function is intended to be used in cases when you already have an
/// [mod@axum] server running and can reuse it for webhooks.
///
/// **Note**: in order for webhooks to work, you need to use returned
/// [`axum::Router`] in an [mod@axum] server that is bound to
/// [`options.address`].
///
/// It may also be desired to use [`with_graceful_shutdown`] with the returned
/// future in order to shutdown the server with the [`stop`] of the listener.
///
/// [`set_webhook`]: crate::payloads::SetWebhook
/// [`delete_webhook`]: crate::payloads::DeleteWebhook
/// [`stop`]: crate::stop::StopToken::stop
/// [`options.address`]: Options::address
/// [`with_graceful_shutdown`]: axum::Server::with_graceful_shutdown
///
/// ## Returns
///
/// A update listener, stop-future, axum router triplet on success.
///
/// The "stop-future" is resolved after [`stop`] is called on the stop token of
/// the returned update listener.
///
/// ## Fails
///
/// If `set_webhook()` fails.
///
/// ## See also
///
/// [`fn@axum`] for higher-level and [`axum_no_setup`] for lower-level
/// versions of this function.
pub async fn axum_to_router<R>(
bot: R,
mut options: Options,
) -> Result<
(impl UpdateListener<Err = Infallible>, impl Future<Output = ()> + Send, axum::Router),
R::Err,
>
where where
R: Requester + Send, B: Requester,
<R as Requester>::DeleteWebhook: Send,
{ {
use crate::{requests::Request, update_listeners::webhooks::setup_webhook}; type Item = Result<Update, Infallible>;
use futures::FutureExt;
setup_webhook(&bot, &mut options).await?; fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
let (listener, stop_flag, router) = axum_no_setup(options); match this.inner.poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(upd)) => Poll::Ready(Some(Ok(upd))),
Poll::Ready(None) => {
if let Some(mut deletion) = this.webhook_deletion.as_mut().as_pin_mut() {
if let Some(future) = deletion.as_mut().as_pin_mut() {
// `Some(Some(_))` — we are currently deleting webhook
match future.poll(cx) {
Poll::Pending => Poll::Pending,
let stop_flag = stop_flag.then(move |()| async move { // We completed webhook deletion (potentially failed)
// This assignment is needed to not require `R: Sync` since without it `&bot` Poll::Ready(Ok(True) | Err(_)) => {
// temporary lives across `.await` points. this.webhook_deletion.set(Some(None));
let req = bot.delete_webhook().send(); Poll::Ready(None)
let res = req.await;
if let Err(err) = res {
log::error!("Couldn't delete webhook: {}", err);
} }
}
} else {
// `Some(None)` — we've already deleted webhook
Poll::Ready(None)
}
} else {
// `None` — we haven't yet started deleting webhook
this.webhook_deletion.set(Some(Some(this.axum.bot.delete_webhook().send())));
// Immediately wake up to poll `self.in_flight`
// (without this this stream becomes a zombie)
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
}
}
impl<R: Requester> Debug for SetupError<R>
where
R: Requester,
R::Err: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Bind(e) => f.debug_tuple("Bind").field(&e).finish(),
Self::SetWebhook(e) => f.debug_tuple("SetWebhook").field(&e).finish(),
}
}
}
impl<R> fmt::Display for SetupError<R>
where
R: Requester,
R::Err: Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Bind(e) => write!(f, "Error while binding an address for webhooks: {e}"),
Self::SetWebhook(e) => write!(f, "Error while setting up webhooks: {e}"),
}
}
}
impl<R> Error for SetupError<R>
where
R: Requester,
R::Err: Error + Debug + Display + 'static,
{
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::Bind(e) => Some(e),
Self::SetWebhook(e) => Some(e),
}
}
}
fn create_router(
options: &Options,
stop_flag: StopFlag,
) -> (axum::Router, UnboundedReceiverStream<Update>) {
let (tx, rx): (mpsc::UnboundedSender<Update>, _) = mpsc::unbounded_channel();
let app = axum::Router::new()
.route(options.url.path(), post(telegram_request))
.layer(TraceLayer::new_for_http())
.with_state(WebhookState {
tx: ClosableSender::new(tx),
flag: stop_flag,
secret: options.secret_token.clone(),
}); });
Ok((listener, stop_flag, router)) (app, rx.into())
} }
/// Webhook implementation based on the [mod@axum] framework that doesn't async fn telegram_request(
/// perform any setup work.
///
/// ## Note about the stop-future
///
/// This function returns a future that is resolved when `.stop()` is called on
/// a stop token of the update listener. Note that even if the future is not
/// used, after `.stop()` is called, update listener will not produce new
/// updates.
///
/// ## See also
///
/// [`fn@axum`] and [`axum_to_router`] for higher-level versions of this
/// function.
pub fn axum_no_setup(
options: Options,
) -> (impl UpdateListener<Err = Infallible>, impl Future<Output = ()>, axum::Router) {
use crate::{
stop::{mk_stop_token, StopToken},
update_listeners::{webhooks::tuple_first_mut, StatefulListener},
};
use axum::{response::IntoResponse, routing::post};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tower_http::trace::TraceLayer;
let (tx, rx): (UpdateSender, _) = mpsc::unbounded_channel();
async fn telegram_request(
State(WebhookState { secret, flag, mut tx }): State<WebhookState>, State(WebhookState { secret, flag, mut tx }): State<WebhookState>,
secret_header: XTelegramBotApiSecretToken, secret_header: XTelegramBotApiSecretToken,
input: String, input: String,
) -> impl IntoResponse { ) -> impl IntoResponse {
// FIXME: use constant time comparison here // FIXME: use constant time comparison here
if secret_header.0.as_deref() != secret.as_deref().map(str::as_bytes) { if secret_header.0.as_deref() != secret.as_deref().map(str::as_bytes) {
return StatusCode::UNAUTHORIZED; return StatusCode::UNAUTHORIZED;
@ -194,7 +300,7 @@ pub fn axum_no_setup(
*value = serde_json::from_str(&input).unwrap_or_default(); *value = serde_json::from_str(&input).unwrap_or_default();
} }
tx.send(Ok(update)).expect("Cannot send an incoming update from the webhook") tx.send(update).expect("Cannot send an incoming update from the webhook")
} }
Err(error) => { Err(error) => {
log::error!( log::error!(
@ -208,37 +314,11 @@ pub fn axum_no_setup(
}; };
StatusCode::OK StatusCode::OK
}
let (stop_token, stop_flag) = mk_stop_token();
let app = axum::Router::new()
.route(options.url.path(), post(telegram_request))
.layer(TraceLayer::new_for_http())
.with_state(WebhookState {
tx: ClosableSender::new(tx),
flag: stop_flag.clone(),
secret: options.secret_token,
});
let stream = UnboundedReceiverStream::new(rx);
// FIXME: this should support `hint_allowed_updates()`
let listener = StatefulListener::new(
(stream, stop_token),
tuple_first_mut,
|state: &mut (_, StopToken)| state.1.clone(),
);
(listener, stop_flag, app)
} }
type UpdateSender = mpsc::UnboundedSender<Result<Update, std::convert::Infallible>>;
type UpdateCSender = ClosableSender<Result<Update, std::convert::Infallible>>;
#[derive(Clone)] #[derive(Clone)]
struct WebhookState { struct WebhookState {
tx: UpdateCSender, tx: ClosableSender<Update>,
flag: StopFlag, flag: StopFlag,
secret: Option<String>, secret: Option<String>,
} }