mirror of
https://github.com/teloxide/teloxide.git
synced 2025-01-24 17:22:43 +01:00
added dialogues + updated sqlite_remember_bot example.
This commit is contained in:
parent
3f1d1360c6
commit
6959d1c928
15 changed files with 743 additions and 90 deletions
|
@ -7,6 +7,7 @@ edition = "2018"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# You can also choose "cbor-serializer" or built-in JSON serializer
|
# You can also choose "cbor-serializer" or built-in JSON serializer
|
||||||
teloxide = { path = "../../", features = ["sqlite-storage", "bincode-serializer", "redis-storage", "macros", "auto-send"] }
|
teloxide = { path = "../../", features = ["sqlite-storage", "bincode-serializer", "redis-storage", "macros", "auto-send"] }
|
||||||
|
dptree = { path = "../../../chakka" }
|
||||||
|
|
||||||
log = "0.4.8"
|
log = "0.4.8"
|
||||||
pretty_env_logger = "0.4.0"
|
pretty_env_logger = "0.4.0"
|
||||||
|
@ -16,4 +17,3 @@ serde = "1.0.104"
|
||||||
futures = "0.3.5"
|
futures = "0.3.5"
|
||||||
|
|
||||||
thiserror = "1.0.15"
|
thiserror = "1.0.15"
|
||||||
derive_more = "0.99.9"
|
|
||||||
|
|
|
@ -1,19 +1,15 @@
|
||||||
#[macro_use]
|
|
||||||
extern crate derive_more;
|
|
||||||
|
|
||||||
mod states;
|
|
||||||
mod transitions;
|
|
||||||
|
|
||||||
use states::*;
|
|
||||||
|
|
||||||
use teloxide::{
|
use teloxide::{
|
||||||
dispatching::dialogue::{serializer::Json, SqliteStorage, Storage},
|
dispatching2::dialogue::{serializer::Json, SqliteStorage, Storage},
|
||||||
prelude::*,
|
prelude::*,
|
||||||
RequestError,
|
RequestError,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
type StorageError = <SqliteStorage<Json> as Storage<Dialogue>>::Error;
|
type Store = SqliteStorage<Json>;
|
||||||
|
// FIXME: naming
|
||||||
|
type MyDialogue = Dialogue<BotDialogue, Store>;
|
||||||
|
type StorageError = <SqliteStorage<Json> as Storage<BotDialogue>>::Error;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
enum Error {
|
enum Error {
|
||||||
|
@ -23,33 +19,69 @@ enum Error {
|
||||||
StorageError(#[from] StorageError),
|
StorageError(#[from] StorageError),
|
||||||
}
|
}
|
||||||
|
|
||||||
type In = DialogueWithCx<AutoSend<Bot>, Message, Dialogue, StorageError>;
|
#[derive(serde::Serialize, serde::Deserialize)]
|
||||||
|
pub enum BotDialogue {
|
||||||
|
Start,
|
||||||
|
HaveNumber(i32),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for BotDialogue {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::Start
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn handle_message(
|
async fn handle_message(
|
||||||
cx: UpdateWithCx<AutoSend<Bot>, Message>,
|
bot: Arc<AutoSend<Bot>>,
|
||||||
dialogue: Dialogue,
|
mes: Arc<Message>,
|
||||||
) -> TransitionOut<Dialogue> {
|
dialogue: Arc<MyDialogue>,
|
||||||
match cx.update.text().map(ToOwned::to_owned) {
|
) -> Result<(), Error> {
|
||||||
|
match mes.text() {
|
||||||
None => {
|
None => {
|
||||||
cx.answer("Send me a text message.").await?;
|
bot.send_message(mes.chat.id, "Send me a text message.").await?;
|
||||||
next(dialogue)
|
|
||||||
}
|
}
|
||||||
Some(ans) => dialogue.react(cx, ans).await,
|
Some(ans) => {
|
||||||
|
let state = dialogue.current_state_or_default().await?;
|
||||||
|
match state {
|
||||||
|
BotDialogue::Start => {
|
||||||
|
if let Ok(number) = ans.parse() {
|
||||||
|
dialogue.next(BotDialogue::HaveNumber(number)).await?;
|
||||||
|
bot.send_message(mes.chat.id, format!("Remembered number {}. Now use /get or /reset", number)).await?;
|
||||||
|
} else {
|
||||||
|
bot.send_message(mes.chat.id, "Please, send me a number").await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
BotDialogue::HaveNumber(num) => {
|
||||||
|
if ans.starts_with("/get") {
|
||||||
|
bot.send_message(mes.chat.id, format!("Here is your number: {}", num)).await?;
|
||||||
|
} else if ans.starts_with("/reset") {
|
||||||
|
dialogue.reset().await?;
|
||||||
|
bot.send_message(mes.chat.id, "Resetted number").await?;
|
||||||
|
} else {
|
||||||
|
bot.send_message(mes.chat.id, "Please, send /get or /reset").await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
let bot = Bot::from_env().auto_send();
|
let bot = Arc::new(Bot::from_env().auto_send());
|
||||||
|
let storage = SqliteStorage::open("db.sqlite", Json).await.unwrap();
|
||||||
|
|
||||||
Dispatcher::new(bot)
|
Dispatcher::new(bot)
|
||||||
.messages_handler(DialogueDispatcher::with_storage(
|
.dependencies({
|
||||||
|DialogueWithCx { cx, dialogue }: In| async move {
|
let mut map = dptree::di::DependencyMap::new();
|
||||||
let dialogue = dialogue.expect("std::convert::Infallible");
|
map.insert_arc(storage);
|
||||||
handle_message(cx, dialogue).await.expect("Something wrong with the bot!")
|
map
|
||||||
},
|
})
|
||||||
SqliteStorage::open("db.sqlite", Json).await.unwrap(),
|
.messages_handler(|h| {
|
||||||
))
|
h.add_dialogue::<Message, Store, BotDialogue>()
|
||||||
|
.branch(dptree::endpoint(handle_message))
|
||||||
|
})
|
||||||
.dispatch()
|
.dispatch()
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
use teloxide::macros::Transition;
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[derive(Transition, From, Serialize, Deserialize)]
|
|
||||||
pub enum Dialogue {
|
|
||||||
Start(StartState),
|
|
||||||
HaveNumber(HaveNumberState),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for Dialogue {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::Start(StartState)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
pub struct StartState;
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
pub struct HaveNumberState {
|
|
||||||
pub number: i32,
|
|
||||||
}
|
|
|
@ -1,39 +0,0 @@
|
||||||
use teloxide::prelude::*;
|
|
||||||
use teloxide::macros::teloxide;
|
|
||||||
|
|
||||||
use super::states::*;
|
|
||||||
|
|
||||||
#[teloxide(subtransition)]
|
|
||||||
async fn start(
|
|
||||||
state: StartState,
|
|
||||||
cx: TransitionIn<AutoSend<Bot>>,
|
|
||||||
ans: String,
|
|
||||||
) -> TransitionOut<Dialogue> {
|
|
||||||
if let Ok(number) = ans.parse() {
|
|
||||||
cx.answer(format!("Remembered number {}. Now use /get or /reset", number)).await?;
|
|
||||||
next(HaveNumberState { number })
|
|
||||||
} else {
|
|
||||||
cx.answer("Please, send me a number").await?;
|
|
||||||
next(state)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[teloxide(subtransition)]
|
|
||||||
async fn have_number(
|
|
||||||
state: HaveNumberState,
|
|
||||||
cx: TransitionIn<AutoSend<Bot>>,
|
|
||||||
ans: String,
|
|
||||||
) -> TransitionOut<Dialogue> {
|
|
||||||
let num = state.number;
|
|
||||||
|
|
||||||
if ans.starts_with("/get") {
|
|
||||||
cx.answer(format!("Here is your number: {}", num)).await?;
|
|
||||||
next(state)
|
|
||||||
} else if ans.starts_with("/reset") {
|
|
||||||
cx.answer("Resetted number").await?;
|
|
||||||
next(StartState)
|
|
||||||
} else {
|
|
||||||
cx.answer("Please, send /get or /reset").await?;
|
|
||||||
next(state)
|
|
||||||
}
|
|
||||||
}
|
|
29
src/dispatching2/dialogue/dialogue_handler_ext.rs
Normal file
29
src/dispatching2/dialogue/dialogue_handler_ext.rs
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
use crate::dispatching2::dialogue::{get_chat_id::GetChatId, Dialogue, Storage};
|
||||||
|
use dptree::{di::DependencyMap, Handler, Insert};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
pub trait DialogueHandlerExt {
|
||||||
|
fn add_dialogue<Upd, S, D>(self) -> Self
|
||||||
|
where
|
||||||
|
S: Storage<D> + Send + Sync + 'static,
|
||||||
|
D: Send + Sync + 'static,
|
||||||
|
Upd: GetChatId + Send + Sync + 'static;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, Output> DialogueHandlerExt for Handler<'a, DependencyMap, Output>
|
||||||
|
where
|
||||||
|
Output: Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
fn add_dialogue<Upd, S, D>(self) -> Self
|
||||||
|
where
|
||||||
|
// FIXME: some of this requirements are useless.
|
||||||
|
S: Storage<D> + Send + Sync + 'static,
|
||||||
|
D: Send + Sync + 'static,
|
||||||
|
Upd: GetChatId + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
self.chain(dptree::map(|storage: Arc<S>, upd: Arc<Upd>| async move {
|
||||||
|
let chat_id = upd.chat_id()?;
|
||||||
|
Dialogue::new(storage, chat_id).ok()
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
20
src/dispatching2/dialogue/get_chat_id.rs
Normal file
20
src/dispatching2/dialogue/get_chat_id.rs
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
use crate::types::CallbackQuery;
|
||||||
|
use teloxide_core::types::Message;
|
||||||
|
|
||||||
|
/// Something that maybe has a chat ID.
|
||||||
|
pub trait GetChatId {
|
||||||
|
#[must_use]
|
||||||
|
fn chat_id(&self) -> Option<i64>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GetChatId for Message {
|
||||||
|
fn chat_id(&self) -> Option<i64> {
|
||||||
|
Some(self.chat.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GetChatId for CallbackQuery {
|
||||||
|
fn chat_id(&self) -> Option<i64> {
|
||||||
|
self.message.as_ref().map(|mes| mes.chat.id)
|
||||||
|
}
|
||||||
|
}
|
85
src/dispatching2/dialogue/mod.rs
Normal file
85
src/dispatching2/dialogue/mod.rs
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
#[cfg(feature = "redis-storage")]
|
||||||
|
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "redis-storage")))]
|
||||||
|
pub use storage::{RedisStorage, RedisStorageError};
|
||||||
|
|
||||||
|
#[cfg(feature = "sqlite-storage")]
|
||||||
|
pub use storage::{SqliteStorage, SqliteStorageError};
|
||||||
|
|
||||||
|
pub use storage::{serializer, InMemStorage, InMemStorageError, Serializer, Storage, TraceStorage};
|
||||||
|
|
||||||
|
pub use dialogue_handler_ext::DialogueHandlerExt;
|
||||||
|
|
||||||
|
use std::{future::Future, marker::PhantomData, sync::Arc};
|
||||||
|
|
||||||
|
mod dialogue_handler_ext;
|
||||||
|
mod get_chat_id;
|
||||||
|
mod storage;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Dialogue<D, S> {
|
||||||
|
// Maybe it's better to use Box<dyn Storage<D, Err>> here but it's require
|
||||||
|
// us to introduce `Err` generic parameter.
|
||||||
|
storage: Arc<S>,
|
||||||
|
chat_id: i64,
|
||||||
|
_phantom: PhantomData<D>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D, S> Dialogue<D, S>
|
||||||
|
where
|
||||||
|
D: Send + 'static,
|
||||||
|
S: Storage<D>,
|
||||||
|
{
|
||||||
|
pub fn new(storage: Arc<S>, chat_id: i64) -> Result<Self, S::Error> {
|
||||||
|
Ok(Self { storage, chat_id, _phantom: PhantomData })
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Cache this.
|
||||||
|
pub async fn current_state(&self) -> Result<Option<D>, S::Error> {
|
||||||
|
self.storage.clone().get_dialogue(self.chat_id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn current_state_or_default(&self) -> Result<D, S::Error>
|
||||||
|
where
|
||||||
|
D: Default,
|
||||||
|
{
|
||||||
|
match self.storage.clone().get_dialogue(self.chat_id).await? {
|
||||||
|
Some(d) => Ok(d),
|
||||||
|
None => {
|
||||||
|
self.storage.clone().update_dialogue(self.chat_id, D::default()).await?;
|
||||||
|
Ok(D::default())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn next<State>(&self, state: State) -> Result<(), S::Error>
|
||||||
|
where
|
||||||
|
D: From<State>,
|
||||||
|
{
|
||||||
|
let new_dialogue = state.into();
|
||||||
|
self.storage.clone().update_dialogue(self.chat_id, new_dialogue).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn with<F, Fut, State>(&self, f: F) -> Result<(), S::Error>
|
||||||
|
where
|
||||||
|
F: FnOnce(Option<D>) -> Fut,
|
||||||
|
Fut: Future<Output = State>,
|
||||||
|
D: From<State>,
|
||||||
|
{
|
||||||
|
let current_dialogue = self.current_state().await?;
|
||||||
|
let new_dialogue = f(current_dialogue).await.into();
|
||||||
|
self.storage.clone().update_dialogue(self.chat_id, new_dialogue).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn reset(&self) -> Result<(), S::Error>
|
||||||
|
where
|
||||||
|
D: Default,
|
||||||
|
{
|
||||||
|
self.next(D::default()).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn exit(&self) -> Result<(), S::Error> {
|
||||||
|
self.storage.clone().remove_dialogue(self.chat_id).await
|
||||||
|
}
|
||||||
|
}
|
73
src/dispatching2/dialogue/storage/in_mem_storage.rs
Normal file
73
src/dispatching2/dialogue/storage/in_mem_storage.rs
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
use super::Storage;
|
||||||
|
use futures::future::BoxFuture;
|
||||||
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
/// An error returned from [`InMemStorage`].
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum InMemStorageError {
|
||||||
|
/// Returned from [`InMemStorage::remove_dialogue`].
|
||||||
|
#[error("row not found")]
|
||||||
|
DialogueNotFound,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A dialogue storage based on [`std::collections::HashMap`].
|
||||||
|
///
|
||||||
|
/// ## Note
|
||||||
|
/// All your dialogues will be lost after you restart your bot. If you need to
|
||||||
|
/// store them somewhere on a drive, you should use e.g.
|
||||||
|
/// [`super::SqliteStorage`] or implement your own.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct InMemStorage<D> {
|
||||||
|
map: Mutex<HashMap<i64, D>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> InMemStorage<S> {
|
||||||
|
#[must_use]
|
||||||
|
pub fn new() -> Arc<Self> {
|
||||||
|
Arc::new(Self { map: Mutex::new(HashMap::new()) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D> Storage<D> for InMemStorage<D>
|
||||||
|
where
|
||||||
|
D: Clone,
|
||||||
|
D: Send + 'static,
|
||||||
|
{
|
||||||
|
type Error = InMemStorageError;
|
||||||
|
|
||||||
|
fn remove_dialogue(self: Arc<Self>, chat_id: i64) -> BoxFuture<'static, Result<(), Self::Error>>
|
||||||
|
where
|
||||||
|
D: Send + 'static,
|
||||||
|
{
|
||||||
|
Box::pin(async move {
|
||||||
|
self.map
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.remove(&chat_id)
|
||||||
|
.map_or(Err(InMemStorageError::DialogueNotFound), |_| Ok(()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_dialogue(
|
||||||
|
self: Arc<Self>,
|
||||||
|
chat_id: i64,
|
||||||
|
dialogue: D,
|
||||||
|
) -> BoxFuture<'static, Result<(), Self::Error>>
|
||||||
|
where
|
||||||
|
D: Send + 'static,
|
||||||
|
{
|
||||||
|
Box::pin(async move {
|
||||||
|
self.map.lock().await.insert(chat_id, dialogue);
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_dialogue(
|
||||||
|
self: Arc<Self>,
|
||||||
|
chat_id: i64,
|
||||||
|
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
|
||||||
|
Box::pin(async move { Ok(self.map.lock().await.get(&chat_id).map(ToOwned::to_owned)) })
|
||||||
|
}
|
||||||
|
}
|
76
src/dispatching2/dialogue/storage/mod.rs
Normal file
76
src/dispatching2/dialogue/storage/mod.rs
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
pub mod serializer;
|
||||||
|
|
||||||
|
mod in_mem_storage;
|
||||||
|
mod trace_storage;
|
||||||
|
|
||||||
|
#[cfg(feature = "redis-storage")]
|
||||||
|
mod redis_storage;
|
||||||
|
|
||||||
|
#[cfg(feature = "sqlite-storage")]
|
||||||
|
mod sqlite_storage;
|
||||||
|
|
||||||
|
use futures::future::BoxFuture;
|
||||||
|
|
||||||
|
pub use self::{
|
||||||
|
in_mem_storage::{InMemStorage, InMemStorageError},
|
||||||
|
trace_storage::TraceStorage,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(feature = "redis-storage")]
|
||||||
|
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "redis-storage")))]
|
||||||
|
pub use redis_storage::{RedisStorage, RedisStorageError};
|
||||||
|
pub use serializer::Serializer;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[cfg(feature = "sqlite-storage")]
|
||||||
|
pub use sqlite_storage::{SqliteStorage, SqliteStorageError};
|
||||||
|
|
||||||
|
/// A storage of dialogues.
|
||||||
|
///
|
||||||
|
/// You can implement this trait for a structure that communicates with a DB and
|
||||||
|
/// be sure that after you restart your bot, all the dialogues won't be lost.
|
||||||
|
///
|
||||||
|
/// `Storage` is used only to store dialogue states, i.e. it can't be used as a
|
||||||
|
/// generic database.
|
||||||
|
///
|
||||||
|
/// Currently we support the following storages out of the box:
|
||||||
|
///
|
||||||
|
/// - [`InMemStorage`] -- a storage based on [`std::collections::HashMap`].
|
||||||
|
/// - [`RedisStorage`] -- a Redis-based storage.
|
||||||
|
/// - [`SqliteStorage`] -- an SQLite-based persistent storage.
|
||||||
|
///
|
||||||
|
/// [`InMemStorage`]: crate::dispatching::dialogue::InMemStorage
|
||||||
|
/// [`RedisStorage`]: crate::dispatching::dialogue::RedisStorage
|
||||||
|
/// [`SqliteStorage`]: crate::dispatching::dialogue::SqliteStorage
|
||||||
|
pub trait Storage<D> {
|
||||||
|
type Error;
|
||||||
|
|
||||||
|
/// Removes a dialogue indexed by `chat_id`.
|
||||||
|
///
|
||||||
|
/// If the dialogue indexed by `chat_id` does not exist, this function
|
||||||
|
/// results in an error.
|
||||||
|
#[must_use = "Futures are lazy and do nothing unless polled with .await"]
|
||||||
|
fn remove_dialogue(
|
||||||
|
self: Arc<Self>,
|
||||||
|
chat_id: i64,
|
||||||
|
) -> BoxFuture<'static, Result<(), Self::Error>>
|
||||||
|
where
|
||||||
|
D: Send + 'static;
|
||||||
|
|
||||||
|
/// Updates a dialogue indexed by `chat_id` with `dialogue`.
|
||||||
|
#[must_use = "Futures are lazy and do nothing unless polled with .await"]
|
||||||
|
fn update_dialogue(
|
||||||
|
self: Arc<Self>,
|
||||||
|
chat_id: i64,
|
||||||
|
dialogue: D,
|
||||||
|
) -> BoxFuture<'static, Result<(), Self::Error>>
|
||||||
|
where
|
||||||
|
D: Send + 'static;
|
||||||
|
|
||||||
|
/// Returns the dialogue indexed by `chat_id`.
|
||||||
|
#[must_use = "Futures are lazy and do nothing unless polled with .await"]
|
||||||
|
fn get_dialogue(
|
||||||
|
self: Arc<Self>,
|
||||||
|
chat_id: i64,
|
||||||
|
) -> BoxFuture<'static, Result<Option<D>, Self::Error>>;
|
||||||
|
}
|
110
src/dispatching2/dialogue/storage/redis_storage.rs
Normal file
110
src/dispatching2/dialogue/storage/redis_storage.rs
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
use super::{serializer::Serializer, Storage};
|
||||||
|
use futures::future::BoxFuture;
|
||||||
|
use redis::{AsyncCommands, IntoConnectionInfo};
|
||||||
|
use serde::{de::DeserializeOwned, Serialize};
|
||||||
|
use std::{
|
||||||
|
convert::Infallible,
|
||||||
|
fmt::{Debug, Display},
|
||||||
|
ops::DerefMut,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
/// An error returned from [`RedisStorage`].
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum RedisStorageError<SE>
|
||||||
|
where
|
||||||
|
SE: Debug + Display,
|
||||||
|
{
|
||||||
|
#[error("parsing/serializing error: {0}")]
|
||||||
|
SerdeError(SE),
|
||||||
|
|
||||||
|
#[error("error from Redis: {0}")]
|
||||||
|
RedisError(#[from] redis::RedisError),
|
||||||
|
|
||||||
|
/// Returned from [`RedisStorage::remove_dialogue`].
|
||||||
|
#[error("row not found")]
|
||||||
|
DialogueNotFound,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A dialogue storage based on [Redis](https://redis.io/).
|
||||||
|
pub struct RedisStorage<S> {
|
||||||
|
conn: Mutex<redis::aio::Connection>,
|
||||||
|
serializer: S,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> RedisStorage<S> {
|
||||||
|
pub async fn open(
|
||||||
|
url: impl IntoConnectionInfo,
|
||||||
|
serializer: S,
|
||||||
|
) -> Result<Arc<Self>, RedisStorageError<Infallible>> {
|
||||||
|
Ok(Arc::new(Self {
|
||||||
|
conn: Mutex::new(redis::Client::open(url)?.get_async_connection().await?),
|
||||||
|
serializer,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, D> Storage<D> for RedisStorage<S>
|
||||||
|
where
|
||||||
|
S: Send + Sync + Serializer<D> + 'static,
|
||||||
|
D: Send + Serialize + DeserializeOwned + 'static,
|
||||||
|
<S as Serializer<D>>::Error: Debug + Display,
|
||||||
|
{
|
||||||
|
type Error = RedisStorageError<<S as Serializer<D>>::Error>;
|
||||||
|
|
||||||
|
fn remove_dialogue(
|
||||||
|
self: Arc<Self>,
|
||||||
|
chat_id: i64,
|
||||||
|
) -> BoxFuture<'static, Result<(), Self::Error>> {
|
||||||
|
Box::pin(async move {
|
||||||
|
let deleted_rows_count = redis::pipe()
|
||||||
|
.atomic()
|
||||||
|
.del(chat_id)
|
||||||
|
.query_async::<_, redis::Value>(self.conn.lock().await.deref_mut())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if let redis::Value::Bulk(values) = deleted_rows_count {
|
||||||
|
// False positive
|
||||||
|
#[allow(clippy::collapsible_match)]
|
||||||
|
if let redis::Value::Int(deleted_rows_count) = values[0] {
|
||||||
|
match deleted_rows_count {
|
||||||
|
0 => return Err(RedisStorageError::DialogueNotFound),
|
||||||
|
_ => return Ok(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unreachable!("Must return redis::Value::Bulk(redis::Value::Int(_))");
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_dialogue(
|
||||||
|
self: Arc<Self>,
|
||||||
|
chat_id: i64,
|
||||||
|
dialogue: D,
|
||||||
|
) -> BoxFuture<'static, Result<(), Self::Error>> {
|
||||||
|
Box::pin(async move {
|
||||||
|
let dialogue =
|
||||||
|
self.serializer.serialize(&dialogue).map_err(RedisStorageError::SerdeError)?;
|
||||||
|
self.conn.lock().await.set::<_, Vec<u8>, _>(chat_id, dialogue).await?;
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_dialogue(
|
||||||
|
self: Arc<Self>,
|
||||||
|
chat_id: i64,
|
||||||
|
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
|
||||||
|
Box::pin(async move {
|
||||||
|
self.conn
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.get::<_, Option<Vec<u8>>>(chat_id)
|
||||||
|
.await?
|
||||||
|
.map(|d| self.serializer.deserialize(&d).map_err(RedisStorageError::SerdeError))
|
||||||
|
.transpose()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
77
src/dispatching2/dialogue/storage/serializer.rs
Normal file
77
src/dispatching2/dialogue/storage/serializer.rs
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
//! Various serializers for dialogue storages.
|
||||||
|
|
||||||
|
use serde::{de::DeserializeOwned, ser::Serialize};
|
||||||
|
|
||||||
|
/// A serializer for memory storages.
|
||||||
|
pub trait Serializer<D> {
|
||||||
|
type Error;
|
||||||
|
|
||||||
|
fn serialize(&self, val: &D) -> Result<Vec<u8>, Self::Error>;
|
||||||
|
fn deserialize(&self, data: &[u8]) -> Result<D, Self::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The JSON serializer for memory storages.
|
||||||
|
pub struct Json;
|
||||||
|
|
||||||
|
impl<D> Serializer<D> for Json
|
||||||
|
where
|
||||||
|
D: Serialize + DeserializeOwned,
|
||||||
|
{
|
||||||
|
type Error = serde_json::Error;
|
||||||
|
|
||||||
|
fn serialize(&self, val: &D) -> Result<Vec<u8>, Self::Error> {
|
||||||
|
serde_json::to_vec(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deserialize(&self, data: &[u8]) -> Result<D, Self::Error> {
|
||||||
|
serde_json::from_slice(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The [CBOR] serializer for memory storages.
|
||||||
|
///
|
||||||
|
/// [CBOR]: https://en.wikipedia.org/wiki/CBOR
|
||||||
|
#[cfg(feature = "cbor-serializer")]
|
||||||
|
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "cbor-serializer")))]
|
||||||
|
pub struct Cbor;
|
||||||
|
|
||||||
|
#[cfg(feature = "cbor-serializer")]
|
||||||
|
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "cbor-serializer")))]
|
||||||
|
impl<D> Serializer<D> for Cbor
|
||||||
|
where
|
||||||
|
D: Serialize + DeserializeOwned,
|
||||||
|
{
|
||||||
|
type Error = serde_cbor::Error;
|
||||||
|
|
||||||
|
fn serialize(&self, val: &D) -> Result<Vec<u8>, Self::Error> {
|
||||||
|
serde_cbor::to_vec(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deserialize(&self, data: &[u8]) -> Result<D, Self::Error> {
|
||||||
|
serde_cbor::from_slice(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The [Bincode] serializer for memory storages.
|
||||||
|
///
|
||||||
|
/// [Bincode]: https://github.com/servo/bincode
|
||||||
|
#[cfg(feature = "bincode-serializer")]
|
||||||
|
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "bincode-serializer")))]
|
||||||
|
pub struct Bincode;
|
||||||
|
|
||||||
|
#[cfg(feature = "bincode-serializer")]
|
||||||
|
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "bincode-serializer")))]
|
||||||
|
impl<D> Serializer<D> for Bincode
|
||||||
|
where
|
||||||
|
D: Serialize + DeserializeOwned,
|
||||||
|
{
|
||||||
|
type Error = bincode::Error;
|
||||||
|
|
||||||
|
fn serialize(&self, val: &D) -> Result<Vec<u8>, Self::Error> {
|
||||||
|
bincode::serialize(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deserialize(&self, data: &[u8]) -> Result<D, Self::Error> {
|
||||||
|
bincode::deserialize(data)
|
||||||
|
}
|
||||||
|
}
|
141
src/dispatching2/dialogue/storage/sqlite_storage.rs
Normal file
141
src/dispatching2/dialogue/storage/sqlite_storage.rs
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
use super::{serializer::Serializer, Storage};
|
||||||
|
use futures::future::BoxFuture;
|
||||||
|
use serde::{de::DeserializeOwned, Serialize};
|
||||||
|
use sqlx::{sqlite::SqlitePool, Executor};
|
||||||
|
use std::{
|
||||||
|
convert::Infallible,
|
||||||
|
fmt::{Debug, Display},
|
||||||
|
str,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// A persistent dialogue storage based on [SQLite](https://www.sqlite.org/).
|
||||||
|
pub struct SqliteStorage<S> {
|
||||||
|
pool: SqlitePool,
|
||||||
|
serializer: S,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An error returned from [`SqliteStorage`].
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum SqliteStorageError<SE>
|
||||||
|
where
|
||||||
|
SE: Debug + Display,
|
||||||
|
{
|
||||||
|
#[error("dialogue serialization error: {0}")]
|
||||||
|
SerdeError(SE),
|
||||||
|
|
||||||
|
#[error("sqlite error: {0}")]
|
||||||
|
SqliteError(#[from] sqlx::Error),
|
||||||
|
|
||||||
|
/// Returned from [`SqliteStorage::remove_dialogue`].
|
||||||
|
#[error("row not found")]
|
||||||
|
DialogueNotFound,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> SqliteStorage<S> {
|
||||||
|
pub async fn open(
|
||||||
|
path: &str,
|
||||||
|
serializer: S,
|
||||||
|
) -> Result<Arc<Self>, SqliteStorageError<Infallible>> {
|
||||||
|
let pool = SqlitePool::connect(format!("sqlite:{}?mode=rwc", path).as_str()).await?;
|
||||||
|
let mut conn = pool.acquire().await?;
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS teloxide_dialogues (
|
||||||
|
chat_id BIGINT PRIMARY KEY,
|
||||||
|
dialogue BLOB NOT NULL
|
||||||
|
);
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.execute(&mut conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Arc::new(Self { pool, serializer }))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, D> Storage<D> for SqliteStorage<S>
|
||||||
|
where
|
||||||
|
S: Send + Sync + Serializer<D> + 'static,
|
||||||
|
D: Send + Serialize + DeserializeOwned + 'static,
|
||||||
|
<S as Serializer<D>>::Error: Debug + Display,
|
||||||
|
{
|
||||||
|
type Error = SqliteStorageError<<S as Serializer<D>>::Error>;
|
||||||
|
|
||||||
|
/// Returns [`sqlx::Error::RowNotFound`] if a dialogue does not exist.
|
||||||
|
fn remove_dialogue(
|
||||||
|
self: Arc<Self>,
|
||||||
|
chat_id: i64,
|
||||||
|
) -> BoxFuture<'static, Result<(), Self::Error>> {
|
||||||
|
Box::pin(async move {
|
||||||
|
let deleted_rows_count =
|
||||||
|
sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?")
|
||||||
|
.bind(chat_id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?
|
||||||
|
.rows_affected();
|
||||||
|
|
||||||
|
if deleted_rows_count == 0 {
|
||||||
|
return Err(SqliteStorageError::DialogueNotFound);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_dialogue(
|
||||||
|
self: Arc<Self>,
|
||||||
|
chat_id: i64,
|
||||||
|
dialogue: D,
|
||||||
|
) -> BoxFuture<'static, Result<(), Self::Error>> {
|
||||||
|
Box::pin(async move {
|
||||||
|
let d = self.serializer.serialize(&dialogue).map_err(SqliteStorageError::SerdeError)?;
|
||||||
|
|
||||||
|
self.pool
|
||||||
|
.acquire()
|
||||||
|
.await?
|
||||||
|
.execute(
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
INSERT INTO teloxide_dialogues VALUES (?, ?)
|
||||||
|
ON CONFLICT(chat_id) DO UPDATE SET dialogue=excluded.dialogue
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(chat_id)
|
||||||
|
.bind(d),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_dialogue(
|
||||||
|
self: Arc<Self>,
|
||||||
|
chat_id: i64,
|
||||||
|
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
|
||||||
|
Box::pin(async move {
|
||||||
|
get_dialogue(&self.pool, chat_id)
|
||||||
|
.await?
|
||||||
|
.map(|d| self.serializer.deserialize(&d).map_err(SqliteStorageError::SerdeError))
|
||||||
|
.transpose()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_dialogue(pool: &SqlitePool, chat_id: i64) -> Result<Option<Vec<u8>>, sqlx::Error> {
|
||||||
|
#[derive(sqlx::FromRow)]
|
||||||
|
struct DialogueDbRow {
|
||||||
|
dialogue: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
let bytes = sqlx::query_as::<_, DialogueDbRow>(
|
||||||
|
"SELECT dialogue FROM teloxide_dialogues WHERE chat_id = ?",
|
||||||
|
)
|
||||||
|
.bind(chat_id)
|
||||||
|
.fetch_optional(pool)
|
||||||
|
.await?
|
||||||
|
.map(|r| r.dialogue);
|
||||||
|
|
||||||
|
Ok(bytes)
|
||||||
|
}
|
68
src/dispatching2/dialogue/storage/trace_storage.rs
Normal file
68
src/dispatching2/dialogue/storage/trace_storage.rs
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
use std::{
|
||||||
|
fmt::Debug,
|
||||||
|
marker::{Send, Sync},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
|
use futures::future::BoxFuture;
|
||||||
|
|
||||||
|
use crate::dispatching::dialogue::Storage;
|
||||||
|
|
||||||
|
/// A dialogue storage wrapper which logs all actions performed on an underlying
|
||||||
|
/// storage.
|
||||||
|
///
|
||||||
|
/// Reports about any dialogue action via [`log::Level::Trace`].
|
||||||
|
pub struct TraceStorage<S> {
|
||||||
|
inner: Arc<S>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> TraceStorage<S> {
|
||||||
|
#[must_use]
|
||||||
|
pub fn new(inner: Arc<S>) -> Arc<Self> {
|
||||||
|
Arc::new(Self { inner })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_inner(self) -> Arc<S> {
|
||||||
|
self.inner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, D> Storage<D> for TraceStorage<S>
|
||||||
|
where
|
||||||
|
D: Debug,
|
||||||
|
S: Storage<D> + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
type Error = <S as Storage<D>>::Error;
|
||||||
|
|
||||||
|
fn remove_dialogue(self: Arc<Self>, chat_id: i64) -> BoxFuture<'static, Result<(), Self::Error>>
|
||||||
|
where
|
||||||
|
D: Send + 'static,
|
||||||
|
{
|
||||||
|
log::trace!("Removing dialogue #{}", chat_id);
|
||||||
|
<S as Storage<D>>::remove_dialogue(self.inner.clone(), chat_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_dialogue(
|
||||||
|
self: Arc<Self>,
|
||||||
|
chat_id: i64,
|
||||||
|
dialogue: D,
|
||||||
|
) -> BoxFuture<'static, Result<(), Self::Error>>
|
||||||
|
where
|
||||||
|
D: Send + 'static,
|
||||||
|
{
|
||||||
|
Box::pin(async move {
|
||||||
|
let to = format!("{:#?}", dialogue);
|
||||||
|
<S as Storage<D>>::update_dialogue(self.inner.clone(), chat_id, dialogue).await?;
|
||||||
|
log::trace!("Updated a dialogue #{}: {:#?}", chat_id, to);
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_dialogue(
|
||||||
|
self: Arc<Self>,
|
||||||
|
chat_id: i64,
|
||||||
|
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
|
||||||
|
log::trace!("Requested a dialogue #{}", chat_id);
|
||||||
|
<S as Storage<D>>::get_dialogue(self.inner.clone(), chat_id)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
pub(crate) mod repls;
|
pub(crate) mod repls;
|
||||||
|
|
||||||
|
pub mod dialogue;
|
||||||
mod dispatcher;
|
mod dispatcher;
|
||||||
|
|
||||||
pub use dispatcher::Dispatcher;
|
pub use dispatcher::Dispatcher;
|
||||||
|
|
|
@ -15,7 +15,10 @@ pub use crate::dispatching::{
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(not(feature = "old_dispatching"))]
|
#[cfg(not(feature = "old_dispatching"))]
|
||||||
pub use crate::dispatching2::Dispatcher;
|
pub use crate::dispatching2::{
|
||||||
|
dialogue::{Dialogue, DialogueHandlerExt as _},
|
||||||
|
Dispatcher,
|
||||||
|
};
|
||||||
|
|
||||||
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "macros")))]
|
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "macros")))]
|
||||||
#[cfg(feature = "macros")]
|
#[cfg(feature = "macros")]
|
||||||
|
|
Loading…
Add table
Reference in a new issue