diff --git a/crates/teloxide/Cargo.toml b/crates/teloxide/Cargo.toml index 59c0fdb7..f3bf3ef6 100644 --- a/crates/teloxide/Cargo.toml +++ b/crates/teloxide/Cargo.toml @@ -21,7 +21,7 @@ categories = ["web-programming", "api-bindings", "asynchronous"] default = ["native-tls", "ctrlc_handler", "teloxide-core/default", "auto-send"] webhooks = ["rand"] -webhooks-axum = ["webhooks", "axum", "tower", "tower-http"] +webhooks-axum = ["webhooks", "axum", "tower", "tower-http", "hyper"] # FIXME: rename `sqlite-storage` -> `sqlite-storage-nativetls` sqlite-storage = ["sqlx", "sqlx/runtime-tokio-native-tls", "native-tls"] @@ -109,6 +109,7 @@ bincode = { version = "1.3", optional = true } axum = { version = "0.6.0", optional = true } tower = { version = "0.4.12", optional = true } tower-http = { version = "0.3.4", features = ["trace"], optional = true } +hyper = { version = "0.14", optional = true, default-features = false } rand = { version = "0.8.5", optional = true } diff --git a/crates/teloxide/examples/heroku_ping_pong.rs b/crates/teloxide/examples/heroku_ping_pong.rs index c9198132..c79ab75c 100644 --- a/crates/teloxide/examples/heroku_ping_pong.rs +++ b/crates/teloxide/examples/heroku_ping_pong.rs @@ -41,9 +41,7 @@ async fn main() { let host = env::var("HOST").expect("HOST env variable is not set"); let url = format!("https://{host}/webhook").parse().unwrap(); - let listener = webhooks::axum(bot.clone(), webhooks::Options::new(addr, url)) - .await - .expect("Couldn't setup webhook"); + let listener = webhooks::axum(bot.clone(), webhooks::Options::new(addr, url)); teloxide::repl_with_listener( bot, diff --git a/crates/teloxide/examples/ngrok_ping_pong.rs b/crates/teloxide/examples/ngrok_ping_pong.rs index b3f0b4b3..6eff3fbe 100644 --- a/crates/teloxide/examples/ngrok_ping_pong.rs +++ b/crates/teloxide/examples/ngrok_ping_pong.rs @@ -12,9 +12,7 @@ async fn main() { let addr = ([127, 0, 0, 1], 8443).into(); let url = "Your HTTPS ngrok URL here. Get it by `ngrok http 8443`".parse().unwrap(); - let listener = webhooks::axum(bot.clone(), webhooks::Options::new(addr, url)) - .await - .expect("Couldn't setup webhook"); + let listener = webhooks::axum(bot.clone(), webhooks::Options::new(addr, url)); teloxide::repl_with_listener( bot, diff --git a/crates/teloxide/src/dispatching/dispatcher.rs b/crates/teloxide/src/dispatching/dispatcher.rs index 91db674a..f3b5a4d7 100644 --- a/crates/teloxide/src/dispatching/dispatcher.rs +++ b/crates/teloxide/src/dispatching/dispatcher.rs @@ -20,7 +20,8 @@ use tokio_stream::wrappers::ReceiverStream; use std::{ collections::HashMap, - fmt::Debug, + error::Error, + fmt::{self, Debug, Display}, future::Future, hash::Hash, ops::{ControlFlow, Deref}, @@ -214,6 +215,15 @@ pub struct Dispatcher { state: ShutdownToken, } +/// An error returned from [`Disatcher::try_dispatch_with_listener`]. +pub enum TryDispatchError { + /// An error from calling `get_me` while creating dispatcher context. + GetMe(R::Err), + + /// An error during update listener setup. + ListenerSetup(L::SetupErr), +} + struct Worker { tx: tokio::sync::mpsc::Sender, handle: tokio::task::JoinHandle<()>, @@ -300,8 +310,8 @@ where update_listener_error_handler: Arc, ) where UListener: UpdateListener + 'a, - Eh: ErrorHandler + 'a, - UListener::Err: Debug, + Eh: ErrorHandler + 'a, + UListener::SetupErr: Debug, { self.try_dispatch_with_listener(update_listener, update_listener_error_handler) .await @@ -319,14 +329,13 @@ where &'a mut self, mut update_listener: UListener, update_listener_error_handler: Arc, - ) -> Result<(), R::Err> + ) -> Result<(), TryDispatchError> where UListener: UpdateListener + 'a, - Eh: ErrorHandler + 'a, - UListener::Err: Debug, + Eh: ErrorHandler + 'a, { // FIXME: there should be a way to check if dependency is already inserted - let me = self.bot.get_me().send().await?; + let me = self.bot.get_me().send().await.map_err(TryDispatchError::GetMe)?; self.dependencies.insert(me); self.dependencies.insert(self.bot.clone()); @@ -340,7 +349,7 @@ where self.state.start_dispatching(); { - let stream = update_listener.as_stream(); + let stream = update_listener.listen().await.map_err(TryDispatchError::ListenerSetup)?; tokio::pin!(stream); loop { @@ -521,6 +530,51 @@ impl Dispatcher { } } +impl Debug for TryDispatchError +where + R: Requester, + R::Err: Debug, + L: UpdateListener, + L::SetupErr: Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::GetMe(e) => f.debug_tuple("GetMe").field(&e).finish(), + Self::ListenerSetup(e) => f.debug_tuple("ListenerSetup").field(&e).finish(), + } + } +} + +impl fmt::Display for TryDispatchError +where + R: Requester, + R::Err: Display, + L: UpdateListener, + L::SetupErr: Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::GetMe(e) => write!(f, "Error while setting up update listener: {e}"), + Self::ListenerSetup(e) => write!(f, "Error while setting up update listener: {e}"), + } + } +} + +impl Error for TryDispatchError +where + R: Requester, + R::Err: Error + Debug + Display + 'static, + L: UpdateListener, + L::SetupErr: Error + Debug + Display + 'static, +{ + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + Self::GetMe(e) => Some(e), + Self::ListenerSetup(e) => Some(e), + } + } +} + fn spawn_worker( deps: DependencyMap, handler: Arc>, diff --git a/crates/teloxide/src/repls/commands_repl.rs b/crates/teloxide/src/repls/commands_repl.rs index c846089e..1f9cbebb 100644 --- a/crates/teloxide/src/repls/commands_repl.rs +++ b/crates/teloxide/src/repls/commands_repl.rs @@ -85,8 +85,11 @@ pub trait CommandReplExt { fn repl_with_listener<'a, R, H, L, Args>(bot: R, handler: H, listener: L) -> BoxFuture<'a, ()> where H: Injectable, Args> + Send + Sync + 'static, - L: UpdateListener + Send + 'a, - L::Err: Debug + Send + 'a, + // FIXME: why do we need + 'static here??? (and in similar cases) (compiler error does not + // provide enough info...) + L: UpdateListener + Send + 'static, + L::SetupErr: Debug, + L::StreamErr: Debug + Send + 'a, R: Requester + Clone + Send + Sync + 'static, ::GetMe: Send; } @@ -120,8 +123,9 @@ where fn repl_with_listener<'a, R, H, L, Args>(bot: R, handler: H, listener: L) -> BoxFuture<'a, ()> where H: Injectable, Args> + Send + Sync + 'static, - L: UpdateListener + Send + 'a, - L::Err: Debug + Send + 'a, + L: UpdateListener + Send + 'static, + L::SetupErr: Debug, + L::StreamErr: Debug + Send + 'a, R: Requester + Clone + Send + Sync + 'static, ::GetMe: Send, { @@ -277,7 +281,8 @@ pub async fn commands_repl_with_listener<'a, R, Cmd, H, L, Args>( Cmd: BotCommands + Send + Sync + 'static, H: Injectable, Args> + Send + Sync + 'static, L: UpdateListener + Send + 'a, - L::Err: Debug + Send + 'a, + L::SetupErr: Debug, + L::StreamErr: Debug, R: Requester + Clone + Send + Sync + 'static, { use crate::dispatching::Dispatcher; diff --git a/crates/teloxide/src/repls/repl.rs b/crates/teloxide/src/repls/repl.rs index dcb0a308..728593ae 100644 --- a/crates/teloxide/src/repls/repl.rs +++ b/crates/teloxide/src/repls/repl.rs @@ -108,7 +108,8 @@ where R: Requester + Clone + Send + Sync + 'static, H: Injectable, Args> + Send + Sync + 'static, L: UpdateListener + Send, - L::Err: Debug, + L::SetupErr: Debug, + L::StreamErr: Debug, { use crate::dispatching::Dispatcher; diff --git a/crates/teloxide/src/update_listeners.rs b/crates/teloxide/src/update_listeners.rs index 6e03ef01..66ca97d4 100644 --- a/crates/teloxide/src/update_listeners.rs +++ b/crates/teloxide/src/update_listeners.rs @@ -30,7 +30,9 @@ #[cfg(feature = "webhooks")] pub mod webhooks; -use futures::Stream; +use std::pin::Pin; + +use futures::{Future, Stream}; use crate::{ stop::StopToken, @@ -38,30 +40,54 @@ use crate::{ }; mod polling; -mod stateful_listener; #[allow(deprecated)] -pub use self::{ - polling::{polling, polling_default, Polling, PollingBuilder, PollingStream}, - stateful_listener::StatefulListener, -}; +pub use self::polling::{polling, polling_default, Polling, PollingBuilder, PollingStream}; /// An update listener. /// /// Implementors of this trait allow getting updates from Telegram. See /// [module-level documentation] for more. /// -/// Some functions of this trait are located in the supertrait -/// ([`AsUpdateStream`]), see also: -/// - [`AsUpdateStream::Stream`] -/// - [`AsUpdateStream::as_stream`] -/// /// [module-level documentation]: mod@self -pub trait UpdateListener: - for<'a> AsUpdateStream<'a, StreamErr = ::Err> -{ - /// The type of errors that can be returned from this listener. - type Err; +pub trait UpdateListener { + type SetupErr; + + /// Error that can be returned from the [`Stream`] + /// + /// [`Stream`]: UpdateListener::Stream + type StreamErr; + + /// The stream of updates from Telegram. + // NB: `Send` is not strictly required here, but it makes it easier to return + // `impl AsUpdateStream` and also you want `Send` streams almost (?) always + // anyway. + type Stream<'a>: Stream> + Send + 'a + where + Self: 'a; + + /// Creates the update [`Stream`]. + /// + /// This function should also do all the necessary setup, and return an + /// error if something goes wrong with it. For example for webhooks this + /// should call `set_webhook`. + /// + /// [`Stream`]: AsUpdateStream::Stream + fn listen( + &mut self, + ) -> Pin, Self::SetupErr>> + Send + '_>>; + + /// Hint which updates should the listener listen for. + /// + /// For example [`polling()`] should send the hint as + /// [`GetUpdates::allowed_updates`] + /// + /// Note: this is a very important method, without setting appropriate + /// allowed updates, telegram will not send some update kinds. + /// + /// [`GetUpdates::allowed_updates`]: + /// crate::payloads::GetUpdates::allowed_updates + fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator); /// Returns a token which stops this listener. /// @@ -77,44 +103,6 @@ pub trait UpdateListener: #[must_use = "This function doesn't stop listening, to stop listening you need to call `stop` \ on the returned token"] fn stop_token(&mut self) -> StopToken; - - /// Hint which updates should the listener listen for. - /// - /// For example [`polling()`] should send the hint as - /// [`GetUpdates::allowed_updates`] - /// - /// Note however that this is a _hint_ and as such, it can be ignored. The - /// listener is not guaranteed to only return updates which types are listed - /// in the hint. - /// - /// [`GetUpdates::allowed_updates`]: - /// crate::payloads::GetUpdates::allowed_updates - fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator) { - let _ = hint; - } -} - -/// [`UpdateListener`]'s supertrait/extension. -/// -/// This trait is a workaround to not require GAT. -pub trait AsUpdateStream<'a> { - /// Error that can be returned from the [`Stream`] - /// - /// [`Stream`]: AsUpdateStream::Stream - // NB: This should be named differently to `UpdateListener::Err`, so that it's - // unambiguous - type StreamErr; - - /// The stream of updates from Telegram. - // NB: `Send` is not strictly required here, but it makes it easier to return - // `impl AsUpdateStream` and also you want `Send` streams almost (?) always - // anyway. - type Stream: Stream> + Send + 'a; - - /// Creates the update [`Stream`]. - /// - /// [`Stream`]: AsUpdateStream::Stream - fn as_stream(&'a mut self) -> Self::Stream; } #[inline(always)] diff --git a/crates/teloxide/src/update_listeners/polling.rs b/crates/teloxide/src/update_listeners/polling.rs index 62d63076..7017f9b9 100644 --- a/crates/teloxide/src/update_listeners/polling.rs +++ b/crates/teloxide/src/update_listeners/polling.rs @@ -17,7 +17,7 @@ use crate::{ requests::{HasPayload, Request, Requester}, stop::{mk_stop_token, StopFlag, StopToken}, types::{AllowedUpdate, Update}, - update_listeners::{assert_update_listener, AsUpdateStream, UpdateListener}, + update_listeners::{assert_update_listener, UpdateListener}, }; /// Builder for polling update listener. @@ -243,7 +243,7 @@ where /// [get_updates]: crate::requests::Requester::get_updates /// [`Dispatcher`]: crate::dispatching::Dispatcher #[must_use = "`Polling` is an update listener and does nothing unless used"] -pub struct Polling { +pub struct Polling { bot: B, timeout: Option, limit: Option, @@ -320,7 +320,53 @@ pub struct PollingStream<'a, B: Requester> { } impl UpdateListener for Polling { - type Err = B::Err; + type SetupErr = B::Err; + type StreamErr = B::Err; + type Stream<'a> = PollingStream<'a, B>; + + fn listen( + &mut self, + ) -> Pin, Self::SetupErr>> + Send + '_>> { + Box::pin(async { + let timeout = self.timeout.map(|t| t.as_secs().try_into().expect("timeout is too big")); + let allowed_updates = self.allowed_updates.clone(); + let drop_pending_updates = self.drop_pending_updates; + + let token_used_and_updated = self.reinit_stop_flag_if_needed(); + + // FIXME: document that `listen` is a destructive operation, actually, + // and you need to call `stop_token` *again* after it + // + // maybe also remove the panic, it's a lot of additional work, for little + // benefit, it seems like. + if token_used_and_updated { + panic!( + "detected calling `as_stream` a second time after calling `stop_token`. \ + `as_stream` updates the stop token, thus you need to call it again after \ + calling `as_stream`" + ) + } + + // FIXME: do update dropping *here* + + // Unwrap: just called reinit + let flag = self.flag.take().unwrap(); + let stream = PollingStream { + polling: self, + drop_pending_updates, + timeout, + allowed_updates, + offset: 0, + force_stop: false, + stopping: false, + buffer: Vec::new().into_iter(), + in_flight: None, + flag, + }; + + Ok(stream) + }) + } fn stop_token(&mut self) -> StopToken { self.reinit_stop_flag_if_needed(); @@ -335,44 +381,6 @@ impl UpdateListener for Polling { } } -impl<'a, B: Requester + Send + 'a> AsUpdateStream<'a> for Polling { - type StreamErr = B::Err; - type Stream = PollingStream<'a, B>; - - fn as_stream(&'a mut self) -> Self::Stream { - let timeout = self.timeout.map(|t| t.as_secs().try_into().expect("timeout is too big")); - let allowed_updates = self.allowed_updates.clone(); - let drop_pending_updates = self.drop_pending_updates; - - let token_used_and_updated = self.reinit_stop_flag_if_needed(); - - // FIXME: document that `as_stream` is a destructive operation, actually, - // and you need to call `stop_token` *again* after it - if token_used_and_updated { - panic!( - "detected calling `as_stream` a second time after calling `stop_token`. \ - `as_stream` updates the stop token, thus you need to call it again after calling \ - `as_stream`" - ) - } - - // Unwrap: just called reinit - let flag = self.flag.take().unwrap(); - PollingStream { - polling: self, - drop_pending_updates, - timeout, - allowed_updates, - offset: 0, - force_stop: false, - stopping: false, - buffer: Vec::new().into_iter(), - in_flight: None, - flag, - } - } -} - impl Stream for PollingStream<'_, B> { type Item = Result; @@ -471,8 +479,12 @@ fn polling_is_send() { let mut polling = polling(bot, None, None, None); assert_send(&polling); - assert_send(&polling.as_stream()); + assert_send(&polling.listen()); assert_send(&polling.stop_token()); + _ = async { + assert_send(&polling.listen().await.unwrap()); + }; + fn assert_send(_: &impl Send) {} } diff --git a/crates/teloxide/src/update_listeners/stateful_listener.rs b/crates/teloxide/src/update_listeners/stateful_listener.rs deleted file mode 100644 index 87ae492a..00000000 --- a/crates/teloxide/src/update_listeners/stateful_listener.rs +++ /dev/null @@ -1,92 +0,0 @@ -use futures::Stream; - -use crate::{ - stop::StopToken, - types::{AllowedUpdate, Update}, - update_listeners::{AsUpdateStream, UpdateListener}, -}; - -/// A listener created from functions. -/// -/// This type allows to turn a stream of updates (+ some additional functions) -/// into an [`UpdateListener`]. -/// -/// For an example of usage, see [`polling`]. -/// -/// [`polling`]: crate::update_listeners::polling() -#[non_exhaustive] -pub struct StatefulListener { - /// The state of the listener. - pub state: St, - - /// The function used as [`AsUpdateStream::as_stream`]. - /// - /// Must implement `for<'a> FnMut(&'a mut St) -> impl Stream + 'a`. - pub stream: Assf, - - /// The function used as [`UpdateListener::stop_token`]. - /// - /// Must implement `FnMut(&mut St) -> StopToken`. - pub stop_token: Sf, - - /// The function used as [`UpdateListener::hint_allowed_updates`]. - /// - /// Must implement `FnMut(&mut St, &mut dyn Iterator)`. - pub hint_allowed_updates: Option, -} - -type Haufn = for<'a, 'b> fn(&'a mut State, &'b mut dyn Iterator); - -impl StatefulListener> { - /// Creates a new stateful listener from its components. - pub fn new(state: St, stream: Assf, stop_token: Sf) -> Self { - Self::new_with_hints(state, stream, stop_token, None) - } -} - -impl StatefulListener { - /// Creates a new stateful listener from its components. - pub fn new_with_hints( - state: St, - stream: Assf, - stop_token: Sf, - hint_allowed_updates: Option, - ) -> Self { - Self { state, stream, stop_token, hint_allowed_updates } - } -} - -impl<'a, St, Assf, Sf, Hauf, Strm, E> AsUpdateStream<'a> for StatefulListener -where - (St, Strm): 'a, - Strm: Send, - Assf: FnMut(&'a mut St) -> Strm, - Strm: Stream>, -{ - type StreamErr = E; - type Stream = Strm; - - fn as_stream(&'a mut self) -> Self::Stream { - (self.stream)(&mut self.state) - } -} - -impl UpdateListener for StatefulListener -where - Self: for<'a> AsUpdateStream<'a, StreamErr = E>, - Sf: FnMut(&mut St) -> StopToken, - Hauf: FnMut(&mut St, &mut dyn Iterator), -{ - type Err = E; - - fn stop_token(&mut self) -> StopToken { - (self.stop_token)(&mut self.state) - } - - fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator) { - if let Some(f) = &mut self.hint_allowed_updates { - f(&mut self.state, hint); - } - } -} diff --git a/crates/teloxide/src/update_listeners/webhooks.rs b/crates/teloxide/src/update_listeners/webhooks.rs index e327016b..6c96940a 100644 --- a/crates/teloxide/src/update_listeners/webhooks.rs +++ b/crates/teloxide/src/update_listeners/webhooks.rs @@ -1,7 +1,7 @@ //! use std::net::SocketAddr; -use crate::{requests::Requester, types::InputFile}; +use crate::{requests::Requester, types::InputFile, update_listeners::AllowedUpdate}; /// Options related to setting up webhooks. #[must_use] @@ -116,7 +116,7 @@ impl Options { } #[cfg(feature = "webhooks-axum")] -pub use self::axum::{axum, axum_no_setup, axum_to_router}; +pub use self::axum::axum; #[cfg(feature = "webhooks-axum")] mod axum; @@ -124,25 +124,26 @@ mod axum; // TODO: add different implementation (for example: warp) /// Calls `set_webhook` with arguments from `options`. -/// -/// Note: this takes out `certificate`. -async fn setup_webhook(bot: R, options: &mut Options) -> Result<(), R::Err> +async fn setup_webhook( + bot: R, + options: &Options, + allowed_updates: Option>, +) -> Result<(), R::Err> where R: Requester, { use crate::requests::Request; use teloxide_core::requests::HasPayload; - let secret = options.get_or_gen_secret_token().to_owned(); - let &mut Options { - ref url, ref mut certificate, max_connections, drop_pending_updates, .. - } = options; + let Options { url, certificate, secret_token, max_connections, drop_pending_updates, .. } = + options; let mut req = bot.set_webhook(url.clone()); - req.payload_mut().certificate = certificate.take(); - req.payload_mut().max_connections = max_connections; - req.payload_mut().drop_pending_updates = Some(drop_pending_updates); - req.payload_mut().secret_token = Some(secret); + req.payload_mut().certificate = certificate.clone(); + req.payload_mut().max_connections = *max_connections; + req.payload_mut().drop_pending_updates = Some(*drop_pending_updates); + req.payload_mut().secret_token = secret_token.clone(); + req.payload_mut().allowed_updates = allowed_updates; req.send().await?; @@ -182,17 +183,3 @@ fn check_secret(bytes: &[u8]) -> Result<&[u8], &'static str> { Ok(bytes) } - -/// Returns first (`.0`) field from a tuple as a `&mut` reference. -/// -/// This hack is needed because there isn't currently a way to easily force a -/// closure to be higher-ranked (`for<'a> &'a mut _ -> &'a mut _`) which causes -/// problems when using [`StatefulListener`] to implement update listener. -/// -/// This could be probably removed once [rfc#3216] is implemented. -/// -/// [`StatefulListener`]: -/// [rfc#3216]: https://github.com/rust-lang/rfcs/pull/3216 -fn tuple_first_mut(tuple: &mut (A, B)) -> &mut A { - &mut tuple.0 -} diff --git a/crates/teloxide/src/update_listeners/webhooks/axum.rs b/crates/teloxide/src/update_listeners/webhooks/axum.rs index 2a915a55..49123532 100644 --- a/crates/teloxide/src/update_listeners/webhooks/axum.rs +++ b/crates/teloxide/src/update_listeners/webhooks/axum.rs @@ -1,18 +1,66 @@ -use std::{convert::Infallible, future::Future, pin::Pin}; +use crate::{ + requests::{Request, Requester}, + stop::{mk_stop_token, StopFlag}, + types::{AllowedUpdate, True, Update, UpdateKind}, + update_listeners::{ + webhooks::{setup_webhook, Options}, + StopToken, UpdateListener, + }, +}; use axum::{ extract::{FromRequestParts, State}, http::{request::Parts, status::StatusCode}, + response::IntoResponse, + routing::post, }; +use futures::stream::Stream; use tokio::sync::mpsc; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tower_http::trace::TraceLayer; -use crate::{ - requests::Requester, - stop::StopFlag, - types::{Update, UpdateKind}, - update_listeners::{webhooks::Options, UpdateListener}, +use std::{ + convert::Infallible, + error::Error, + fmt::{self, Debug, Display}, + future::Future, + pin::Pin, + task, + task::Poll, }; +/// A webhook update listener backed by [`axum`](mod@axum). +pub struct Axum { + bot: B, + options: Options, + token: StopToken, + flag: Option, + allowed_updates: Option>, + /// This is a stream of updates, coming from an `axum::Router` we've + /// created. + /// + /// N.B. This field is only initialized by `take_router` and is only + /// de-initialized by `listen`. Basically, it's a way to pass the + /// channel from the router creation, to the listener stream. + stream: Option>, +} + +#[pin_project::pin_project] +pub struct AxumStream<'a, B: Requester> { + axum: &'a mut Axum, + + #[pin] + inner: UnboundedReceiverStream, + + #[pin] + webhook_deletion: Option::Send>>, +} + +pub enum SetupError { + Bind(hyper::Error), + SetWebhook(B::Err), +} + /// Webhook implementation based on the [mod@axum] framework. /// /// This function does all the work necessary for webhook to work, it: @@ -23,222 +71,254 @@ use crate::{ /// [`set_webhook`]: crate::payloads::SetWebhook /// [`delete_webhook`]: crate::payloads::DeleteWebhook /// [`stop`]: crate::stop::StopToken::stop -/// -/// ## Panics -/// -/// If binding to the [address] fails. -/// -/// [address]: Options::address -/// -/// ## Fails -/// -/// If `set_webhook()` fails. -/// -/// ## See also -/// -/// [`axum_to_router`] and [`axum_no_setup`] for lower-level versions of this -/// function. -pub async fn axum( - bot: R, - options: Options, -) -> Result, R::Err> +pub fn axum(bot: R, mut options: Options) -> Axum where - R: Requester + Send + 'static, - ::DeleteWebhook: Send, + R: Requester + Sync + Send + Clone + 'static, + R::SetWebhook: Send, + R::DeleteWebhook: Send, { - let Options { address, .. } = options; + _ = options.get_or_gen_secret_token(); + let (token, flag) = mk_stop_token(); - let (mut update_listener, stop_flag, app) = axum_to_router(bot, options).await?; - let stop_token = update_listener.stop_token(); - - tokio::spawn(async move { - axum::Server::bind(&address) - .serve(app.into_make_service()) - .with_graceful_shutdown(stop_flag) - .await - .map_err(|err| { - stop_token.stop(); - err - }) - .expect("Axum server error"); - }); - - Ok(update_listener) + Axum { bot, options, token, flag: Some(flag), allowed_updates: None, stream: None } } -/// Webhook implementation based on the [mod@axum] framework that can reuse -/// existing [mod@axum] server. -/// -/// This function does most of the work necessary for webhook to work, it: -/// - Calls [`set_webhook`], so telegram starts sending updates our way -/// - When the update listener is [`stop`]ped, calls [`delete_webhook`] -/// -/// The only missing part is running [mod@axum] server with a returned -/// [`axum::Router`]. -/// -/// This function is intended to be used in cases when you already have an -/// [mod@axum] server running and can reuse it for webhooks. -/// -/// **Note**: in order for webhooks to work, you need to use returned -/// [`axum::Router`] in an [mod@axum] server that is bound to -/// [`options.address`]. -/// -/// It may also be desired to use [`with_graceful_shutdown`] with the returned -/// future in order to shutdown the server with the [`stop`] of the listener. -/// -/// [`set_webhook`]: crate::payloads::SetWebhook -/// [`delete_webhook`]: crate::payloads::DeleteWebhook -/// [`stop`]: crate::stop::StopToken::stop -/// [`options.address`]: Options::address -/// [`with_graceful_shutdown`]: axum::Server::with_graceful_shutdown -/// -/// ## Returns -/// -/// A update listener, stop-future, axum router triplet on success. -/// -/// The "stop-future" is resolved after [`stop`] is called on the stop token of -/// the returned update listener. -/// -/// ## Fails -/// -/// If `set_webhook()` fails. -/// -/// ## See also -/// -/// [`fn@axum`] for higher-level and [`axum_no_setup`] for lower-level -/// versions of this function. -pub async fn axum_to_router( - bot: R, - mut options: Options, -) -> Result< - (impl UpdateListener, impl Future + Send, axum::Router), - R::Err, -> -where - R: Requester + Send, - ::DeleteWebhook: Send, -{ - use crate::{requests::Request, update_listeners::webhooks::setup_webhook}; - use futures::FutureExt; - - setup_webhook(&bot, &mut options).await?; - - let (listener, stop_flag, router) = axum_no_setup(options); - - let stop_flag = stop_flag.then(move |()| async move { - // This assignment is needed to not require `R: Sync` since without it `&bot` - // temporary lives across `.await` points. - let req = bot.delete_webhook().send(); - let res = req.await; - if let Err(err) = res { - log::error!("Couldn't delete webhook: {}", err); +impl Axum { + /// Returns a router that will listen to updates. + /// + /// N.B. you need to get a new router each time you re-start dispatching. + pub fn take_router(&mut self) -> Option { + match self.stream { + None => { + self.reinit_stop_flag_if_needed(); + let stop_flag = self.flag.as_ref().unwrap().clone(); + let (router, stream) = create_router(&self.options, stop_flag); + self.stream = Some(stream); + Some(router) + } + Some(_) => None, } - }); - - Ok((listener, stop_flag, router)) -} - -/// Webhook implementation based on the [mod@axum] framework that doesn't -/// perform any setup work. -/// -/// ## Note about the stop-future -/// -/// This function returns a future that is resolved when `.stop()` is called on -/// a stop token of the update listener. Note that even if the future is not -/// used, after `.stop()` is called, update listener will not produce new -/// updates. -/// -/// ## See also -/// -/// [`fn@axum`] and [`axum_to_router`] for higher-level versions of this -/// function. -pub fn axum_no_setup( - options: Options, -) -> (impl UpdateListener, impl Future, axum::Router) { - use crate::{ - stop::{mk_stop_token, StopToken}, - update_listeners::{webhooks::tuple_first_mut, StatefulListener}, - }; - use axum::{response::IntoResponse, routing::post}; - use tokio_stream::wrappers::UnboundedReceiverStream; - use tower_http::trace::TraceLayer; - - let (tx, rx): (UpdateSender, _) = mpsc::unbounded_channel(); - - async fn telegram_request( - State(WebhookState { secret, flag, mut tx }): State, - secret_header: XTelegramBotApiSecretToken, - input: String, - ) -> impl IntoResponse { - // FIXME: use constant time comparison here - if secret_header.0.as_deref() != secret.as_deref().map(str::as_bytes) { - return StatusCode::UNAUTHORIZED; - } - - let tx = match tx.get() { - None => return StatusCode::SERVICE_UNAVAILABLE, - // Do not process updates after `.stop()` is called even if the server is still - // running (useful for when you need to stop the bot but can't stop the server). - _ if flag.is_stopped() => { - tx.close(); - return StatusCode::SERVICE_UNAVAILABLE; - } - Some(tx) => tx, - }; - - match serde_json::from_str::(&input) { - Ok(mut update) => { - // See HACK comment in - // `teloxide_core::net::request::process_response::{closure#0}` - if let UpdateKind::Error(value) = &mut update.kind { - *value = serde_json::from_str(&input).unwrap_or_default(); - } - - tx.send(Ok(update)).expect("Cannot send an incoming update from the webhook") - } - Err(error) => { - log::error!( - "Cannot parse an update.\nError: {:?}\nValue: {}\n\ - This is a bug in teloxide-core, please open an issue here: \ - https://github.com/teloxide/teloxide/issues.", - error, - input - ); - } - }; - - StatusCode::OK } - let (stop_token, stop_flag) = mk_stop_token(); + fn reinit_stop_flag_if_needed(&mut self) { + if self.flag.is_none() { + let (token, flag) = mk_stop_token(); + self.token = token; + self.flag = Some(flag); + } + } +} + +impl UpdateListener for Axum +where + B: Requester + Sync + Send + 'static, + B::SetWebhook: Send, +{ + type SetupErr = SetupError; + type StreamErr = Infallible; + type Stream<'a> = AxumStream<'a, B>; + + fn listen( + &mut self, + ) -> Pin, Self::SetupErr>> + Send + '_>> { + Box::pin(async { + // Unwrap: `stop_flag` always returns `Some`. + self.reinit_stop_flag_if_needed(); + let stop_flag = self.flag.take().unwrap(); + let stop_token = self.token.clone(); + + // If the user did not take the router themselves — spawn an axum server + if let Some(router) = self.take_router() { + let server = axum::Server::try_bind(&self.options.address) + .map_err(SetupError::Bind)? + .serve(router.into_make_service()) + .with_graceful_shutdown(stop_flag); + + tokio::spawn(async move { + server + .await + .map_err(|err| { + stop_token.stop(); + err + }) + .expect("Axum server error"); + }); + } + + // Unwrap: just called `take_router` + let stream = self.stream.take().unwrap(); + + setup_webhook(&self.bot, &self.options, self.allowed_updates.clone()) + .await + .map_err(SetupError::SetWebhook)?; + + let stream = AxumStream { axum: self, inner: stream, webhook_deletion: None }; + + Ok(stream) + }) + } + + fn stop_token(&mut self) -> StopToken { + self.reinit_stop_flag_if_needed(); + self.token.clone() + } + + fn hint_allowed_updates(&mut self, hint: &mut dyn Iterator) { + // TODO: we should probably warn if there already were different allowed updates + // before + self.allowed_updates = Some(hint.collect()); + } +} + +impl Stream for AxumStream<'_, B> +where + B: Requester, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + let mut this = self.project(); + + match this.inner.poll_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(upd)) => Poll::Ready(Some(Ok(upd))), + Poll::Ready(None) => { + if let Some(mut deletion) = this.webhook_deletion.as_mut().as_pin_mut() { + if let Some(future) = deletion.as_mut().as_pin_mut() { + // `Some(Some(_))` — we are currently deleting webhook + match future.poll(cx) { + Poll::Pending => Poll::Pending, + + // We completed webhook deletion (potentially failed) + Poll::Ready(Ok(True) | Err(_)) => { + this.webhook_deletion.set(Some(None)); + Poll::Ready(None) + } + } + } else { + // `Some(None)` — we've already deleted webhook + Poll::Ready(None) + } + } else { + // `None` — we haven't yet started deleting webhook + + this.webhook_deletion.set(Some(Some(this.axum.bot.delete_webhook().send()))); + + // Immediately wake up to poll `self.in_flight` + // (without this this stream becomes a zombie) + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } + } +} + +impl Debug for SetupError +where + R: Requester, + R::Err: Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Bind(e) => f.debug_tuple("Bind").field(&e).finish(), + Self::SetWebhook(e) => f.debug_tuple("SetWebhook").field(&e).finish(), + } + } +} + +impl fmt::Display for SetupError +where + R: Requester, + R::Err: Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Bind(e) => write!(f, "Error while binding an address for webhooks: {e}"), + Self::SetWebhook(e) => write!(f, "Error while setting up webhooks: {e}"), + } + } +} + +impl Error for SetupError +where + R: Requester, + R::Err: Error + Debug + Display + 'static, +{ + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + Self::Bind(e) => Some(e), + Self::SetWebhook(e) => Some(e), + } + } +} + +fn create_router( + options: &Options, + stop_flag: StopFlag, +) -> (axum::Router, UnboundedReceiverStream) { + let (tx, rx): (mpsc::UnboundedSender, _) = mpsc::unbounded_channel(); let app = axum::Router::new() .route(options.url.path(), post(telegram_request)) .layer(TraceLayer::new_for_http()) .with_state(WebhookState { tx: ClosableSender::new(tx), - flag: stop_flag.clone(), - secret: options.secret_token, + flag: stop_flag, + secret: options.secret_token.clone(), }); - let stream = UnboundedReceiverStream::new(rx); - - // FIXME: this should support `hint_allowed_updates()` - let listener = StatefulListener::new( - (stream, stop_token), - tuple_first_mut, - |state: &mut (_, StopToken)| state.1.clone(), - ); - - (listener, stop_flag, app) + (app, rx.into()) } -type UpdateSender = mpsc::UnboundedSender>; -type UpdateCSender = ClosableSender>; +async fn telegram_request( + State(WebhookState { secret, flag, mut tx }): State, + secret_header: XTelegramBotApiSecretToken, + input: String, +) -> impl IntoResponse { + // FIXME: use constant time comparison here + if secret_header.0.as_deref() != secret.as_deref().map(str::as_bytes) { + return StatusCode::UNAUTHORIZED; + } + + let tx = match tx.get() { + None => return StatusCode::SERVICE_UNAVAILABLE, + // Do not process updates after `.stop()` is called even if the server is still + // running (useful for when you need to stop the bot but can't stop the server). + _ if flag.is_stopped() => { + tx.close(); + return StatusCode::SERVICE_UNAVAILABLE; + } + Some(tx) => tx, + }; + + match serde_json::from_str::(&input) { + Ok(mut update) => { + // See HACK comment in + // `teloxide_core::net::request::process_response::{closure#0}` + if let UpdateKind::Error(value) = &mut update.kind { + *value = serde_json::from_str(&input).unwrap_or_default(); + } + + tx.send(update).expect("Cannot send an incoming update from the webhook") + } + Err(error) => { + log::error!( + "Cannot parse an update.\nError: {:?}\nValue: {}\n\ + This is a bug in teloxide-core, please open an issue here: \ + https://github.com/teloxide/teloxide/issues.", + error, + input + ); + } + }; + + StatusCode::OK +} #[derive(Clone)] struct WebhookState { - tx: UpdateCSender, + tx: ClosableSender, flag: StopFlag, secret: Option, }