diff --git a/examples/sqlite_remember_bot/Cargo.toml b/examples/sqlite_remember_bot/Cargo.toml index d224d528..9815aaa6 100644 --- a/examples/sqlite_remember_bot/Cargo.toml +++ b/examples/sqlite_remember_bot/Cargo.toml @@ -7,6 +7,7 @@ edition = "2018" [dependencies] # You can also choose "cbor-serializer" or built-in JSON serializer teloxide = { path = "../../", features = ["sqlite-storage", "bincode-serializer", "redis-storage", "macros", "auto-send"] } +dptree = { path = "../../../chakka" } log = "0.4.8" pretty_env_logger = "0.4.0" @@ -16,4 +17,3 @@ serde = "1.0.104" futures = "0.3.5" thiserror = "1.0.15" -derive_more = "0.99.9" diff --git a/examples/sqlite_remember_bot/src/main.rs b/examples/sqlite_remember_bot/src/main.rs index c8980765..23f33980 100644 --- a/examples/sqlite_remember_bot/src/main.rs +++ b/examples/sqlite_remember_bot/src/main.rs @@ -1,19 +1,15 @@ -#[macro_use] -extern crate derive_more; - -mod states; -mod transitions; - -use states::*; - use teloxide::{ - dispatching::dialogue::{serializer::Json, SqliteStorage, Storage}, + dispatching2::dialogue::{serializer::Json, SqliteStorage, Storage}, prelude::*, RequestError, }; use thiserror::Error; +use std::sync::Arc; -type StorageError = as Storage>::Error; +type Store = SqliteStorage; +// FIXME: naming +type MyDialogue = Dialogue; +type StorageError = as Storage>::Error; #[derive(Debug, Error)] enum Error { @@ -23,33 +19,69 @@ enum Error { StorageError(#[from] StorageError), } -type In = DialogueWithCx, Message, Dialogue, StorageError>; +#[derive(serde::Serialize, serde::Deserialize)] +pub enum BotDialogue { + Start, + HaveNumber(i32), +} + +impl Default for BotDialogue { + fn default() -> Self { + Self::Start + } +} async fn handle_message( - cx: UpdateWithCx, Message>, - dialogue: Dialogue, -) -> TransitionOut { - match cx.update.text().map(ToOwned::to_owned) { + bot: Arc>, + mes: Arc, + dialogue: Arc, +) -> Result<(), Error> { + match mes.text() { None => { - cx.answer("Send me a text message.").await?; - next(dialogue) + bot.send_message(mes.chat.id, "Send me a text message.").await?; } - Some(ans) => dialogue.react(cx, ans).await, + Some(ans) => { + let state = dialogue.current_state_or_default().await?; + match state { + BotDialogue::Start => { + if let Ok(number) = ans.parse() { + dialogue.next(BotDialogue::HaveNumber(number)).await?; + bot.send_message(mes.chat.id, format!("Remembered number {}. Now use /get or /reset", number)).await?; + } else { + bot.send_message(mes.chat.id, "Please, send me a number").await?; + } + } + BotDialogue::HaveNumber(num) => { + if ans.starts_with("/get") { + bot.send_message(mes.chat.id, format!("Here is your number: {}", num)).await?; + } else if ans.starts_with("/reset") { + dialogue.reset().await?; + bot.send_message(mes.chat.id, "Resetted number").await?; + } else { + bot.send_message(mes.chat.id, "Please, send /get or /reset").await?; + } + } + } + }, } + Ok(()) } #[tokio::main] async fn main() { - let bot = Bot::from_env().auto_send(); + let bot = Arc::new(Bot::from_env().auto_send()); + let storage = SqliteStorage::open("db.sqlite", Json).await.unwrap(); Dispatcher::new(bot) - .messages_handler(DialogueDispatcher::with_storage( - |DialogueWithCx { cx, dialogue }: In| async move { - let dialogue = dialogue.expect("std::convert::Infallible"); - handle_message(cx, dialogue).await.expect("Something wrong with the bot!") - }, - SqliteStorage::open("db.sqlite", Json).await.unwrap(), - )) + .dependencies({ + let mut map = dptree::di::DependencyMap::new(); + map.insert_arc(storage); + map + }) + .messages_handler(|h| { + h.add_dialogue::() + .branch(dptree::endpoint(handle_message)) + }) .dispatch() .await; } diff --git a/examples/sqlite_remember_bot/src/states.rs b/examples/sqlite_remember_bot/src/states.rs deleted file mode 100644 index 1c007b5a..00000000 --- a/examples/sqlite_remember_bot/src/states.rs +++ /dev/null @@ -1,23 +0,0 @@ -use teloxide::macros::Transition; - -use serde::{Deserialize, Serialize}; - -#[derive(Transition, From, Serialize, Deserialize)] -pub enum Dialogue { - Start(StartState), - HaveNumber(HaveNumberState), -} - -impl Default for Dialogue { - fn default() -> Self { - Self::Start(StartState) - } -} - -#[derive(Serialize, Deserialize)] -pub struct StartState; - -#[derive(Serialize, Deserialize)] -pub struct HaveNumberState { - pub number: i32, -} diff --git a/examples/sqlite_remember_bot/src/transitions.rs b/examples/sqlite_remember_bot/src/transitions.rs deleted file mode 100644 index 2606e203..00000000 --- a/examples/sqlite_remember_bot/src/transitions.rs +++ /dev/null @@ -1,39 +0,0 @@ -use teloxide::prelude::*; -use teloxide::macros::teloxide; - -use super::states::*; - -#[teloxide(subtransition)] -async fn start( - state: StartState, - cx: TransitionIn>, - ans: String, -) -> TransitionOut { - if let Ok(number) = ans.parse() { - cx.answer(format!("Remembered number {}. Now use /get or /reset", number)).await?; - next(HaveNumberState { number }) - } else { - cx.answer("Please, send me a number").await?; - next(state) - } -} - -#[teloxide(subtransition)] -async fn have_number( - state: HaveNumberState, - cx: TransitionIn>, - ans: String, -) -> TransitionOut { - let num = state.number; - - if ans.starts_with("/get") { - cx.answer(format!("Here is your number: {}", num)).await?; - next(state) - } else if ans.starts_with("/reset") { - cx.answer("Resetted number").await?; - next(StartState) - } else { - cx.answer("Please, send /get or /reset").await?; - next(state) - } -} diff --git a/src/dispatching2/dialogue/dialogue_handler_ext.rs b/src/dispatching2/dialogue/dialogue_handler_ext.rs new file mode 100644 index 00000000..7e88de9d --- /dev/null +++ b/src/dispatching2/dialogue/dialogue_handler_ext.rs @@ -0,0 +1,29 @@ +use crate::dispatching2::dialogue::{get_chat_id::GetChatId, Dialogue, Storage}; +use dptree::{di::DependencyMap, Handler, Insert}; +use std::sync::Arc; + +pub trait DialogueHandlerExt { + fn add_dialogue(self) -> Self + where + S: Storage + Send + Sync + 'static, + D: Send + Sync + 'static, + Upd: GetChatId + Send + Sync + 'static; +} + +impl<'a, Output> DialogueHandlerExt for Handler<'a, DependencyMap, Output> +where + Output: Send + Sync + 'static, +{ + fn add_dialogue(self) -> Self + where + // FIXME: some of this requirements are useless. + S: Storage + Send + Sync + 'static, + D: Send + Sync + 'static, + Upd: GetChatId + Send + Sync + 'static, + { + self.chain(dptree::map(|storage: Arc, upd: Arc| async move { + let chat_id = upd.chat_id()?; + Dialogue::new(storage, chat_id).ok() + })) + } +} diff --git a/src/dispatching2/dialogue/get_chat_id.rs b/src/dispatching2/dialogue/get_chat_id.rs new file mode 100644 index 00000000..2dabcabd --- /dev/null +++ b/src/dispatching2/dialogue/get_chat_id.rs @@ -0,0 +1,20 @@ +use crate::types::CallbackQuery; +use teloxide_core::types::Message; + +/// Something that maybe has a chat ID. +pub trait GetChatId { + #[must_use] + fn chat_id(&self) -> Option; +} + +impl GetChatId for Message { + fn chat_id(&self) -> Option { + Some(self.chat.id) + } +} + +impl GetChatId for CallbackQuery { + fn chat_id(&self) -> Option { + self.message.as_ref().map(|mes| mes.chat.id) + } +} diff --git a/src/dispatching2/dialogue/mod.rs b/src/dispatching2/dialogue/mod.rs new file mode 100644 index 00000000..d782b2ce --- /dev/null +++ b/src/dispatching2/dialogue/mod.rs @@ -0,0 +1,85 @@ +#[cfg(feature = "redis-storage")] +#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "redis-storage")))] +pub use storage::{RedisStorage, RedisStorageError}; + +#[cfg(feature = "sqlite-storage")] +pub use storage::{SqliteStorage, SqliteStorageError}; + +pub use storage::{serializer, InMemStorage, InMemStorageError, Serializer, Storage, TraceStorage}; + +pub use dialogue_handler_ext::DialogueHandlerExt; + +use std::{future::Future, marker::PhantomData, sync::Arc}; + +mod dialogue_handler_ext; +mod get_chat_id; +mod storage; + +#[derive(Debug)] +pub struct Dialogue { + // Maybe it's better to use Box> here but it's require + // us to introduce `Err` generic parameter. + storage: Arc, + chat_id: i64, + _phantom: PhantomData, +} + +impl Dialogue +where + D: Send + 'static, + S: Storage, +{ + pub fn new(storage: Arc, chat_id: i64) -> Result { + Ok(Self { storage, chat_id, _phantom: PhantomData }) + } + + // TODO: Cache this. + pub async fn current_state(&self) -> Result, S::Error> { + self.storage.clone().get_dialogue(self.chat_id).await + } + + pub async fn current_state_or_default(&self) -> Result + where + D: Default, + { + match self.storage.clone().get_dialogue(self.chat_id).await? { + Some(d) => Ok(d), + None => { + self.storage.clone().update_dialogue(self.chat_id, D::default()).await?; + Ok(D::default()) + } + } + } + + pub async fn next(&self, state: State) -> Result<(), S::Error> + where + D: From, + { + let new_dialogue = state.into(); + self.storage.clone().update_dialogue(self.chat_id, new_dialogue).await?; + Ok(()) + } + + pub async fn with(&self, f: F) -> Result<(), S::Error> + where + F: FnOnce(Option) -> Fut, + Fut: Future, + D: From, + { + let current_dialogue = self.current_state().await?; + let new_dialogue = f(current_dialogue).await.into(); + self.storage.clone().update_dialogue(self.chat_id, new_dialogue).await?; + Ok(()) + } + + pub async fn reset(&self) -> Result<(), S::Error> + where + D: Default, + { + self.next(D::default()).await + } + + pub async fn exit(&self) -> Result<(), S::Error> { + self.storage.clone().remove_dialogue(self.chat_id).await + } +} diff --git a/src/dispatching2/dialogue/storage/in_mem_storage.rs b/src/dispatching2/dialogue/storage/in_mem_storage.rs new file mode 100644 index 00000000..d26a21eb --- /dev/null +++ b/src/dispatching2/dialogue/storage/in_mem_storage.rs @@ -0,0 +1,73 @@ +use super::Storage; +use futures::future::BoxFuture; +use std::{collections::HashMap, sync::Arc}; +use thiserror::Error; +use tokio::sync::Mutex; + +/// An error returned from [`InMemStorage`]. +#[derive(Debug, Error)] +pub enum InMemStorageError { + /// Returned from [`InMemStorage::remove_dialogue`]. + #[error("row not found")] + DialogueNotFound, +} + +/// A dialogue storage based on [`std::collections::HashMap`]. +/// +/// ## Note +/// All your dialogues will be lost after you restart your bot. If you need to +/// store them somewhere on a drive, you should use e.g. +/// [`super::SqliteStorage`] or implement your own. +#[derive(Debug)] +pub struct InMemStorage { + map: Mutex>, +} + +impl InMemStorage { + #[must_use] + pub fn new() -> Arc { + Arc::new(Self { map: Mutex::new(HashMap::new()) }) + } +} + +impl Storage for InMemStorage +where + D: Clone, + D: Send + 'static, +{ + type Error = InMemStorageError; + + fn remove_dialogue(self: Arc, chat_id: i64) -> BoxFuture<'static, Result<(), Self::Error>> + where + D: Send + 'static, + { + Box::pin(async move { + self.map + .lock() + .await + .remove(&chat_id) + .map_or(Err(InMemStorageError::DialogueNotFound), |_| Ok(())) + }) + } + + fn update_dialogue( + self: Arc, + chat_id: i64, + dialogue: D, + ) -> BoxFuture<'static, Result<(), Self::Error>> + where + D: Send + 'static, + { + Box::pin(async move { + self.map.lock().await.insert(chat_id, dialogue); + Ok(()) + }) + } + + fn get_dialogue( + self: Arc, + chat_id: i64, + ) -> BoxFuture<'static, Result, Self::Error>> { + Box::pin(async move { Ok(self.map.lock().await.get(&chat_id).map(ToOwned::to_owned)) }) + } +} diff --git a/src/dispatching2/dialogue/storage/mod.rs b/src/dispatching2/dialogue/storage/mod.rs new file mode 100644 index 00000000..dbf4c25a --- /dev/null +++ b/src/dispatching2/dialogue/storage/mod.rs @@ -0,0 +1,76 @@ +pub mod serializer; + +mod in_mem_storage; +mod trace_storage; + +#[cfg(feature = "redis-storage")] +mod redis_storage; + +#[cfg(feature = "sqlite-storage")] +mod sqlite_storage; + +use futures::future::BoxFuture; + +pub use self::{ + in_mem_storage::{InMemStorage, InMemStorageError}, + trace_storage::TraceStorage, +}; + +#[cfg(feature = "redis-storage")] +#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "redis-storage")))] +pub use redis_storage::{RedisStorage, RedisStorageError}; +pub use serializer::Serializer; +use std::sync::Arc; + +#[cfg(feature = "sqlite-storage")] +pub use sqlite_storage::{SqliteStorage, SqliteStorageError}; + +/// A storage of dialogues. +/// +/// You can implement this trait for a structure that communicates with a DB and +/// be sure that after you restart your bot, all the dialogues won't be lost. +/// +/// `Storage` is used only to store dialogue states, i.e. it can't be used as a +/// generic database. +/// +/// Currently we support the following storages out of the box: +/// +/// - [`InMemStorage`] -- a storage based on [`std::collections::HashMap`]. +/// - [`RedisStorage`] -- a Redis-based storage. +/// - [`SqliteStorage`] -- an SQLite-based persistent storage. +/// +/// [`InMemStorage`]: crate::dispatching::dialogue::InMemStorage +/// [`RedisStorage`]: crate::dispatching::dialogue::RedisStorage +/// [`SqliteStorage`]: crate::dispatching::dialogue::SqliteStorage +pub trait Storage { + type Error; + + /// Removes a dialogue indexed by `chat_id`. + /// + /// If the dialogue indexed by `chat_id` does not exist, this function + /// results in an error. + #[must_use = "Futures are lazy and do nothing unless polled with .await"] + fn remove_dialogue( + self: Arc, + chat_id: i64, + ) -> BoxFuture<'static, Result<(), Self::Error>> + where + D: Send + 'static; + + /// Updates a dialogue indexed by `chat_id` with `dialogue`. + #[must_use = "Futures are lazy and do nothing unless polled with .await"] + fn update_dialogue( + self: Arc, + chat_id: i64, + dialogue: D, + ) -> BoxFuture<'static, Result<(), Self::Error>> + where + D: Send + 'static; + + /// Returns the dialogue indexed by `chat_id`. + #[must_use = "Futures are lazy and do nothing unless polled with .await"] + fn get_dialogue( + self: Arc, + chat_id: i64, + ) -> BoxFuture<'static, Result, Self::Error>>; +} diff --git a/src/dispatching2/dialogue/storage/redis_storage.rs b/src/dispatching2/dialogue/storage/redis_storage.rs new file mode 100644 index 00000000..f3889acc --- /dev/null +++ b/src/dispatching2/dialogue/storage/redis_storage.rs @@ -0,0 +1,110 @@ +use super::{serializer::Serializer, Storage}; +use futures::future::BoxFuture; +use redis::{AsyncCommands, IntoConnectionInfo}; +use serde::{de::DeserializeOwned, Serialize}; +use std::{ + convert::Infallible, + fmt::{Debug, Display}, + ops::DerefMut, + sync::Arc, +}; +use thiserror::Error; +use tokio::sync::Mutex; + +/// An error returned from [`RedisStorage`]. +#[derive(Debug, Error)] +pub enum RedisStorageError +where + SE: Debug + Display, +{ + #[error("parsing/serializing error: {0}")] + SerdeError(SE), + + #[error("error from Redis: {0}")] + RedisError(#[from] redis::RedisError), + + /// Returned from [`RedisStorage::remove_dialogue`]. + #[error("row not found")] + DialogueNotFound, +} + +/// A dialogue storage based on [Redis](https://redis.io/). +pub struct RedisStorage { + conn: Mutex, + serializer: S, +} + +impl RedisStorage { + pub async fn open( + url: impl IntoConnectionInfo, + serializer: S, + ) -> Result, RedisStorageError> { + Ok(Arc::new(Self { + conn: Mutex::new(redis::Client::open(url)?.get_async_connection().await?), + serializer, + })) + } +} + +impl Storage for RedisStorage +where + S: Send + Sync + Serializer + 'static, + D: Send + Serialize + DeserializeOwned + 'static, + >::Error: Debug + Display, +{ + type Error = RedisStorageError<>::Error>; + + fn remove_dialogue( + self: Arc, + chat_id: i64, + ) -> BoxFuture<'static, Result<(), Self::Error>> { + Box::pin(async move { + let deleted_rows_count = redis::pipe() + .atomic() + .del(chat_id) + .query_async::<_, redis::Value>(self.conn.lock().await.deref_mut()) + .await?; + + if let redis::Value::Bulk(values) = deleted_rows_count { + // False positive + #[allow(clippy::collapsible_match)] + if let redis::Value::Int(deleted_rows_count) = values[0] { + match deleted_rows_count { + 0 => return Err(RedisStorageError::DialogueNotFound), + _ => return Ok(()), + } + } + } + + unreachable!("Must return redis::Value::Bulk(redis::Value::Int(_))"); + }) + } + + fn update_dialogue( + self: Arc, + chat_id: i64, + dialogue: D, + ) -> BoxFuture<'static, Result<(), Self::Error>> { + Box::pin(async move { + let dialogue = + self.serializer.serialize(&dialogue).map_err(RedisStorageError::SerdeError)?; + self.conn.lock().await.set::<_, Vec, _>(chat_id, dialogue).await?; + Ok(()) + }) + } + + fn get_dialogue( + self: Arc, + chat_id: i64, + ) -> BoxFuture<'static, Result, Self::Error>> { + Box::pin(async move { + self.conn + .lock() + .await + .get::<_, Option>>(chat_id) + .await? + .map(|d| self.serializer.deserialize(&d).map_err(RedisStorageError::SerdeError)) + .transpose() + }) + } +} diff --git a/src/dispatching2/dialogue/storage/serializer.rs b/src/dispatching2/dialogue/storage/serializer.rs new file mode 100644 index 00000000..2fb30cbc --- /dev/null +++ b/src/dispatching2/dialogue/storage/serializer.rs @@ -0,0 +1,77 @@ +//! Various serializers for dialogue storages. + +use serde::{de::DeserializeOwned, ser::Serialize}; + +/// A serializer for memory storages. +pub trait Serializer { + type Error; + + fn serialize(&self, val: &D) -> Result, Self::Error>; + fn deserialize(&self, data: &[u8]) -> Result; +} + +/// The JSON serializer for memory storages. +pub struct Json; + +impl Serializer for Json +where + D: Serialize + DeserializeOwned, +{ + type Error = serde_json::Error; + + fn serialize(&self, val: &D) -> Result, Self::Error> { + serde_json::to_vec(val) + } + + fn deserialize(&self, data: &[u8]) -> Result { + serde_json::from_slice(data) + } +} + +/// The [CBOR] serializer for memory storages. +/// +/// [CBOR]: https://en.wikipedia.org/wiki/CBOR +#[cfg(feature = "cbor-serializer")] +#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "cbor-serializer")))] +pub struct Cbor; + +#[cfg(feature = "cbor-serializer")] +#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "cbor-serializer")))] +impl Serializer for Cbor +where + D: Serialize + DeserializeOwned, +{ + type Error = serde_cbor::Error; + + fn serialize(&self, val: &D) -> Result, Self::Error> { + serde_cbor::to_vec(val) + } + + fn deserialize(&self, data: &[u8]) -> Result { + serde_cbor::from_slice(data) + } +} + +/// The [Bincode] serializer for memory storages. +/// +/// [Bincode]: https://github.com/servo/bincode +#[cfg(feature = "bincode-serializer")] +#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "bincode-serializer")))] +pub struct Bincode; + +#[cfg(feature = "bincode-serializer")] +#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "bincode-serializer")))] +impl Serializer for Bincode +where + D: Serialize + DeserializeOwned, +{ + type Error = bincode::Error; + + fn serialize(&self, val: &D) -> Result, Self::Error> { + bincode::serialize(val) + } + + fn deserialize(&self, data: &[u8]) -> Result { + bincode::deserialize(data) + } +} diff --git a/src/dispatching2/dialogue/storage/sqlite_storage.rs b/src/dispatching2/dialogue/storage/sqlite_storage.rs new file mode 100644 index 00000000..a562b5e5 --- /dev/null +++ b/src/dispatching2/dialogue/storage/sqlite_storage.rs @@ -0,0 +1,141 @@ +use super::{serializer::Serializer, Storage}; +use futures::future::BoxFuture; +use serde::{de::DeserializeOwned, Serialize}; +use sqlx::{sqlite::SqlitePool, Executor}; +use std::{ + convert::Infallible, + fmt::{Debug, Display}, + str, + sync::Arc, +}; +use thiserror::Error; + +/// A persistent dialogue storage based on [SQLite](https://www.sqlite.org/). +pub struct SqliteStorage { + pool: SqlitePool, + serializer: S, +} + +/// An error returned from [`SqliteStorage`]. +#[derive(Debug, Error)] +pub enum SqliteStorageError +where + SE: Debug + Display, +{ + #[error("dialogue serialization error: {0}")] + SerdeError(SE), + + #[error("sqlite error: {0}")] + SqliteError(#[from] sqlx::Error), + + /// Returned from [`SqliteStorage::remove_dialogue`]. + #[error("row not found")] + DialogueNotFound, +} + +impl SqliteStorage { + pub async fn open( + path: &str, + serializer: S, + ) -> Result, SqliteStorageError> { + let pool = SqlitePool::connect(format!("sqlite:{}?mode=rwc", path).as_str()).await?; + let mut conn = pool.acquire().await?; + sqlx::query( + r#" +CREATE TABLE IF NOT EXISTS teloxide_dialogues ( + chat_id BIGINT PRIMARY KEY, + dialogue BLOB NOT NULL +); + "#, + ) + .execute(&mut conn) + .await?; + + Ok(Arc::new(Self { pool, serializer })) + } +} + +impl Storage for SqliteStorage +where + S: Send + Sync + Serializer + 'static, + D: Send + Serialize + DeserializeOwned + 'static, + >::Error: Debug + Display, +{ + type Error = SqliteStorageError<>::Error>; + + /// Returns [`sqlx::Error::RowNotFound`] if a dialogue does not exist. + fn remove_dialogue( + self: Arc, + chat_id: i64, + ) -> BoxFuture<'static, Result<(), Self::Error>> { + Box::pin(async move { + let deleted_rows_count = + sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?") + .bind(chat_id) + .execute(&self.pool) + .await? + .rows_affected(); + + if deleted_rows_count == 0 { + return Err(SqliteStorageError::DialogueNotFound); + } + + Ok(()) + }) + } + + fn update_dialogue( + self: Arc, + chat_id: i64, + dialogue: D, + ) -> BoxFuture<'static, Result<(), Self::Error>> { + Box::pin(async move { + let d = self.serializer.serialize(&dialogue).map_err(SqliteStorageError::SerdeError)?; + + self.pool + .acquire() + .await? + .execute( + sqlx::query( + r#" + INSERT INTO teloxide_dialogues VALUES (?, ?) + ON CONFLICT(chat_id) DO UPDATE SET dialogue=excluded.dialogue + "#, + ) + .bind(chat_id) + .bind(d), + ) + .await?; + Ok(()) + }) + } + + fn get_dialogue( + self: Arc, + chat_id: i64, + ) -> BoxFuture<'static, Result, Self::Error>> { + Box::pin(async move { + get_dialogue(&self.pool, chat_id) + .await? + .map(|d| self.serializer.deserialize(&d).map_err(SqliteStorageError::SerdeError)) + .transpose() + }) + } +} + +async fn get_dialogue(pool: &SqlitePool, chat_id: i64) -> Result>, sqlx::Error> { + #[derive(sqlx::FromRow)] + struct DialogueDbRow { + dialogue: Vec, + } + + let bytes = sqlx::query_as::<_, DialogueDbRow>( + "SELECT dialogue FROM teloxide_dialogues WHERE chat_id = ?", + ) + .bind(chat_id) + .fetch_optional(pool) + .await? + .map(|r| r.dialogue); + + Ok(bytes) +} diff --git a/src/dispatching2/dialogue/storage/trace_storage.rs b/src/dispatching2/dialogue/storage/trace_storage.rs new file mode 100644 index 00000000..f4e22d8e --- /dev/null +++ b/src/dispatching2/dialogue/storage/trace_storage.rs @@ -0,0 +1,68 @@ +use std::{ + fmt::Debug, + marker::{Send, Sync}, + sync::Arc, +}; + +use futures::future::BoxFuture; + +use crate::dispatching::dialogue::Storage; + +/// A dialogue storage wrapper which logs all actions performed on an underlying +/// storage. +/// +/// Reports about any dialogue action via [`log::Level::Trace`]. +pub struct TraceStorage { + inner: Arc, +} + +impl TraceStorage { + #[must_use] + pub fn new(inner: Arc) -> Arc { + Arc::new(Self { inner }) + } + + pub fn into_inner(self) -> Arc { + self.inner + } +} + +impl Storage for TraceStorage +where + D: Debug, + S: Storage + Send + Sync + 'static, +{ + type Error = >::Error; + + fn remove_dialogue(self: Arc, chat_id: i64) -> BoxFuture<'static, Result<(), Self::Error>> + where + D: Send + 'static, + { + log::trace!("Removing dialogue #{}", chat_id); + >::remove_dialogue(self.inner.clone(), chat_id) + } + + fn update_dialogue( + self: Arc, + chat_id: i64, + dialogue: D, + ) -> BoxFuture<'static, Result<(), Self::Error>> + where + D: Send + 'static, + { + Box::pin(async move { + let to = format!("{:#?}", dialogue); + >::update_dialogue(self.inner.clone(), chat_id, dialogue).await?; + log::trace!("Updated a dialogue #{}: {:#?}", chat_id, to); + Ok(()) + }) + } + + fn get_dialogue( + self: Arc, + chat_id: i64, + ) -> BoxFuture<'static, Result, Self::Error>> { + log::trace!("Requested a dialogue #{}", chat_id); + >::get_dialogue(self.inner.clone(), chat_id) + } +} diff --git a/src/dispatching2/mod.rs b/src/dispatching2/mod.rs index be9662d6..57488066 100644 --- a/src/dispatching2/mod.rs +++ b/src/dispatching2/mod.rs @@ -1,5 +1,6 @@ pub(crate) mod repls; +pub mod dialogue; mod dispatcher; pub use dispatcher::Dispatcher; diff --git a/src/prelude.rs b/src/prelude.rs index a49ee236..a011cfbf 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -15,7 +15,10 @@ pub use crate::dispatching::{ }; #[cfg(not(feature = "old_dispatching"))] -pub use crate::dispatching2::Dispatcher; +pub use crate::dispatching2::{ + dialogue::{Dialogue, DialogueHandlerExt as _}, + Dispatcher, +}; #[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "macros")))] #[cfg(feature = "macros")]