diff --git a/examples/dialogue_bot/src/main.rs b/examples/dialogue_bot/src/main.rs index bbac3112..663c503c 100644 --- a/examples/dialogue_bot/src/main.rs +++ b/examples/dialogue_bot/src/main.rs @@ -19,6 +19,7 @@ #[macro_use] extern crate smart_default; +use std::convert::Infallible; use teloxide::{ prelude::*, types::{KeyboardButton, ReplyKeyboardMarkup}, @@ -87,7 +88,7 @@ enum Dialogue { // [Control a dialogue] // ============================================================================ -type Cx<State> = DialogueDispatcherHandlerCx<Message, State>; +type Cx<State> = DialogueDispatcherHandlerCx<Message, State, Infallible>; type Res = ResponseResult<DialogueStage<Dialogue>>; async fn start(cx: Cx<()>) -> Res { @@ -118,13 +119,13 @@ async fn age(cx: Cx<ReceiveAgeState>) -> Res { .send() .await?; next(Dialogue::ReceiveFavouriteMusic(ReceiveFavouriteMusicState { - data: cx.dialogue, + data: cx.dialogue.unwrap(), age, })) } Err(_) => { cx.answer("Oh, please, enter a number!").send().await?; - next(Dialogue::ReceiveAge(cx.dialogue)) + next(Dialogue::ReceiveAge(cx.dialogue.unwrap())) } } } @@ -134,7 +135,10 @@ async fn favourite_music(cx: Cx<ReceiveFavouriteMusicState>) -> Res { Ok(favourite_music) => { cx.answer(format!( "Fine. {}", - ExitState { data: cx.dialogue.clone(), favourite_music } + ExitState { + data: cx.dialogue.clone().unwrap(), + favourite_music + } )) .send() .await?; @@ -142,33 +146,24 @@ async fn favourite_music(cx: Cx<ReceiveFavouriteMusicState>) -> Res { } Err(_) => { cx.answer("Oh, please, enter from the keyboard!").send().await?; - next(Dialogue::ReceiveFavouriteMusic(cx.dialogue)) + next(Dialogue::ReceiveFavouriteMusic(cx.dialogue.unwrap())) } } } async fn handle_message(cx: Cx<Dialogue>) -> Res { - match cx { - DialogueDispatcherHandlerCx { - bot, - update, - dialogue: Dialogue::Start, - } => start(DialogueDispatcherHandlerCx::new(bot, update, ())).await, - DialogueDispatcherHandlerCx { - bot, - update, - dialogue: Dialogue::ReceiveFullName, - } => full_name(DialogueDispatcherHandlerCx::new(bot, update, ())).await, - DialogueDispatcherHandlerCx { - bot, - update, - dialogue: Dialogue::ReceiveAge(s), - } => age(DialogueDispatcherHandlerCx::new(bot, update, s)).await, - DialogueDispatcherHandlerCx { - bot, - update, - dialogue: Dialogue::ReceiveFavouriteMusic(s), - } => { + let DialogueDispatcherHandlerCx { bot, update, dialogue } = cx; + match dialogue.unwrap() { + Dialogue::Start => { + start(DialogueDispatcherHandlerCx::new(bot, update, ())).await + } + Dialogue::ReceiveFullName => { + full_name(DialogueDispatcherHandlerCx::new(bot, update, ())).await + } + Dialogue::ReceiveAge(s) => { + age(DialogueDispatcherHandlerCx::new(bot, update, s)).await + } + Dialogue::ReceiveFavouriteMusic(s) => { favourite_music(DialogueDispatcherHandlerCx::new(bot, update, s)) .await } @@ -191,8 +186,10 @@ async fn run() { let bot = Bot::from_env(); Dispatcher::new(bot) - .messages_handler(DialogueDispatcher::new(|cx| async move { - handle_message(cx).await.expect("Something wrong with the bot!") + .messages_handler(DialogueDispatcher::new(|cx| { + async move { + handle_message(cx).await.expect("Something wrong with the bot!") + } })) .dispatch() .await; diff --git a/examples/guess_a_number_bot/src/main.rs b/examples/guess_a_number_bot/src/main.rs index 4a712650..78727647 100644 --- a/examples/guess_a_number_bot/src/main.rs +++ b/examples/guess_a_number_bot/src/main.rs @@ -21,6 +21,7 @@ extern crate smart_default; use teloxide::prelude::*; +use std::convert::Infallible; use rand::{thread_rng, Rng}; // ============================================================================ @@ -38,7 +39,7 @@ enum Dialogue { // [Control a dialogue] // ============================================================================ -type Cx<State> = DialogueDispatcherHandlerCx<Message, State>; +type Cx<State> = DialogueDispatcherHandlerCx<Message, State, Infallible>; type Res = ResponseResult<DialogueStage<Dialogue>>; async fn start(cx: Cx<()>) -> Res { @@ -49,7 +50,7 @@ async fn start(cx: Cx<()>) -> Res { } async fn receive_attempt(cx: Cx<u8>) -> Res { - let secret = cx.dialogue; + let secret = cx.dialogue.unwrap(); match cx.update.text() { None => { @@ -77,24 +78,25 @@ async fn receive_attempt(cx: Cx<u8>) -> Res { } async fn handle_message( - cx: DialogueDispatcherHandlerCx<Message, Dialogue>, + cx: DialogueDispatcherHandlerCx<Message, Dialogue, Infallible>, ) -> Res { match cx { DialogueDispatcherHandlerCx { bot, update, - dialogue: Dialogue::Start, + dialogue: Ok(Dialogue::Start), } => start(DialogueDispatcherHandlerCx::new(bot, update, ())).await, DialogueDispatcherHandlerCx { bot, update, - dialogue: Dialogue::ReceiveAttempt(secret), + dialogue: Ok(Dialogue::ReceiveAttempt(secret)), } => { receive_attempt(DialogueDispatcherHandlerCx::new( bot, update, secret, )) .await } + _ => panic!("Failed to get dialogue info from storage") } } diff --git a/src/dispatching/dialogue/dialogue_dispatcher.rs b/src/dispatching/dialogue/dialogue_dispatcher.rs index 30b6bcfb..0181a564 100644 --- a/src/dispatching/dialogue/dialogue_dispatcher.rs +++ b/src/dispatching/dialogue/dialogue_dispatcher.rs @@ -5,13 +5,13 @@ use crate::dispatching::{ }, DispatcherHandler, DispatcherHandlerCx, }; -use std::{future::Future, pin::Pin}; +use std::{convert::Infallible, marker::PhantomData}; -use futures::StreamExt; +use futures::{future::BoxFuture, StreamExt}; use tokio::sync::mpsc; use lockfree::map::Map; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; /// A dispatcher of dialogues. /// @@ -23,9 +23,10 @@ use std::sync::Arc; /// /// [`Dispatcher`]: crate::dispatching::Dispatcher /// [`DispatcherHandler`]: crate::dispatching::DispatcherHandler -pub struct DialogueDispatcher<D, H, Upd> { - storage: Arc<dyn Storage<D> + Send + Sync + 'static>, +pub struct DialogueDispatcher<D, S, H, Upd> { + storage: Arc<S>, handler: Arc<H>, + _phantom: PhantomData<Mutex<D>>, /// A lock-free map to handle updates from the same chat sequentially, but /// concurrently from different chats. @@ -36,9 +37,9 @@ pub struct DialogueDispatcher<D, H, Upd> { senders: Arc<Map<i64, mpsc::UnboundedSender<DispatcherHandlerCx<Upd>>>>, } -impl<D, H, Upd> DialogueDispatcher<D, H, Upd> +impl<D, H, Upd> DialogueDispatcher<D, InMemStorage<D>, H, Upd> where - H: DialogueDispatcherHandler<Upd, D> + Send + Sync + 'static, + H: DialogueDispatcherHandler<Upd, D, Infallible> + Send + Sync + 'static, Upd: GetChatId + Send + 'static, D: Default + Send + 'static, { @@ -52,19 +53,27 @@ where storage: InMemStorage::new(), handler: Arc::new(handler), senders: Arc::new(Map::new()), + _phantom: PhantomData, } } +} +impl<D, S, H, Upd> DialogueDispatcher<D, S, H, Upd> +where + H: DialogueDispatcherHandler<Upd, D, S::Error> + Send + Sync + 'static, + Upd: GetChatId + Send + 'static, + D: Default + Send + 'static, + S: Storage<D> + Send + Sync + 'static, + S::Error: Send + 'static, +{ /// Creates a dispatcher with the specified `handler` and `storage`. #[must_use] - pub fn with_storage<Stg>(handler: H, storage: Arc<Stg>) -> Self - where - Stg: Storage<D> + Send + Sync + 'static, - { + pub fn with_storage(handler: H, storage: Arc<S>) -> Self { Self { storage, handler: Arc::new(handler), senders: Arc::new(Map::new()), + _phantom: PhantomData, } } @@ -87,7 +96,7 @@ where let dialogue = Arc::clone(&storage) .remove_dialogue(chat_id) .await - .unwrap_or_default(); + .map(Option::unwrap_or_default); match handler .handle(DialogueDispatcherHandlerCx { @@ -98,12 +107,15 @@ where .await { DialogueStage::Next(new_dialogue) => { - update_dialogue( - Arc::clone(&storage), - chat_id, - new_dialogue, - ) - .await; + if let Ok(Some(_)) = + storage.update_dialogue(chat_id, new_dialogue).await + { + panic!( + "Oops, you have an bug in your Storage: \ + update_dialogue returns Some after \ + remove_dialogue" + ); + } } DialogueStage::Exit => { // On the next .poll() call, the spawned future will @@ -122,31 +134,18 @@ where } } -async fn update_dialogue<D>( - storage: Arc<dyn Storage<D> + Send + Sync + 'static>, - chat_id: i64, - new_dialogue: D, -) where - D: 'static + Send, -{ - if storage.update_dialogue(chat_id, new_dialogue).await.is_some() { - panic!( - "Oops, you have an bug in your Storage: update_dialogue returns \ - Some after remove_dialogue" - ); - } -} - -impl<D, H, Upd> DispatcherHandler<Upd> for DialogueDispatcher<D, H, Upd> +impl<D, S, H, Upd> DispatcherHandler<Upd> for DialogueDispatcher<D, S, H, Upd> where - H: DialogueDispatcherHandler<Upd, D> + Send + Sync + 'static, + H: DialogueDispatcherHandler<Upd, D, S::Error> + Send + Sync + 'static, Upd: GetChatId + Send + 'static, D: Default + Send + 'static, + S: Storage<D> + Send + Sync + 'static, + S::Error: Send + 'static, { fn handle( self, updates: mpsc::UnboundedReceiver<DispatcherHandlerCx<Upd>>, - ) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>> + ) -> BoxFuture<'static, ()> where DispatcherHandlerCx<Upd>: 'static, { @@ -222,7 +221,7 @@ mod tests { } let dispatcher = DialogueDispatcher::new( - |cx: DialogueDispatcherHandlerCx<MyUpdate, ()>| async move { + |cx: DialogueDispatcherHandlerCx<MyUpdate, (), Infallible>| async move { delay_for(Duration::from_millis(300)).await; match cx.update { diff --git a/src/dispatching/dialogue/dialogue_dispatcher_handler.rs b/src/dispatching/dialogue/dialogue_dispatcher_handler.rs index 2d065838..69743621 100644 --- a/src/dispatching/dialogue/dialogue_dispatcher_handler.rs +++ b/src/dispatching/dialogue/dialogue_dispatcher_handler.rs @@ -8,27 +8,30 @@ use std::{future::Future, sync::Arc}; /// overview](crate::dispatching::dialogue). /// /// [`DialogueDispatcher`]: crate::dispatching::dialogue::DialogueDispatcher -pub trait DialogueDispatcherHandler<Upd, D> { +pub trait DialogueDispatcherHandler<Upd, D, E> { #[must_use] fn handle( self: Arc<Self>, - cx: DialogueDispatcherHandlerCx<Upd, D>, + cx: DialogueDispatcherHandlerCx<Upd, D, E>, ) -> BoxFuture<'static, DialogueStage<D>> where - DialogueDispatcherHandlerCx<Upd, D>: Send + 'static; + DialogueDispatcherHandlerCx<Upd, D, E>: Send + 'static; } -impl<Upd, D, F, Fut> DialogueDispatcherHandler<Upd, D> for F +impl<Upd, D, E, F, Fut> DialogueDispatcherHandler<Upd, D, E> for F where - F: Fn(DialogueDispatcherHandlerCx<Upd, D>) -> Fut + Send + Sync + 'static, + F: Fn(DialogueDispatcherHandlerCx<Upd, D, E>) -> Fut + + Send + + Sync + + 'static, Fut: Future<Output = DialogueStage<D>> + Send + 'static, { fn handle( self: Arc<Self>, - cx: DialogueDispatcherHandlerCx<Upd, D>, + cx: DialogueDispatcherHandlerCx<Upd, D, E>, ) -> BoxFuture<'static, Fut::Output> where - DialogueDispatcherHandlerCx<Upd, D>: Send + 'static, + DialogueDispatcherHandlerCx<Upd, D, E>: Send + 'static, { Box::pin(async move { self(cx).await }) } diff --git a/src/dispatching/dialogue/dialogue_dispatcher_handler_cx.rs b/src/dispatching/dialogue/dialogue_dispatcher_handler_cx.rs index d8024431..7c76db74 100644 --- a/src/dispatching/dialogue/dialogue_dispatcher_handler_cx.rs +++ b/src/dispatching/dialogue/dialogue_dispatcher_handler_cx.rs @@ -18,24 +18,24 @@ use std::sync::Arc; /// /// [`DialogueDispatcher`]: crate::dispatching::dialogue::DialogueDispatcher #[derive(Debug)] -pub struct DialogueDispatcherHandlerCx<Upd, D> { +pub struct DialogueDispatcherHandlerCx<Upd, D, E> { pub bot: Arc<Bot>, pub update: Upd, - pub dialogue: D, + pub dialogue: Result<D, E>, } -impl<Upd, D> DialogueDispatcherHandlerCx<Upd, D> { +impl<Upd, D, E> DialogueDispatcherHandlerCx<Upd, D, E> { /// Creates a new instance with the provided fields. pub fn new(bot: Arc<Bot>, update: Upd, dialogue: D) -> Self { - Self { bot, update, dialogue } + Self { bot, update, dialogue: Ok(dialogue) } } /// Creates a new instance by substituting a dialogue and preserving /// `self.bot` and `self.update`. - pub fn with_new_dialogue<Nd>( + pub fn with_new_dialogue<Nd, Ne>( self, - new_dialogue: Nd, - ) -> DialogueDispatcherHandlerCx<Upd, Nd> { + new_dialogue: Result<Nd, Ne>, + ) -> DialogueDispatcherHandlerCx<Upd, Nd, Ne> { DialogueDispatcherHandlerCx { bot: self.bot, update: self.update, @@ -44,7 +44,7 @@ impl<Upd, D> DialogueDispatcherHandlerCx<Upd, D> { } } -impl<Upd, D> GetChatId for DialogueDispatcherHandlerCx<Upd, D> +impl<Upd, D, E> GetChatId for DialogueDispatcherHandlerCx<Upd, D, E> where Upd: GetChatId, { @@ -53,7 +53,7 @@ where } } -impl<D> DialogueDispatcherHandlerCx<Message, D> { +impl<D, E> DialogueDispatcherHandlerCx<Message, D, E> { pub fn answer<T>(&self, text: T) -> SendMessage where T: Into<String>, diff --git a/src/dispatching/dialogue/storage/in_mem_storage.rs b/src/dispatching/dialogue/storage/in_mem_storage.rs index 5ac54c4b..55c8f345 100644 --- a/src/dispatching/dialogue/storage/in_mem_storage.rs +++ b/src/dispatching/dialogue/storage/in_mem_storage.rs @@ -23,24 +23,28 @@ impl<S> InMemStorage<S> { } impl<D> Storage<D> for InMemStorage<D> { + type Error = std::convert::Infallible; + fn remove_dialogue( self: Arc<Self>, chat_id: i64, - ) -> BoxFuture<'static, Option<D>> + ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> where D: Send + 'static, { - Box::pin(async move { self.map.lock().await.remove(&chat_id) }) + Box::pin(async move { Ok(self.map.lock().await.remove(&chat_id)) }) } fn update_dialogue( self: Arc<Self>, chat_id: i64, dialogue: D, - ) -> BoxFuture<'static, Option<D>> + ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> where D: Send + 'static, { - Box::pin(async move { self.map.lock().await.insert(chat_id, dialogue) }) + Box::pin( + async move { Ok(self.map.lock().await.insert(chat_id, dialogue)) }, + ) } } diff --git a/src/dispatching/dialogue/storage/mod.rs b/src/dispatching/dialogue/storage/mod.rs index 723cd42a..acbd9888 100644 --- a/src/dispatching/dialogue/storage/mod.rs +++ b/src/dispatching/dialogue/storage/mod.rs @@ -13,6 +13,8 @@ use std::sync::Arc; /// /// [`InMemStorage`]: crate::dispatching::dialogue::InMemStorage pub trait Storage<D> { + type Error; + /// Removes a dialogue with the specified `chat_id`. /// /// Returns `None` if there wasn't such a dialogue, `Some(dialogue)` if a @@ -20,7 +22,7 @@ pub trait Storage<D> { fn remove_dialogue( self: Arc<Self>, chat_id: i64, - ) -> BoxFuture<'static, Option<D>> + ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> where D: Send + 'static; @@ -32,7 +34,7 @@ pub trait Storage<D> { self: Arc<Self>, chat_id: i64, dialogue: D, - ) -> BoxFuture<'static, Option<D>> + ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> where D: Send + 'static; }