mirror of
https://github.com/teloxide/teloxide.git
synced 2025-01-08 19:33:53 +01:00
added dialogues + updated sqlite_remember_bot example.
This commit is contained in:
parent
3f1d1360c6
commit
6959d1c928
15 changed files with 743 additions and 90 deletions
|
@ -7,6 +7,7 @@ edition = "2018"
|
|||
[dependencies]
|
||||
# You can also choose "cbor-serializer" or built-in JSON serializer
|
||||
teloxide = { path = "../../", features = ["sqlite-storage", "bincode-serializer", "redis-storage", "macros", "auto-send"] }
|
||||
dptree = { path = "../../../chakka" }
|
||||
|
||||
log = "0.4.8"
|
||||
pretty_env_logger = "0.4.0"
|
||||
|
@ -16,4 +17,3 @@ serde = "1.0.104"
|
|||
futures = "0.3.5"
|
||||
|
||||
thiserror = "1.0.15"
|
||||
derive_more = "0.99.9"
|
||||
|
|
|
@ -1,19 +1,15 @@
|
|||
#[macro_use]
|
||||
extern crate derive_more;
|
||||
|
||||
mod states;
|
||||
mod transitions;
|
||||
|
||||
use states::*;
|
||||
|
||||
use teloxide::{
|
||||
dispatching::dialogue::{serializer::Json, SqliteStorage, Storage},
|
||||
dispatching2::dialogue::{serializer::Json, SqliteStorage, Storage},
|
||||
prelude::*,
|
||||
RequestError,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use std::sync::Arc;
|
||||
|
||||
type StorageError = <SqliteStorage<Json> as Storage<Dialogue>>::Error;
|
||||
type Store = SqliteStorage<Json>;
|
||||
// FIXME: naming
|
||||
type MyDialogue = Dialogue<BotDialogue, Store>;
|
||||
type StorageError = <SqliteStorage<Json> as Storage<BotDialogue>>::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum Error {
|
||||
|
@ -23,33 +19,69 @@ enum Error {
|
|||
StorageError(#[from] StorageError),
|
||||
}
|
||||
|
||||
type In = DialogueWithCx<AutoSend<Bot>, Message, Dialogue, StorageError>;
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
pub enum BotDialogue {
|
||||
Start,
|
||||
HaveNumber(i32),
|
||||
}
|
||||
|
||||
impl Default for BotDialogue {
|
||||
fn default() -> Self {
|
||||
Self::Start
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_message(
|
||||
cx: UpdateWithCx<AutoSend<Bot>, Message>,
|
||||
dialogue: Dialogue,
|
||||
) -> TransitionOut<Dialogue> {
|
||||
match cx.update.text().map(ToOwned::to_owned) {
|
||||
bot: Arc<AutoSend<Bot>>,
|
||||
mes: Arc<Message>,
|
||||
dialogue: Arc<MyDialogue>,
|
||||
) -> Result<(), Error> {
|
||||
match mes.text() {
|
||||
None => {
|
||||
cx.answer("Send me a text message.").await?;
|
||||
next(dialogue)
|
||||
bot.send_message(mes.chat.id, "Send me a text message.").await?;
|
||||
}
|
||||
Some(ans) => dialogue.react(cx, ans).await,
|
||||
Some(ans) => {
|
||||
let state = dialogue.current_state_or_default().await?;
|
||||
match state {
|
||||
BotDialogue::Start => {
|
||||
if let Ok(number) = ans.parse() {
|
||||
dialogue.next(BotDialogue::HaveNumber(number)).await?;
|
||||
bot.send_message(mes.chat.id, format!("Remembered number {}. Now use /get or /reset", number)).await?;
|
||||
} else {
|
||||
bot.send_message(mes.chat.id, "Please, send me a number").await?;
|
||||
}
|
||||
}
|
||||
BotDialogue::HaveNumber(num) => {
|
||||
if ans.starts_with("/get") {
|
||||
bot.send_message(mes.chat.id, format!("Here is your number: {}", num)).await?;
|
||||
} else if ans.starts_with("/reset") {
|
||||
dialogue.reset().await?;
|
||||
bot.send_message(mes.chat.id, "Resetted number").await?;
|
||||
} else {
|
||||
bot.send_message(mes.chat.id, "Please, send /get or /reset").await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let bot = Bot::from_env().auto_send();
|
||||
let bot = Arc::new(Bot::from_env().auto_send());
|
||||
let storage = SqliteStorage::open("db.sqlite", Json).await.unwrap();
|
||||
|
||||
Dispatcher::new(bot)
|
||||
.messages_handler(DialogueDispatcher::with_storage(
|
||||
|DialogueWithCx { cx, dialogue }: In| async move {
|
||||
let dialogue = dialogue.expect("std::convert::Infallible");
|
||||
handle_message(cx, dialogue).await.expect("Something wrong with the bot!")
|
||||
},
|
||||
SqliteStorage::open("db.sqlite", Json).await.unwrap(),
|
||||
))
|
||||
.dependencies({
|
||||
let mut map = dptree::di::DependencyMap::new();
|
||||
map.insert_arc(storage);
|
||||
map
|
||||
})
|
||||
.messages_handler(|h| {
|
||||
h.add_dialogue::<Message, Store, BotDialogue>()
|
||||
.branch(dptree::endpoint(handle_message))
|
||||
})
|
||||
.dispatch()
|
||||
.await;
|
||||
}
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
use teloxide::macros::Transition;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Transition, From, Serialize, Deserialize)]
|
||||
pub enum Dialogue {
|
||||
Start(StartState),
|
||||
HaveNumber(HaveNumberState),
|
||||
}
|
||||
|
||||
impl Default for Dialogue {
|
||||
fn default() -> Self {
|
||||
Self::Start(StartState)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct StartState;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct HaveNumberState {
|
||||
pub number: i32,
|
||||
}
|
|
@ -1,39 +0,0 @@
|
|||
use teloxide::prelude::*;
|
||||
use teloxide::macros::teloxide;
|
||||
|
||||
use super::states::*;
|
||||
|
||||
#[teloxide(subtransition)]
|
||||
async fn start(
|
||||
state: StartState,
|
||||
cx: TransitionIn<AutoSend<Bot>>,
|
||||
ans: String,
|
||||
) -> TransitionOut<Dialogue> {
|
||||
if let Ok(number) = ans.parse() {
|
||||
cx.answer(format!("Remembered number {}. Now use /get or /reset", number)).await?;
|
||||
next(HaveNumberState { number })
|
||||
} else {
|
||||
cx.answer("Please, send me a number").await?;
|
||||
next(state)
|
||||
}
|
||||
}
|
||||
|
||||
#[teloxide(subtransition)]
|
||||
async fn have_number(
|
||||
state: HaveNumberState,
|
||||
cx: TransitionIn<AutoSend<Bot>>,
|
||||
ans: String,
|
||||
) -> TransitionOut<Dialogue> {
|
||||
let num = state.number;
|
||||
|
||||
if ans.starts_with("/get") {
|
||||
cx.answer(format!("Here is your number: {}", num)).await?;
|
||||
next(state)
|
||||
} else if ans.starts_with("/reset") {
|
||||
cx.answer("Resetted number").await?;
|
||||
next(StartState)
|
||||
} else {
|
||||
cx.answer("Please, send /get or /reset").await?;
|
||||
next(state)
|
||||
}
|
||||
}
|
29
src/dispatching2/dialogue/dialogue_handler_ext.rs
Normal file
29
src/dispatching2/dialogue/dialogue_handler_ext.rs
Normal file
|
@ -0,0 +1,29 @@
|
|||
use crate::dispatching2::dialogue::{get_chat_id::GetChatId, Dialogue, Storage};
|
||||
use dptree::{di::DependencyMap, Handler, Insert};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub trait DialogueHandlerExt {
|
||||
fn add_dialogue<Upd, S, D>(self) -> Self
|
||||
where
|
||||
S: Storage<D> + Send + Sync + 'static,
|
||||
D: Send + Sync + 'static,
|
||||
Upd: GetChatId + Send + Sync + 'static;
|
||||
}
|
||||
|
||||
impl<'a, Output> DialogueHandlerExt for Handler<'a, DependencyMap, Output>
|
||||
where
|
||||
Output: Send + Sync + 'static,
|
||||
{
|
||||
fn add_dialogue<Upd, S, D>(self) -> Self
|
||||
where
|
||||
// FIXME: some of this requirements are useless.
|
||||
S: Storage<D> + Send + Sync + 'static,
|
||||
D: Send + Sync + 'static,
|
||||
Upd: GetChatId + Send + Sync + 'static,
|
||||
{
|
||||
self.chain(dptree::map(|storage: Arc<S>, upd: Arc<Upd>| async move {
|
||||
let chat_id = upd.chat_id()?;
|
||||
Dialogue::new(storage, chat_id).ok()
|
||||
}))
|
||||
}
|
||||
}
|
20
src/dispatching2/dialogue/get_chat_id.rs
Normal file
20
src/dispatching2/dialogue/get_chat_id.rs
Normal file
|
@ -0,0 +1,20 @@
|
|||
use crate::types::CallbackQuery;
|
||||
use teloxide_core::types::Message;
|
||||
|
||||
/// Something that maybe has a chat ID.
|
||||
pub trait GetChatId {
|
||||
#[must_use]
|
||||
fn chat_id(&self) -> Option<i64>;
|
||||
}
|
||||
|
||||
impl GetChatId for Message {
|
||||
fn chat_id(&self) -> Option<i64> {
|
||||
Some(self.chat.id)
|
||||
}
|
||||
}
|
||||
|
||||
impl GetChatId for CallbackQuery {
|
||||
fn chat_id(&self) -> Option<i64> {
|
||||
self.message.as_ref().map(|mes| mes.chat.id)
|
||||
}
|
||||
}
|
85
src/dispatching2/dialogue/mod.rs
Normal file
85
src/dispatching2/dialogue/mod.rs
Normal file
|
@ -0,0 +1,85 @@
|
|||
#[cfg(feature = "redis-storage")]
|
||||
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "redis-storage")))]
|
||||
pub use storage::{RedisStorage, RedisStorageError};
|
||||
|
||||
#[cfg(feature = "sqlite-storage")]
|
||||
pub use storage::{SqliteStorage, SqliteStorageError};
|
||||
|
||||
pub use storage::{serializer, InMemStorage, InMemStorageError, Serializer, Storage, TraceStorage};
|
||||
|
||||
pub use dialogue_handler_ext::DialogueHandlerExt;
|
||||
|
||||
use std::{future::Future, marker::PhantomData, sync::Arc};
|
||||
|
||||
mod dialogue_handler_ext;
|
||||
mod get_chat_id;
|
||||
mod storage;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Dialogue<D, S> {
|
||||
// Maybe it's better to use Box<dyn Storage<D, Err>> here but it's require
|
||||
// us to introduce `Err` generic parameter.
|
||||
storage: Arc<S>,
|
||||
chat_id: i64,
|
||||
_phantom: PhantomData<D>,
|
||||
}
|
||||
|
||||
impl<D, S> Dialogue<D, S>
|
||||
where
|
||||
D: Send + 'static,
|
||||
S: Storage<D>,
|
||||
{
|
||||
pub fn new(storage: Arc<S>, chat_id: i64) -> Result<Self, S::Error> {
|
||||
Ok(Self { storage, chat_id, _phantom: PhantomData })
|
||||
}
|
||||
|
||||
// TODO: Cache this.
|
||||
pub async fn current_state(&self) -> Result<Option<D>, S::Error> {
|
||||
self.storage.clone().get_dialogue(self.chat_id).await
|
||||
}
|
||||
|
||||
pub async fn current_state_or_default(&self) -> Result<D, S::Error>
|
||||
where
|
||||
D: Default,
|
||||
{
|
||||
match self.storage.clone().get_dialogue(self.chat_id).await? {
|
||||
Some(d) => Ok(d),
|
||||
None => {
|
||||
self.storage.clone().update_dialogue(self.chat_id, D::default()).await?;
|
||||
Ok(D::default())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn next<State>(&self, state: State) -> Result<(), S::Error>
|
||||
where
|
||||
D: From<State>,
|
||||
{
|
||||
let new_dialogue = state.into();
|
||||
self.storage.clone().update_dialogue(self.chat_id, new_dialogue).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn with<F, Fut, State>(&self, f: F) -> Result<(), S::Error>
|
||||
where
|
||||
F: FnOnce(Option<D>) -> Fut,
|
||||
Fut: Future<Output = State>,
|
||||
D: From<State>,
|
||||
{
|
||||
let current_dialogue = self.current_state().await?;
|
||||
let new_dialogue = f(current_dialogue).await.into();
|
||||
self.storage.clone().update_dialogue(self.chat_id, new_dialogue).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn reset(&self) -> Result<(), S::Error>
|
||||
where
|
||||
D: Default,
|
||||
{
|
||||
self.next(D::default()).await
|
||||
}
|
||||
|
||||
pub async fn exit(&self) -> Result<(), S::Error> {
|
||||
self.storage.clone().remove_dialogue(self.chat_id).await
|
||||
}
|
||||
}
|
73
src/dispatching2/dialogue/storage/in_mem_storage.rs
Normal file
73
src/dispatching2/dialogue/storage/in_mem_storage.rs
Normal file
|
@ -0,0 +1,73 @@
|
|||
use super::Storage;
|
||||
use futures::future::BoxFuture;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// An error returned from [`InMemStorage`].
|
||||
#[derive(Debug, Error)]
|
||||
pub enum InMemStorageError {
|
||||
/// Returned from [`InMemStorage::remove_dialogue`].
|
||||
#[error("row not found")]
|
||||
DialogueNotFound,
|
||||
}
|
||||
|
||||
/// A dialogue storage based on [`std::collections::HashMap`].
|
||||
///
|
||||
/// ## Note
|
||||
/// All your dialogues will be lost after you restart your bot. If you need to
|
||||
/// store them somewhere on a drive, you should use e.g.
|
||||
/// [`super::SqliteStorage`] or implement your own.
|
||||
#[derive(Debug)]
|
||||
pub struct InMemStorage<D> {
|
||||
map: Mutex<HashMap<i64, D>>,
|
||||
}
|
||||
|
||||
impl<S> InMemStorage<S> {
|
||||
#[must_use]
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self { map: Mutex::new(HashMap::new()) })
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> Storage<D> for InMemStorage<D>
|
||||
where
|
||||
D: Clone,
|
||||
D: Send + 'static,
|
||||
{
|
||||
type Error = InMemStorageError;
|
||||
|
||||
fn remove_dialogue(self: Arc<Self>, chat_id: i64) -> BoxFuture<'static, Result<(), Self::Error>>
|
||||
where
|
||||
D: Send + 'static,
|
||||
{
|
||||
Box::pin(async move {
|
||||
self.map
|
||||
.lock()
|
||||
.await
|
||||
.remove(&chat_id)
|
||||
.map_or(Err(InMemStorageError::DialogueNotFound), |_| Ok(()))
|
||||
})
|
||||
}
|
||||
|
||||
fn update_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
dialogue: D,
|
||||
) -> BoxFuture<'static, Result<(), Self::Error>>
|
||||
where
|
||||
D: Send + 'static,
|
||||
{
|
||||
Box::pin(async move {
|
||||
self.map.lock().await.insert(chat_id, dialogue);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn get_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
|
||||
Box::pin(async move { Ok(self.map.lock().await.get(&chat_id).map(ToOwned::to_owned)) })
|
||||
}
|
||||
}
|
76
src/dispatching2/dialogue/storage/mod.rs
Normal file
76
src/dispatching2/dialogue/storage/mod.rs
Normal file
|
@ -0,0 +1,76 @@
|
|||
pub mod serializer;
|
||||
|
||||
mod in_mem_storage;
|
||||
mod trace_storage;
|
||||
|
||||
#[cfg(feature = "redis-storage")]
|
||||
mod redis_storage;
|
||||
|
||||
#[cfg(feature = "sqlite-storage")]
|
||||
mod sqlite_storage;
|
||||
|
||||
use futures::future::BoxFuture;
|
||||
|
||||
pub use self::{
|
||||
in_mem_storage::{InMemStorage, InMemStorageError},
|
||||
trace_storage::TraceStorage,
|
||||
};
|
||||
|
||||
#[cfg(feature = "redis-storage")]
|
||||
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "redis-storage")))]
|
||||
pub use redis_storage::{RedisStorage, RedisStorageError};
|
||||
pub use serializer::Serializer;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "sqlite-storage")]
|
||||
pub use sqlite_storage::{SqliteStorage, SqliteStorageError};
|
||||
|
||||
/// A storage of dialogues.
|
||||
///
|
||||
/// You can implement this trait for a structure that communicates with a DB and
|
||||
/// be sure that after you restart your bot, all the dialogues won't be lost.
|
||||
///
|
||||
/// `Storage` is used only to store dialogue states, i.e. it can't be used as a
|
||||
/// generic database.
|
||||
///
|
||||
/// Currently we support the following storages out of the box:
|
||||
///
|
||||
/// - [`InMemStorage`] -- a storage based on [`std::collections::HashMap`].
|
||||
/// - [`RedisStorage`] -- a Redis-based storage.
|
||||
/// - [`SqliteStorage`] -- an SQLite-based persistent storage.
|
||||
///
|
||||
/// [`InMemStorage`]: crate::dispatching::dialogue::InMemStorage
|
||||
/// [`RedisStorage`]: crate::dispatching::dialogue::RedisStorage
|
||||
/// [`SqliteStorage`]: crate::dispatching::dialogue::SqliteStorage
|
||||
pub trait Storage<D> {
|
||||
type Error;
|
||||
|
||||
/// Removes a dialogue indexed by `chat_id`.
|
||||
///
|
||||
/// If the dialogue indexed by `chat_id` does not exist, this function
|
||||
/// results in an error.
|
||||
#[must_use = "Futures are lazy and do nothing unless polled with .await"]
|
||||
fn remove_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
) -> BoxFuture<'static, Result<(), Self::Error>>
|
||||
where
|
||||
D: Send + 'static;
|
||||
|
||||
/// Updates a dialogue indexed by `chat_id` with `dialogue`.
|
||||
#[must_use = "Futures are lazy and do nothing unless polled with .await"]
|
||||
fn update_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
dialogue: D,
|
||||
) -> BoxFuture<'static, Result<(), Self::Error>>
|
||||
where
|
||||
D: Send + 'static;
|
||||
|
||||
/// Returns the dialogue indexed by `chat_id`.
|
||||
#[must_use = "Futures are lazy and do nothing unless polled with .await"]
|
||||
fn get_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
) -> BoxFuture<'static, Result<Option<D>, Self::Error>>;
|
||||
}
|
110
src/dispatching2/dialogue/storage/redis_storage.rs
Normal file
110
src/dispatching2/dialogue/storage/redis_storage.rs
Normal file
|
@ -0,0 +1,110 @@
|
|||
use super::{serializer::Serializer, Storage};
|
||||
use futures::future::BoxFuture;
|
||||
use redis::{AsyncCommands, IntoConnectionInfo};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
fmt::{Debug, Display},
|
||||
ops::DerefMut,
|
||||
sync::Arc,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// An error returned from [`RedisStorage`].
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RedisStorageError<SE>
|
||||
where
|
||||
SE: Debug + Display,
|
||||
{
|
||||
#[error("parsing/serializing error: {0}")]
|
||||
SerdeError(SE),
|
||||
|
||||
#[error("error from Redis: {0}")]
|
||||
RedisError(#[from] redis::RedisError),
|
||||
|
||||
/// Returned from [`RedisStorage::remove_dialogue`].
|
||||
#[error("row not found")]
|
||||
DialogueNotFound,
|
||||
}
|
||||
|
||||
/// A dialogue storage based on [Redis](https://redis.io/).
|
||||
pub struct RedisStorage<S> {
|
||||
conn: Mutex<redis::aio::Connection>,
|
||||
serializer: S,
|
||||
}
|
||||
|
||||
impl<S> RedisStorage<S> {
|
||||
pub async fn open(
|
||||
url: impl IntoConnectionInfo,
|
||||
serializer: S,
|
||||
) -> Result<Arc<Self>, RedisStorageError<Infallible>> {
|
||||
Ok(Arc::new(Self {
|
||||
conn: Mutex::new(redis::Client::open(url)?.get_async_connection().await?),
|
||||
serializer,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, D> Storage<D> for RedisStorage<S>
|
||||
where
|
||||
S: Send + Sync + Serializer<D> + 'static,
|
||||
D: Send + Serialize + DeserializeOwned + 'static,
|
||||
<S as Serializer<D>>::Error: Debug + Display,
|
||||
{
|
||||
type Error = RedisStorageError<<S as Serializer<D>>::Error>;
|
||||
|
||||
fn remove_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
) -> BoxFuture<'static, Result<(), Self::Error>> {
|
||||
Box::pin(async move {
|
||||
let deleted_rows_count = redis::pipe()
|
||||
.atomic()
|
||||
.del(chat_id)
|
||||
.query_async::<_, redis::Value>(self.conn.lock().await.deref_mut())
|
||||
.await?;
|
||||
|
||||
if let redis::Value::Bulk(values) = deleted_rows_count {
|
||||
// False positive
|
||||
#[allow(clippy::collapsible_match)]
|
||||
if let redis::Value::Int(deleted_rows_count) = values[0] {
|
||||
match deleted_rows_count {
|
||||
0 => return Err(RedisStorageError::DialogueNotFound),
|
||||
_ => return Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unreachable!("Must return redis::Value::Bulk(redis::Value::Int(_))");
|
||||
})
|
||||
}
|
||||
|
||||
fn update_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
dialogue: D,
|
||||
) -> BoxFuture<'static, Result<(), Self::Error>> {
|
||||
Box::pin(async move {
|
||||
let dialogue =
|
||||
self.serializer.serialize(&dialogue).map_err(RedisStorageError::SerdeError)?;
|
||||
self.conn.lock().await.set::<_, Vec<u8>, _>(chat_id, dialogue).await?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn get_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
|
||||
Box::pin(async move {
|
||||
self.conn
|
||||
.lock()
|
||||
.await
|
||||
.get::<_, Option<Vec<u8>>>(chat_id)
|
||||
.await?
|
||||
.map(|d| self.serializer.deserialize(&d).map_err(RedisStorageError::SerdeError))
|
||||
.transpose()
|
||||
})
|
||||
}
|
||||
}
|
77
src/dispatching2/dialogue/storage/serializer.rs
Normal file
77
src/dispatching2/dialogue/storage/serializer.rs
Normal file
|
@ -0,0 +1,77 @@
|
|||
//! Various serializers for dialogue storages.
|
||||
|
||||
use serde::{de::DeserializeOwned, ser::Serialize};
|
||||
|
||||
/// A serializer for memory storages.
|
||||
pub trait Serializer<D> {
|
||||
type Error;
|
||||
|
||||
fn serialize(&self, val: &D) -> Result<Vec<u8>, Self::Error>;
|
||||
fn deserialize(&self, data: &[u8]) -> Result<D, Self::Error>;
|
||||
}
|
||||
|
||||
/// The JSON serializer for memory storages.
|
||||
pub struct Json;
|
||||
|
||||
impl<D> Serializer<D> for Json
|
||||
where
|
||||
D: Serialize + DeserializeOwned,
|
||||
{
|
||||
type Error = serde_json::Error;
|
||||
|
||||
fn serialize(&self, val: &D) -> Result<Vec<u8>, Self::Error> {
|
||||
serde_json::to_vec(val)
|
||||
}
|
||||
|
||||
fn deserialize(&self, data: &[u8]) -> Result<D, Self::Error> {
|
||||
serde_json::from_slice(data)
|
||||
}
|
||||
}
|
||||
|
||||
/// The [CBOR] serializer for memory storages.
|
||||
///
|
||||
/// [CBOR]: https://en.wikipedia.org/wiki/CBOR
|
||||
#[cfg(feature = "cbor-serializer")]
|
||||
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "cbor-serializer")))]
|
||||
pub struct Cbor;
|
||||
|
||||
#[cfg(feature = "cbor-serializer")]
|
||||
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "cbor-serializer")))]
|
||||
impl<D> Serializer<D> for Cbor
|
||||
where
|
||||
D: Serialize + DeserializeOwned,
|
||||
{
|
||||
type Error = serde_cbor::Error;
|
||||
|
||||
fn serialize(&self, val: &D) -> Result<Vec<u8>, Self::Error> {
|
||||
serde_cbor::to_vec(val)
|
||||
}
|
||||
|
||||
fn deserialize(&self, data: &[u8]) -> Result<D, Self::Error> {
|
||||
serde_cbor::from_slice(data)
|
||||
}
|
||||
}
|
||||
|
||||
/// The [Bincode] serializer for memory storages.
|
||||
///
|
||||
/// [Bincode]: https://github.com/servo/bincode
|
||||
#[cfg(feature = "bincode-serializer")]
|
||||
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "bincode-serializer")))]
|
||||
pub struct Bincode;
|
||||
|
||||
#[cfg(feature = "bincode-serializer")]
|
||||
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "bincode-serializer")))]
|
||||
impl<D> Serializer<D> for Bincode
|
||||
where
|
||||
D: Serialize + DeserializeOwned,
|
||||
{
|
||||
type Error = bincode::Error;
|
||||
|
||||
fn serialize(&self, val: &D) -> Result<Vec<u8>, Self::Error> {
|
||||
bincode::serialize(val)
|
||||
}
|
||||
|
||||
fn deserialize(&self, data: &[u8]) -> Result<D, Self::Error> {
|
||||
bincode::deserialize(data)
|
||||
}
|
||||
}
|
141
src/dispatching2/dialogue/storage/sqlite_storage.rs
Normal file
141
src/dispatching2/dialogue/storage/sqlite_storage.rs
Normal file
|
@ -0,0 +1,141 @@
|
|||
use super::{serializer::Serializer, Storage};
|
||||
use futures::future::BoxFuture;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use sqlx::{sqlite::SqlitePool, Executor};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
fmt::{Debug, Display},
|
||||
str,
|
||||
sync::Arc,
|
||||
};
|
||||
use thiserror::Error;
|
||||
|
||||
/// A persistent dialogue storage based on [SQLite](https://www.sqlite.org/).
|
||||
pub struct SqliteStorage<S> {
|
||||
pool: SqlitePool,
|
||||
serializer: S,
|
||||
}
|
||||
|
||||
/// An error returned from [`SqliteStorage`].
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SqliteStorageError<SE>
|
||||
where
|
||||
SE: Debug + Display,
|
||||
{
|
||||
#[error("dialogue serialization error: {0}")]
|
||||
SerdeError(SE),
|
||||
|
||||
#[error("sqlite error: {0}")]
|
||||
SqliteError(#[from] sqlx::Error),
|
||||
|
||||
/// Returned from [`SqliteStorage::remove_dialogue`].
|
||||
#[error("row not found")]
|
||||
DialogueNotFound,
|
||||
}
|
||||
|
||||
impl<S> SqliteStorage<S> {
|
||||
pub async fn open(
|
||||
path: &str,
|
||||
serializer: S,
|
||||
) -> Result<Arc<Self>, SqliteStorageError<Infallible>> {
|
||||
let pool = SqlitePool::connect(format!("sqlite:{}?mode=rwc", path).as_str()).await?;
|
||||
let mut conn = pool.acquire().await?;
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS teloxide_dialogues (
|
||||
chat_id BIGINT PRIMARY KEY,
|
||||
dialogue BLOB NOT NULL
|
||||
);
|
||||
"#,
|
||||
)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
Ok(Arc::new(Self { pool, serializer }))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, D> Storage<D> for SqliteStorage<S>
|
||||
where
|
||||
S: Send + Sync + Serializer<D> + 'static,
|
||||
D: Send + Serialize + DeserializeOwned + 'static,
|
||||
<S as Serializer<D>>::Error: Debug + Display,
|
||||
{
|
||||
type Error = SqliteStorageError<<S as Serializer<D>>::Error>;
|
||||
|
||||
/// Returns [`sqlx::Error::RowNotFound`] if a dialogue does not exist.
|
||||
fn remove_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
) -> BoxFuture<'static, Result<(), Self::Error>> {
|
||||
Box::pin(async move {
|
||||
let deleted_rows_count =
|
||||
sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?")
|
||||
.bind(chat_id)
|
||||
.execute(&self.pool)
|
||||
.await?
|
||||
.rows_affected();
|
||||
|
||||
if deleted_rows_count == 0 {
|
||||
return Err(SqliteStorageError::DialogueNotFound);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn update_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
dialogue: D,
|
||||
) -> BoxFuture<'static, Result<(), Self::Error>> {
|
||||
Box::pin(async move {
|
||||
let d = self.serializer.serialize(&dialogue).map_err(SqliteStorageError::SerdeError)?;
|
||||
|
||||
self.pool
|
||||
.acquire()
|
||||
.await?
|
||||
.execute(
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO teloxide_dialogues VALUES (?, ?)
|
||||
ON CONFLICT(chat_id) DO UPDATE SET dialogue=excluded.dialogue
|
||||
"#,
|
||||
)
|
||||
.bind(chat_id)
|
||||
.bind(d),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn get_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
|
||||
Box::pin(async move {
|
||||
get_dialogue(&self.pool, chat_id)
|
||||
.await?
|
||||
.map(|d| self.serializer.deserialize(&d).map_err(SqliteStorageError::SerdeError))
|
||||
.transpose()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_dialogue(pool: &SqlitePool, chat_id: i64) -> Result<Option<Vec<u8>>, sqlx::Error> {
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct DialogueDbRow {
|
||||
dialogue: Vec<u8>,
|
||||
}
|
||||
|
||||
let bytes = sqlx::query_as::<_, DialogueDbRow>(
|
||||
"SELECT dialogue FROM teloxide_dialogues WHERE chat_id = ?",
|
||||
)
|
||||
.bind(chat_id)
|
||||
.fetch_optional(pool)
|
||||
.await?
|
||||
.map(|r| r.dialogue);
|
||||
|
||||
Ok(bytes)
|
||||
}
|
68
src/dispatching2/dialogue/storage/trace_storage.rs
Normal file
68
src/dispatching2/dialogue/storage/trace_storage.rs
Normal file
|
@ -0,0 +1,68 @@
|
|||
use std::{
|
||||
fmt::Debug,
|
||||
marker::{Send, Sync},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use futures::future::BoxFuture;
|
||||
|
||||
use crate::dispatching::dialogue::Storage;
|
||||
|
||||
/// A dialogue storage wrapper which logs all actions performed on an underlying
|
||||
/// storage.
|
||||
///
|
||||
/// Reports about any dialogue action via [`log::Level::Trace`].
|
||||
pub struct TraceStorage<S> {
|
||||
inner: Arc<S>,
|
||||
}
|
||||
|
||||
impl<S> TraceStorage<S> {
|
||||
#[must_use]
|
||||
pub fn new(inner: Arc<S>) -> Arc<Self> {
|
||||
Arc::new(Self { inner })
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Arc<S> {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, D> Storage<D> for TraceStorage<S>
|
||||
where
|
||||
D: Debug,
|
||||
S: Storage<D> + Send + Sync + 'static,
|
||||
{
|
||||
type Error = <S as Storage<D>>::Error;
|
||||
|
||||
fn remove_dialogue(self: Arc<Self>, chat_id: i64) -> BoxFuture<'static, Result<(), Self::Error>>
|
||||
where
|
||||
D: Send + 'static,
|
||||
{
|
||||
log::trace!("Removing dialogue #{}", chat_id);
|
||||
<S as Storage<D>>::remove_dialogue(self.inner.clone(), chat_id)
|
||||
}
|
||||
|
||||
fn update_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
dialogue: D,
|
||||
) -> BoxFuture<'static, Result<(), Self::Error>>
|
||||
where
|
||||
D: Send + 'static,
|
||||
{
|
||||
Box::pin(async move {
|
||||
let to = format!("{:#?}", dialogue);
|
||||
<S as Storage<D>>::update_dialogue(self.inner.clone(), chat_id, dialogue).await?;
|
||||
log::trace!("Updated a dialogue #{}: {:#?}", chat_id, to);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn get_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
|
||||
log::trace!("Requested a dialogue #{}", chat_id);
|
||||
<S as Storage<D>>::get_dialogue(self.inner.clone(), chat_id)
|
||||
}
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
pub(crate) mod repls;
|
||||
|
||||
pub mod dialogue;
|
||||
mod dispatcher;
|
||||
|
||||
pub use dispatcher::Dispatcher;
|
||||
|
|
|
@ -15,7 +15,10 @@ pub use crate::dispatching::{
|
|||
};
|
||||
|
||||
#[cfg(not(feature = "old_dispatching"))]
|
||||
pub use crate::dispatching2::Dispatcher;
|
||||
pub use crate::dispatching2::{
|
||||
dialogue::{Dialogue, DialogueHandlerExt as _},
|
||||
Dispatcher,
|
||||
};
|
||||
|
||||
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "macros")))]
|
||||
#[cfg(feature = "macros")]
|
||||
|
|
Loading…
Reference in a new issue