diff --git a/src/dispatching/dispatcher.rs b/src/dispatching/dispatcher.rs index 0d094fc0..97c63f1b 100644 --- a/src/dispatching/dispatcher.rs +++ b/src/dispatching/dispatcher.rs @@ -16,52 +16,55 @@ use futures::StreamExt; use std::fmt::Debug; pub struct BasicHandlerCtx<'a, Upd> { - bot: &'a Bot, - update: Upd, + pub bot: &'a Bot, + pub update: Upd, } pub struct Dispatcher<'a, Session1, Session2, H1, H2, HandlerE> { bot: &'a Bot, - handlers_error_handler: Box>, + handlers_error_handler: Box + 'a>, private_message_dp: Option>, private_edited_message_dp: Option>, - message_handler: Option, ()>>>, + message_handler: + Option, ()> + 'a>>, edited_message_handler: - Option, ()>>>, + Option, ()> + 'a>>, channel_post_handler: - Option, ()>>>, + Option, ()> + 'a>>, edited_channel_post_handler: - Option, ()>>>, + Option, ()> + 'a>>, inline_query_handler: - Option, ()>>>, - chosen_inline_result_handler: - Option, ()>>>, + Option, ()> + 'a>>, + chosen_inline_result_handler: Option< + Box, ()> + 'a>, + >, callback_query_handler: - Option, ()>>>, + Option, ()> + 'a>>, shipping_query_handler: - Option, ()>>>, - pre_checkout_query_handler: - Option, ()>>>, - poll_handler: Option, ()>>>, + Option, ()> + 'a>>, + pre_checkout_query_handler: Option< + Box, ()> + 'a>, + >, + poll_handler: Option, ()> + 'a>>, } impl<'a, Session1, Session2, H1, H2, HandlerE> Dispatcher<'a, Session1, Session2, H1, H2, HandlerE> where - Session1: Default, - Session2: Default, + Session1: Default + 'a, + Session2: Default + 'a, H1: Handler< - SessionHandlerCtx<'a, Message, Session1>, - SessionState, - >, + SessionHandlerCtx<'a, Message, Session1>, + SessionState, + > + 'a, H2: Handler< - SessionHandlerCtx<'a, Message, Session2>, - SessionState, - >, - HandlerE: Debug, + SessionHandlerCtx<'a, Message, Session2>, + SessionState, + > + 'a, + HandlerE: Debug + 'a, { pub fn new(bot: &'a Bot) -> Self { Self { @@ -82,6 +85,14 @@ where } } + pub fn handlers_error_handler(mut self, val: T) -> Self + where + T: Handler + 'a, + { + self.handlers_error_handler = Box::new(val); + self + } + pub fn private_message_dp( mut self, dp: SessionDispatcher<'a, Session1, H1>, @@ -90,37 +101,115 @@ where self } - async fn dispatch(&'a mut self) + pub fn private_edited_message_dp( + mut self, + dp: SessionDispatcher<'a, Session2, H2>, + ) -> Self { + self.private_edited_message_dp = Some(dp); + self + } + + pub fn message_handler(mut self, h: H) -> Self where - Session1: 'a, - Session2: 'a, - H1: 'a, - H2: 'a, - HandlerE: 'a, + H: Handler, ()> + 'a, { + self.message_handler = Some(Box::new(h)); + self + } + + pub fn edited_message_handler(mut self, h: H) -> Self + where + H: Handler, ()> + 'a, + { + self.edited_message_handler = Some(Box::new(h)); + self + } + + pub fn channel_post_handler(mut self, h: H) -> Self + where + H: Handler, ()> + 'a, + { + self.channel_post_handler = Some(Box::new(h)); + self + } + + pub fn edited_channel_post_handler(mut self, h: H) -> Self + where + H: Handler, ()> + 'a, + { + self.edited_channel_post_handler = Some(Box::new(h)); + self + } + + pub fn inline_query_handler(mut self, h: H) -> Self + where + H: Handler, ()> + 'a, + { + self.inline_query_handler = Some(Box::new(h)); + self + } + + pub fn chosen_inline_result_handler(mut self, h: H) -> Self + where + H: Handler, ()> + 'a, + { + self.chosen_inline_result_handler = Some(Box::new(h)); + self + } + + pub fn callback_query_handler(mut self, h: H) -> Self + where + H: Handler, ()> + 'a, + { + self.callback_query_handler = Some(Box::new(h)); + self + } + + pub fn shipping_query_handler(mut self, h: H) -> Self + where + H: Handler, ()> + 'a, + { + self.shipping_query_handler = Some(Box::new(h)); + self + } + + pub fn pre_checkout_query_handler(mut self, h: H) -> Self + where + H: Handler, ()> + 'a, + { + self.pre_checkout_query_handler = Some(Box::new(h)); + self + } + + pub fn poll_handler(mut self, h: H) -> Self + where + H: Handler, ()> + 'a, + { + self.poll_handler = Some(Box::new(h)); + self + } + + pub async fn dispatch(&'a mut self) { self.dispatch_with_listener( update_listeners::polling_default(self.bot), - error_handlers::Log, + &error_handlers::Log, ) .await; } - async fn dispatch_with_listener( - &'a mut self, + pub async fn dispatch_with_listener( + &'a self, update_listener: UListener, - update_listener_error_handler: Eh, + update_listener_error_handler: &'a Eh, ) where UListener: UpdateListener + 'a, Eh: Handler + 'a, - Session1: 'a, - Session2: 'a, - H1: 'a, - H2: 'a, - HandlerE: 'a, ListenerE: Debug, { + let update_listener = Box::pin(update_listener); + update_listener - .for_each_concurrent(None, move |update| async { + .for_each_concurrent(None, move |update| async move { let update = match update { Ok(update) => update, Err(error) => { @@ -133,7 +222,7 @@ where UpdateKind::Message(message) => match message.chat.kind { ChatKind::Private { .. } => { if let Some(private_message_dp) = - &mut self.private_message_dp + &self.private_message_dp { private_message_dp .dispatch(self.bot, message) @@ -141,8 +230,7 @@ where } } _ => { - if let Some(message_handler) = - &mut self.message_handler + if let Some(message_handler) = &self.message_handler { message_handler .handle(BasicHandlerCtx { @@ -158,7 +246,7 @@ where match message.chat.kind { ChatKind::Private { .. } => { if let Some(private_edited_message_dp) = - &mut self.private_edited_message_dp + &self.private_edited_message_dp { private_edited_message_dp .dispatch(self.bot, message) @@ -167,7 +255,7 @@ where } _ => { if let Some(edited_message_handler) = - &mut self.edited_message_handler + &self.edited_message_handler { edited_message_handler .handle(BasicHandlerCtx { @@ -181,7 +269,7 @@ where } UpdateKind::ChannelPost(post) => { if let Some(channel_post_handler) = - &mut self.channel_post_handler + &self.channel_post_handler { channel_post_handler .handle(BasicHandlerCtx { @@ -193,7 +281,7 @@ where } UpdateKind::EditedChannelPost(post) => { if let Some(edited_channel_post_handler) = - &mut self.edited_channel_post_handler + &self.edited_channel_post_handler { edited_channel_post_handler .handle(BasicHandlerCtx { @@ -205,7 +293,7 @@ where } UpdateKind::InlineQuery(query) => { if let Some(inline_query_handler) = - &mut self.inline_query_handler + &self.inline_query_handler { inline_query_handler .handle(BasicHandlerCtx { @@ -217,7 +305,7 @@ where } UpdateKind::ChosenInlineResult(result) => { if let Some(chosen_inline_result_handler) = - &mut self.chosen_inline_result_handler + &self.chosen_inline_result_handler { chosen_inline_result_handler .handle(BasicHandlerCtx { @@ -229,7 +317,7 @@ where } UpdateKind::CallbackQuery(query) => { if let Some(callback_query_handler) = - &mut self.callback_query_handler + &self.callback_query_handler { callback_query_handler .handle(BasicHandlerCtx { @@ -241,7 +329,7 @@ where } UpdateKind::ShippingQuery(query) => { if let Some(shipping_query_handler) = - &mut self.shipping_query_handler + &self.shipping_query_handler { shipping_query_handler .handle(BasicHandlerCtx { @@ -253,7 +341,7 @@ where } UpdateKind::PreCheckoutQuery(query) => { if let Some(pre_checkout_query_handler) = - &mut self.pre_checkout_query_handler + &self.pre_checkout_query_handler { pre_checkout_query_handler .handle(BasicHandlerCtx { @@ -264,7 +352,7 @@ where } } UpdateKind::Poll(poll) => { - if let Some(poll_handler) = &mut self.poll_handler { + if let Some(poll_handler) = &self.poll_handler { poll_handler .handle(BasicHandlerCtx { bot: self.bot, diff --git a/src/dispatching/session/mod.rs b/src/dispatching/session/mod.rs index 04f41a5c..ec47ab25 100644 --- a/src/dispatching/session/mod.rs +++ b/src/dispatching/session/mod.rs @@ -84,7 +84,7 @@ where } /// Dispatches a single `message` from a private chat. - pub async fn dispatch(&'a mut self, bot: &'a Bot, update: Upd) + pub async fn dispatch(&'a self, bot: &'a Bot, update: Upd) where H: Handler, SessionState>, Upd: GetChatId, diff --git a/src/dispatching/session/storage/in_mem_storage.rs b/src/dispatching/session/storage/in_mem_storage.rs index 140ee133..3203093c 100644 --- a/src/dispatching/session/storage/in_mem_storage.rs +++ b/src/dispatching/session/storage/in_mem_storage.rs @@ -2,6 +2,7 @@ use async_trait::async_trait; use super::Storage; use std::collections::HashMap; +use tokio::sync::Mutex; /// A memory storage based on a hash map. Stores all the sessions directly in /// RAM. @@ -10,23 +11,23 @@ use std::collections::HashMap; /// All the sessions will be lost after you restart your bot. If you need to /// store them somewhere on a drive, you need to implement a storage /// communicating with a DB. -#[derive(Clone, Debug, Eq, PartialEq, Default)] +#[derive(Debug, Default)] pub struct InMemStorage { - map: HashMap, + map: Mutex>, } #[async_trait(?Send)] #[async_trait] impl Storage for InMemStorage { - async fn remove_session(&mut self, chat_id: i64) -> Option { - self.map.remove(&chat_id) + async fn remove_session(&self, chat_id: i64) -> Option { + self.map.lock().await.remove(&chat_id) } async fn update_session( - &mut self, + &self, chat_id: i64, state: Session, ) -> Option { - self.map.insert(chat_id, state) + self.map.lock().await.insert(chat_id, state) } } diff --git a/src/dispatching/session/storage/mod.rs b/src/dispatching/session/storage/mod.rs index e3e12bbf..5bff11f6 100644 --- a/src/dispatching/session/storage/mod.rs +++ b/src/dispatching/session/storage/mod.rs @@ -18,14 +18,14 @@ pub trait Storage { /// /// Returns `None` if there wasn't such a session, `Some(session)` if a /// `session` was deleted. - async fn remove_session(&mut self, chat_id: i64) -> Option; + async fn remove_session(&self, chat_id: i64) -> Option; /// Updates a session with the specified `chat_id`. /// /// Returns `None` if there wasn't such a session, `Some(session)` if a /// `session` was updated. async fn update_session( - &mut self, + &self, chat_id: i64, session: Session, ) -> Option;