diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5ee91e26..9d67b41f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,3 +1,4 @@ + on: push: branches: [ master ] @@ -35,6 +36,12 @@ jobs: profile: minimal toolchain: stable override: true + - name: Setup redis + run: | + sudo apt install redis-server + redis-server --port 7777 > /dev/null & + redis-server --port 7778 > /dev/null & + redis-server --port 7779 > /dev/null & - name: Cargo test run: cargo test --all-features build-example: diff --git a/Cargo.toml b/Cargo.toml index e2fcd313..1a975c17 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,10 @@ authors = [ [badges] maintenance = { status = "actively-developed" } -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +redis-storage = ["redis"] +cbor-serializer = ["serde_cbor"] +bincode-serializer = ["bincode"] [dependencies] serde_json = "1.0.44" @@ -45,6 +48,10 @@ futures = "0.3.1" pin-project = "0.4.6" serde_with_macros = "1.0.1" +redis = { version = "0.15.1", optional = true } +serde_cbor = { version = "0.11.1", optional = true } +bincode = { version = "1.2.1", optional = true } + teloxide-macros = "0.3.1" [dev-dependencies] diff --git a/examples/dialogue_bot/src/main.rs b/examples/dialogue_bot/src/main.rs index e0447004..f9b9683d 100644 --- a/examples/dialogue_bot/src/main.rs +++ b/examples/dialogue_bot/src/main.rs @@ -45,9 +45,9 @@ async fn run() { Dispatcher::new(bot) .messages_handler(DialogueDispatcher::new( - |cx: DialogueWithCx| async move { + |input: TransitionIn| async move { // Unwrap without panic because of std::convert::Infallible. - dispatch(cx.cx, cx.dialogue.unwrap()) + dispatch(input.cx, input.dialogue.unwrap()) .await .expect("Something wrong with the bot!") }, diff --git a/examples/redis_remember_bot/Cargo.toml b/examples/redis_remember_bot/Cargo.toml new file mode 100644 index 00000000..6a91b292 --- /dev/null +++ b/examples/redis_remember_bot/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "redis_remember_bot" +version = "0.1.0" +authors = ["Maximilian Siling "] +edition = "2018" + +[dependencies] +tokio = "0.2.9" + +# You can also choose "cbor-serializer" or built-in JSON serializer +teloxide = { path = "../../", features = ["redis-storage", "bincode-serializer"] } +serde = "1.0.104" + +thiserror = "1.0.15" +smart-default = "0.6.0" +derive_more = "0.99.9" \ No newline at end of file diff --git a/examples/redis_remember_bot/src/main.rs b/examples/redis_remember_bot/src/main.rs new file mode 100644 index 00000000..bffaa8aa --- /dev/null +++ b/examples/redis_remember_bot/src/main.rs @@ -0,0 +1,66 @@ +#[macro_use] +extern crate smart_default; +#[macro_use] +extern crate derive_more; + +mod states; +mod transitions; + +use states::*; +use transitions::*; + +use teloxide::{ + dispatching::dialogue::{serializer::Bincode, RedisStorage, Storage}, + prelude::*, +}; +use thiserror::Error; + +type StorageError = as Storage>::Error; + +#[derive(Debug, Error)] +enum Error { + #[error("error from Telegram: {0}")] + TelegramError(#[from] RequestError), + #[error("error from storage: {0}")] + StorageError(#[from] StorageError), +} + +type In = TransitionIn; + +async fn handle_message(input: In) -> Out { + let (cx, dialogue) = input.unpack(); + + match cx.update.text_owned() { + Some(text) => dispatch(cx, dialogue, &text).await, + None => { + cx.answer_str("Please, send me a text message").await?; + next(StartState) + } + } +} + +#[tokio::main] +async fn main() { + run().await; +} + +async fn run() { + let bot = Bot::from_env(); + Dispatcher::new(bot) + .messages_handler(DialogueDispatcher::with_storage( + |cx| async move { + handle_message(cx) + .await + .expect("Something is wrong with the bot!") + }, + // You can also choose serializer::JSON or serializer::CBOR + // All serializers but JSON require enabling feature + // "serializer-", e. g. "serializer-cbor" + // or "serializer-bincode" + RedisStorage::open("redis://127.0.0.1:6379", Bincode) + .await + .unwrap(), + )) + .dispatch() + .await; +} diff --git a/examples/redis_remember_bot/src/states.rs b/examples/redis_remember_bot/src/states.rs new file mode 100644 index 00000000..142e823a --- /dev/null +++ b/examples/redis_remember_bot/src/states.rs @@ -0,0 +1,23 @@ +use teloxide::prelude::*; + +use serde::{Deserialize, Serialize}; + +#[derive(Default, Serialize, Deserialize)] +pub struct StartState; + +#[derive(Serialize, Deserialize)] +pub struct HaveNumberState { + rest: StartState, + pub number: i32, +} + +up!( + StartState + [number: i32] -> HaveNumberState, +); + +#[derive(SmartDefault, From, Serialize, Deserialize)] +pub enum Dialogue { + #[default] + Start(StartState), + HaveNumber(HaveNumberState), +} diff --git a/examples/redis_remember_bot/src/transitions.rs b/examples/redis_remember_bot/src/transitions.rs new file mode 100644 index 00000000..594471e1 --- /dev/null +++ b/examples/redis_remember_bot/src/transitions.rs @@ -0,0 +1,42 @@ +use teloxide::prelude::*; + +use super::states::*; + +pub type Cx = UpdateWithCx; +pub type Out = TransitionOut; + +async fn start(cx: Cx, state: StartState, text: &str) -> Out { + if let Ok(number) = text.parse() { + cx.answer_str(format!( + "Remembered number {}. Now use /get or /reset", + number + )) + .await?; + next(state.up(number)) + } else { + cx.answer_str("Please, send me a number").await?; + next(state) + } +} + +async fn have_number(cx: Cx, state: HaveNumberState, text: &str) -> Out { + let num = state.number; + + if text.starts_with("/get") { + cx.answer_str(format!("Here is your number: {}", num)).await?; + next(state) + } else if text.starts_with("/reset") { + cx.answer_str("Resetted number").await?; + next(StartState) + } else { + cx.answer_str("Please, send /get or /reset").await?; + next(state) + } +} + +pub async fn dispatch(cx: Cx, dialogue: Dialogue, text: &str) -> Out { + match dialogue { + Dialogue::Start(state) => start(cx, state, text).await, + Dialogue::HaveNumber(state) => have_number(cx, state, text).await, + } +} diff --git a/src/dispatching/dialogue/mod.rs b/src/dispatching/dialogue/mod.rs index 20efc346..09d753fd 100644 --- a/src/dispatching/dialogue/mod.rs +++ b/src/dispatching/dialogue/mod.rs @@ -55,7 +55,11 @@ pub use dialogue_dispatcher_handler::DialogueDispatcherHandler; pub use dialogue_stage::{exit, next, DialogueStage}; pub use dialogue_with_cx::DialogueWithCx; pub use get_chat_id::GetChatId; -pub use storage::{InMemStorage, Storage}; + +#[cfg(feature = "redis-storage")] +pub use storage::{RedisStorage, RedisStorageError}; + +pub use storage::{serializer, InMemStorage, Serializer, Storage}; /// Generates `.up(field)` methods for dialogue states. /// diff --git a/src/dispatching/dialogue/storage/mod.rs b/src/dispatching/dialogue/storage/mod.rs index acbd9888..ed5319c8 100644 --- a/src/dispatching/dialogue/storage/mod.rs +++ b/src/dispatching/dialogue/storage/mod.rs @@ -1,7 +1,16 @@ +pub mod serializer; + mod in_mem_storage; +#[cfg(feature = "redis-storage")] +mod redis_storage; + use futures::future::BoxFuture; + pub use in_mem_storage::InMemStorage; +#[cfg(feature = "redis-storage")] +pub use redis_storage::{RedisStorage, RedisStorageError}; +pub use serializer::Serializer; use std::sync::Arc; /// A storage of dialogues. diff --git a/src/dispatching/dialogue/storage/redis_storage.rs b/src/dispatching/dialogue/storage/redis_storage.rs new file mode 100644 index 00000000..37c5fd07 --- /dev/null +++ b/src/dispatching/dialogue/storage/redis_storage.rs @@ -0,0 +1,112 @@ +use super::{serializer::Serializer, Storage}; +use futures::future::BoxFuture; +use redis::{AsyncCommands, FromRedisValue, IntoConnectionInfo}; +use serde::{de::DeserializeOwned, Serialize}; +use std::{ + convert::Infallible, + fmt::{Debug, Display}, + ops::DerefMut, + sync::Arc, +}; +use thiserror::Error; +use tokio::sync::Mutex; + +/// An error returned from [`RedisStorage`]. +/// +/// [`RedisStorage`]: struct.RedisStorage.html +#[derive(Debug, Error)] +pub enum RedisStorageError +where + SE: Debug + Display, +{ + #[error("parsing/serializing error: {0}")] + SerdeError(SE), + #[error("error from Redis: {0}")] + RedisError(#[from] redis::RedisError), +} + +/// A memory storage based on [Redis](https://redis.io/). +pub struct RedisStorage { + conn: Mutex, + serializer: S, +} + +impl RedisStorage { + pub async fn open( + url: impl IntoConnectionInfo, + serializer: S, + ) -> Result, RedisStorageError> { + Ok(Arc::new(Self { + conn: Mutex::new( + redis::Client::open(url)?.get_async_connection().await?, + ), + serializer, + })) + } +} + +impl Storage for RedisStorage +where + S: Send + Sync + Serializer + 'static, + D: Send + Serialize + DeserializeOwned + 'static, + >::Error: Debug + Display, +{ + type Error = RedisStorageError<>::Error>; + + // `.del().ignore()` is much more readable than `.del()\n.ignore()` + #[rustfmt::skip] + fn remove_dialogue( + self: Arc, + chat_id: i64, + ) -> BoxFuture<'static, Result, Self::Error>> { + Box::pin(async move { + let res = redis::pipe() + .atomic() + .get(chat_id) + .del(chat_id).ignore() + .query_async::<_, redis::Value>( + self.conn.lock().await.deref_mut(), + ) + .await?; + // We're expecting `.pipe()` to return us an exactly one result in + // bulk, so all other branches should be unreachable + match res { + redis::Value::Bulk(bulk) if bulk.len() == 1 => { + Ok(Option::>::from_redis_value(&bulk[0])? + .map(|v| { + self.serializer + .deserialize(&v) + .map_err(RedisStorageError::SerdeError) + }) + .transpose()?) + } + _ => unreachable!(), + } + }) + } + + fn update_dialogue( + self: Arc, + chat_id: i64, + dialogue: D, + ) -> BoxFuture<'static, Result, Self::Error>> { + Box::pin(async move { + let dialogue = self + .serializer + .serialize(&dialogue) + .map_err(RedisStorageError::SerdeError)?; + Ok(self + .conn + .lock() + .await + .getset::<_, Vec, Option>>(chat_id, dialogue) + .await? + .map(|d| { + self.serializer + .deserialize(&d) + .map_err(RedisStorageError::SerdeError) + }) + .transpose()?) + }) + } +} diff --git a/src/dispatching/dialogue/storage/serializer.rs b/src/dispatching/dialogue/storage/serializer.rs new file mode 100644 index 00000000..f31724a4 --- /dev/null +++ b/src/dispatching/dialogue/storage/serializer.rs @@ -0,0 +1,68 @@ +/// Various serializers for memory storages. +use serde::{de::DeserializeOwned, ser::Serialize}; + +/// A serializer for memory storages. +pub trait Serializer { + type Error; + + fn serialize(&self, val: &D) -> Result, Self::Error>; + fn deserialize(&self, data: &[u8]) -> Result; +} + +/// The JSON serializer for memory storages. +pub struct JSON; + +impl Serializer for JSON +where + D: Serialize + DeserializeOwned, +{ + type Error = serde_json::Error; + + fn serialize(&self, val: &D) -> Result, Self::Error> { + serde_json::to_vec(val) + } + + fn deserialize(&self, data: &[u8]) -> Result { + serde_json::from_slice(data) + } +} + +/// The CBOR serializer for memory storages. +#[cfg(feature = "cbor-serializer")] +pub struct CBOR; + +#[cfg(feature = "cbor-serializer")] +impl Serializer for CBOR +where + D: Serialize + DeserializeOwned, +{ + type Error = serde_cbor::Error; + + fn serialize(&self, val: &D) -> Result, Self::Error> { + serde_cbor::to_vec(val) + } + + fn deserialize(&self, data: &[u8]) -> Result { + serde_cbor::from_slice(data) + } +} + +/// The Bincode serializer for memory storages. +#[cfg(feature = "bincode-serializer")] +pub struct Bincode; + +#[cfg(feature = "bincode-serializer")] +impl Serializer for Bincode +where + D: Serialize + DeserializeOwned, +{ + type Error = bincode::Error; + + fn serialize(&self, val: &D) -> Result, Self::Error> { + bincode::serialize(val) + } + + fn deserialize(&self, data: &[u8]) -> Result { + bincode::deserialize(data) + } +} diff --git a/tests/redis.rs b/tests/redis.rs new file mode 100644 index 00000000..61dd6ad1 --- /dev/null +++ b/tests/redis.rs @@ -0,0 +1,82 @@ +use std::{ + fmt::{Debug, Display}, + future::Future, + sync::Arc, +}; +use teloxide::dispatching::dialogue::{ + serializer::{Bincode, CBOR, JSON}, + RedisStorage, Serializer, Storage, +}; + +#[tokio::test] +async fn test_redis_json() { + let storage = + RedisStorage::open("redis://127.0.0.1:7777", JSON).await.unwrap(); + test_redis(storage).await; +} + +#[tokio::test] +async fn test_redis_bincode() { + let storage = + RedisStorage::open("redis://127.0.0.1:7778", Bincode).await.unwrap(); + test_redis(storage).await; +} + +#[tokio::test] +async fn test_redis_cbor() { + let storage = + RedisStorage::open("redis://127.0.0.1:7779", CBOR).await.unwrap(); + test_redis(storage).await; +} + +type Dialogue = String; + +async fn test_redis(storage: Arc>) +where + S: Send + Sync + Serializer + 'static, + >::Error: Debug + Display, +{ + check_dialogue( + None, + Arc::clone(&storage).update_dialogue(1, "ABC".to_owned()), + ) + .await; + check_dialogue( + None, + Arc::clone(&storage).update_dialogue(11, "DEF".to_owned()), + ) + .await; + check_dialogue( + None, + Arc::clone(&storage).update_dialogue(256, "GHI".to_owned()), + ) + .await; + + // 1 - ABC, 11 - DEF, 256 - GHI + + check_dialogue( + "ABC", + Arc::clone(&storage).update_dialogue(1, "JKL".to_owned()), + ) + .await; + check_dialogue( + "GHI", + Arc::clone(&storage).update_dialogue(256, "MNO".to_owned()), + ) + .await; + + // 1 - GKL, 11 - DEF, 256 - MNO + + check_dialogue("JKL", Arc::clone(&storage).remove_dialogue(1)).await; + check_dialogue("DEF", Arc::clone(&storage).remove_dialogue(11)).await; + check_dialogue("MNO", Arc::clone(&storage).remove_dialogue(256)).await; +} + +async fn check_dialogue( + expected: impl Into>, + actual: impl Future, E>>, +) where + E: Debug, +{ + assert_eq!(expected.into().map(ToOwned::to_owned), actual.await.unwrap()) +}