mirror of
https://github.com/teloxide/teloxide.git
synced 2025-03-14 11:44:04 +01:00
Fail Storage::remove_dialogue if a dialogue doesn't exist
This commit is contained in:
parent
ef8c8f4cb5
commit
08bf40e555
12 changed files with 114 additions and 72 deletions
|
@ -8,11 +8,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
### Added
|
||||
|
||||
- `Storage::get_dialogue`
|
||||
- `Storage::get_dialogue` to obtain a dialogue indexed by a chat ID.
|
||||
- `RedisStorageError::RowNotFound` to be returned from `RedisStorage::remove_dialogue`.
|
||||
- `InMemStorageError` with a single variant `RowNotFound` to be returned from `InMemStorage::remove_dialogue`.
|
||||
|
||||
### Changed
|
||||
|
||||
- Do not return a dialogue from `Storage::{remove_dialogue, update_dialogue}`.
|
||||
- Return an error from `Storage::remove_dialogue` if a dialogue does not exist.
|
||||
- Require `D: Clone` in `dialogues_repl(_with_listener)` and `InMemStorage`.
|
||||
- Automatically delete a webhook if it was set up in `update_listeners::polling_default` (thereby making it `async`, [issue 319](https://github.com/teloxide/teloxide/issues/319)).
|
||||
|
||||
|
|
|
@ -4,11 +4,12 @@ use crate::dispatching::{
|
|||
},
|
||||
DispatcherHandler, UpdateWithCx,
|
||||
};
|
||||
use std::{convert::Infallible, fmt::Debug, marker::PhantomData};
|
||||
use std::{fmt::Debug, marker::PhantomData};
|
||||
|
||||
use futures::{future::BoxFuture, FutureExt, StreamExt};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::dispatching::dialogue::InMemStorageError;
|
||||
use lockfree::map::Map;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use teloxide_core::requests::Requester;
|
||||
|
@ -45,7 +46,7 @@ pub struct DialogueDispatcher<R, D, S, H, Upd> {
|
|||
|
||||
impl<R, D, H, Upd> DialogueDispatcher<R, D, InMemStorage<D>, H, Upd>
|
||||
where
|
||||
H: DialogueDispatcherHandler<R, Upd, D, Infallible> + Send + Sync + 'static,
|
||||
H: DialogueDispatcherHandler<R, Upd, D, InMemStorageError> + Send + Sync + 'static,
|
||||
Upd: GetChatId + Send + 'static,
|
||||
D: Default + Send + 'static,
|
||||
{
|
||||
|
|
|
@ -166,7 +166,7 @@ pub use teloxide_macros::Transition;
|
|||
|
||||
#[cfg(feature = "redis-storage")]
|
||||
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "redis-storage")))]
|
||||
pub use storage::{RedisStorage, RedisStorageError};
|
||||
pub use storage::{InMemStorageError, RedisStorage, RedisStorageError};
|
||||
|
||||
#[cfg(feature = "sqlite-storage")]
|
||||
pub use storage::{SqliteStorage, SqliteStorageError};
|
||||
|
|
|
@ -1,18 +1,23 @@
|
|||
use super::Storage;
|
||||
use futures::future::BoxFuture;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// A memory storage based on a hash map. Stores all the dialogues directly in
|
||||
/// RAM.
|
||||
/// An error returned from [`InMemStorage`].
|
||||
#[derive(Debug, Error)]
|
||||
pub enum InMemStorageError {
|
||||
/// Returned from [`InMemStorage::remove_dialogue`].
|
||||
#[error("row not found")]
|
||||
RowNotFound,
|
||||
}
|
||||
|
||||
/// A dialogue storage based on [`std::collections::HashMap`].
|
||||
///
|
||||
/// ## Note
|
||||
/// All the dialogues will be lost after you restart your bot. If you need to
|
||||
/// store them somewhere on a drive, you should use [`SqliteStorage`],
|
||||
/// [`RedisStorage`] or implement your own.
|
||||
///
|
||||
/// [`RedisStorage`]: crate::dispatching::dialogue::RedisStorage
|
||||
/// [`SqliteStorage`]: crate::dispatching::dialogue::SqliteStorage
|
||||
/// All your dialogues will be lost after you restart your bot. If you need to
|
||||
/// store them somewhere on a drive, you should use e.g.
|
||||
/// [`super::SqliteStorage`] or implement your own.
|
||||
#[derive(Debug)]
|
||||
pub struct InMemStorage<D> {
|
||||
map: Mutex<HashMap<i64, D>>,
|
||||
|
@ -30,15 +35,18 @@ where
|
|||
D: Clone,
|
||||
D: Send + 'static,
|
||||
{
|
||||
type Error = std::convert::Infallible;
|
||||
type Error = InMemStorageError;
|
||||
|
||||
fn remove_dialogue(self: Arc<Self>, chat_id: i64) -> BoxFuture<'static, Result<(), Self::Error>>
|
||||
where
|
||||
D: Send + 'static,
|
||||
{
|
||||
Box::pin(async move {
|
||||
self.map.lock().await.remove(&chat_id);
|
||||
Ok(())
|
||||
self.map
|
||||
.lock()
|
||||
.await
|
||||
.remove(&chat_id)
|
||||
.map_or_else(|| Err(InMemStorageError::RowNotFound), |_| Ok(()))
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -11,7 +11,10 @@ mod sqlite_storage;
|
|||
|
||||
use futures::future::BoxFuture;
|
||||
|
||||
pub use self::{in_mem_storage::InMemStorage, trace_storage::TraceStorage};
|
||||
pub use self::{
|
||||
in_mem_storage::{InMemStorage, InMemStorageError},
|
||||
trace_storage::TraceStorage,
|
||||
};
|
||||
|
||||
#[cfg(feature = "redis-storage")]
|
||||
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "redis-storage")))]
|
||||
|
@ -32,9 +35,9 @@ pub use sqlite_storage::{SqliteStorage, SqliteStorageError};
|
|||
///
|
||||
/// Currently we support the following storages out of the box:
|
||||
///
|
||||
/// - [`InMemStorage`] - a storage based on a simple hash map
|
||||
/// - [`RedisStorage`] - a Redis-based storage
|
||||
/// - [`SqliteStorage`] - an SQLite-based persistent storage
|
||||
/// - [`InMemStorage`] -- A storage based on [`std::collections::HashMap`].
|
||||
/// - [`RedisStorage`] -- A Redis-based storage.
|
||||
/// - [`SqliteStorage`] -- An SQLite-based persistent storage.
|
||||
///
|
||||
/// [`InMemStorage`]: crate::dispatching::dialogue::InMemStorage
|
||||
/// [`RedisStorage`]: crate::dispatching::dialogue::RedisStorage
|
||||
|
@ -43,6 +46,9 @@ pub trait Storage<D> {
|
|||
type Error;
|
||||
|
||||
/// Removes a dialogue indexed by `chat_id`.
|
||||
///
|
||||
/// If the dialogue indexed by `chat_id` does not exist, this function
|
||||
/// results in an error.
|
||||
#[must_use = "Futures are lazy and do nothing unless polled with .await"]
|
||||
fn remove_dialogue(
|
||||
self: Arc<Self>,
|
||||
|
@ -61,7 +67,7 @@ pub trait Storage<D> {
|
|||
where
|
||||
D: Send + 'static;
|
||||
|
||||
/// Provides a dialogue indexed by `chat_id`.
|
||||
/// Extracts a dialogue indexed by `chat_id`.
|
||||
#[must_use = "Futures are lazy and do nothing unless polled with .await"]
|
||||
fn get_dialogue(
|
||||
self: Arc<Self>,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::{serializer::Serializer, Storage};
|
||||
use futures::future::BoxFuture;
|
||||
use redis::{AsyncCommands, FromRedisValue, IntoConnectionInfo};
|
||||
use redis::{AsyncCommands, IntoConnectionInfo};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
|
@ -12,8 +12,6 @@ use thiserror::Error;
|
|||
use tokio::sync::Mutex;
|
||||
|
||||
/// An error returned from [`RedisStorage`].
|
||||
///
|
||||
/// [`RedisStorage`]: struct.RedisStorage.html
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RedisStorageError<SE>
|
||||
where
|
||||
|
@ -21,11 +19,16 @@ where
|
|||
{
|
||||
#[error("parsing/serializing error: {0}")]
|
||||
SerdeError(SE),
|
||||
|
||||
#[error("error from Redis: {0}")]
|
||||
RedisError(#[from] redis::RedisError),
|
||||
|
||||
/// Returned from [`RedisStorage::remove_dialogue`].
|
||||
#[error("row not found")]
|
||||
RowNotFound,
|
||||
}
|
||||
|
||||
/// A memory storage based on [Redis](https://redis.io/).
|
||||
/// A dialogue storage based on [Redis](https://redis.io/).
|
||||
pub struct RedisStorage<S> {
|
||||
conn: Mutex<redis::aio::Connection>,
|
||||
serializer: S,
|
||||
|
@ -51,36 +54,30 @@ where
|
|||
{
|
||||
type Error = RedisStorageError<<S as Serializer<D>>::Error>;
|
||||
|
||||
// `.del().ignore()` is much more readable than `.del()\n.ignore()`
|
||||
#[rustfmt::skip]
|
||||
fn remove_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
) -> BoxFuture<'static, Result<(), Self::Error>> {
|
||||
Box::pin(async move {
|
||||
let res = redis::pipe()
|
||||
let deleted_rows_count = redis::pipe()
|
||||
.atomic()
|
||||
.get(chat_id)
|
||||
.del(chat_id).ignore()
|
||||
.query_async::<_, redis::Value>(
|
||||
self.conn.lock().await.deref_mut(),
|
||||
)
|
||||
.del(chat_id)
|
||||
.query_async::<_, redis::Value>(self.conn.lock().await.deref_mut())
|
||||
.await?;
|
||||
// We're expecting `.pipe()` to return us an exactly one result in
|
||||
// bulk, so all other branches should be unreachable
|
||||
match res {
|
||||
redis::Value::Bulk(bulk) if bulk.len() == 1 => {
|
||||
Option::<Vec<u8>>::from_redis_value(&bulk[0])?
|
||||
.map(|v| {
|
||||
self.serializer
|
||||
.deserialize(&v)
|
||||
.map_err(RedisStorageError::SerdeError)
|
||||
})
|
||||
.transpose()?;
|
||||
Ok(())
|
||||
}
|
||||
_ => unreachable!(),
|
||||
|
||||
let deleted_rows_count = match deleted_rows_count {
|
||||
redis::Value::Bulk(values) => match values[0] {
|
||||
redis::Value::Int(x) => x,
|
||||
_ => unreachable!("Must return redis::Value::Int"),
|
||||
},
|
||||
_ => unreachable!("Must return redis::Value::Bulk"),
|
||||
};
|
||||
|
||||
if deleted_rows_count == 0 {
|
||||
return Err(RedisStorageError::RowNotFound);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//! Various serializers for memory storages.
|
||||
//! Various serializers for dialogue storages.
|
||||
|
||||
use serde::{de::DeserializeOwned, ser::Serialize};
|
||||
|
||||
|
|
|
@ -10,15 +10,13 @@ use std::{
|
|||
};
|
||||
use thiserror::Error;
|
||||
|
||||
/// A persistent storage based on [SQLite](https://www.sqlite.org/).
|
||||
/// A persistent dialogue storage based on [SQLite](https://www.sqlite.org/).
|
||||
pub struct SqliteStorage<S> {
|
||||
pool: SqlitePool,
|
||||
serializer: S,
|
||||
}
|
||||
|
||||
/// An error returned from [`SqliteStorage`].
|
||||
///
|
||||
/// [`SqliteStorage`]: struct.SqliteStorage.html
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SqliteStorageError<SE>
|
||||
where
|
||||
|
@ -26,6 +24,7 @@ where
|
|||
{
|
||||
#[error("dialogue serialization error: {0}")]
|
||||
SerdeError(SE),
|
||||
|
||||
#[error("sqlite error: {0}")]
|
||||
SqliteError(#[from] sqlx::Error),
|
||||
}
|
||||
|
@ -60,15 +59,23 @@ where
|
|||
{
|
||||
type Error = SqliteStorageError<<S as Serializer<D>>::Error>;
|
||||
|
||||
/// Returns [`sqlx::Error::RowNotFound`] if a dialogue does not exist.
|
||||
fn remove_dialogue(
|
||||
self: Arc<Self>,
|
||||
chat_id: i64,
|
||||
) -> BoxFuture<'static, Result<(), Self::Error>> {
|
||||
Box::pin(async move {
|
||||
sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?")
|
||||
.bind(chat_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
let deleted_rows_count =
|
||||
sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?; SELECT changes()")
|
||||
.bind(chat_id)
|
||||
.execute(&self.pool)
|
||||
.await?
|
||||
.rows_affected();
|
||||
|
||||
if deleted_rows_count == 0 {
|
||||
return Err(SqliteStorageError::SqliteError(sqlx::Error::RowNotFound));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
@ -112,20 +119,22 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct DialogueDbRow {
|
||||
dialogue: Vec<u8>,
|
||||
}
|
||||
|
||||
async fn get_dialogue(
|
||||
pool: &SqlitePool,
|
||||
chat_id: i64,
|
||||
) -> Result<Option<Box<Vec<u8>>>, sqlx::Error> {
|
||||
Ok(sqlx::query_as::<_, DialogueDbRow>(
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct DialogueDbRow {
|
||||
dialogue: Vec<u8>,
|
||||
}
|
||||
|
||||
let bytes = sqlx::query_as::<_, DialogueDbRow>(
|
||||
"SELECT dialogue FROM teloxide_dialogues WHERE chat_id = ?",
|
||||
)
|
||||
.bind(chat_id)
|
||||
.fetch_optional(pool)
|
||||
.await?
|
||||
.map(|r| Box::new(r.dialogue)))
|
||||
.map(|r| Box::new(r.dialogue));
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
|
|
|
@ -8,10 +8,10 @@ use futures::future::BoxFuture;
|
|||
|
||||
use crate::dispatching::dialogue::Storage;
|
||||
|
||||
/// Storage wrapper for logging purposes.
|
||||
/// A dialogue storage wrapper which logs all actions performed on an underlying
|
||||
/// storage.
|
||||
///
|
||||
/// Reports about any dialogue action using the `trace` level in the `log`
|
||||
/// crate.
|
||||
/// Reports about any dialogue action via [`log::Level::Trace`].
|
||||
pub struct TraceStorage<S> {
|
||||
inner: Arc<S>,
|
||||
}
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
use crate::{
|
||||
dispatching::{
|
||||
dialogue::{DialogueDispatcher, DialogueStage, DialogueWithCx},
|
||||
dialogue::{DialogueDispatcher, DialogueStage, DialogueWithCx, InMemStorageError},
|
||||
update_listeners,
|
||||
update_listeners::UpdateListener,
|
||||
Dispatcher, UpdateWithCx,
|
||||
},
|
||||
error_handlers::LoggingErrorHandler,
|
||||
};
|
||||
use std::{convert::Infallible, fmt::Debug, future::Future, sync::Arc};
|
||||
use std::{fmt::Debug, future::Future, sync::Arc};
|
||||
use teloxide_core::{requests::Requester, types::Message};
|
||||
|
||||
/// A [REPL] for dialogues.
|
||||
|
@ -71,7 +71,12 @@ pub async fn dialogues_repl_with_listener<'a, R, H, D, Fut, L, ListenerE>(
|
|||
|
||||
Dispatcher::new(requester)
|
||||
.messages_handler(DialogueDispatcher::new(
|
||||
move |DialogueWithCx { cx, dialogue }: DialogueWithCx<R, Message, D, Infallible>| {
|
||||
move |DialogueWithCx { cx, dialogue }: DialogueWithCx<
|
||||
R,
|
||||
Message,
|
||||
D,
|
||||
InMemStorageError,
|
||||
>| {
|
||||
let handler = Arc::clone(&handler);
|
||||
|
||||
async move {
|
||||
|
|
|
@ -2,12 +2,12 @@ use std::{
|
|||
fmt::{Debug, Display},
|
||||
sync::Arc,
|
||||
};
|
||||
use teloxide::dispatching::dialogue::{RedisStorage, Serializer, Storage};
|
||||
use teloxide::dispatching::dialogue::{RedisStorage, RedisStorageError, Serializer, Storage};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_redis_json() {
|
||||
let storage = RedisStorage::open(
|
||||
"redis://127.0.0.1:7777",
|
||||
"redis://127.0.0.1:9000",
|
||||
teloxide::dispatching::dialogue::serializer::Json,
|
||||
)
|
||||
.await
|
||||
|
@ -18,7 +18,7 @@ async fn test_redis_json() {
|
|||
#[tokio::test]
|
||||
async fn test_redis_bincode() {
|
||||
let storage = RedisStorage::open(
|
||||
"redis://127.0.0.1:7778",
|
||||
"redis://127.0.0.1:9001",
|
||||
teloxide::dispatching::dialogue::serializer::Bincode,
|
||||
)
|
||||
.await
|
||||
|
@ -29,7 +29,7 @@ async fn test_redis_bincode() {
|
|||
#[tokio::test]
|
||||
async fn test_redis_cbor() {
|
||||
let storage = RedisStorage::open(
|
||||
"redis://127.0.0.1:7779",
|
||||
"redis://127.0.0.1:9002",
|
||||
teloxide::dispatching::dialogue::serializer::Cbor,
|
||||
)
|
||||
.await
|
||||
|
@ -70,4 +70,10 @@ where
|
|||
Arc::clone(&storage).remove_dialogue(256).await.unwrap();
|
||||
|
||||
test_dialogues!(storage, None, None, None);
|
||||
|
||||
// Check that a try to remove a non-existing dialogue results in an error.
|
||||
assert!(matches!(
|
||||
Arc::clone(&storage).remove_dialogue(1).await.unwrap_err(),
|
||||
RedisStorageError::RowNotFound
|
||||
));
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::{
|
|||
fmt::{Debug, Display},
|
||||
sync::Arc,
|
||||
};
|
||||
use teloxide::dispatching::dialogue::{Serializer, SqliteStorage, Storage};
|
||||
use teloxide::dispatching::dialogue::{Serializer, SqliteStorage, SqliteStorageError, Storage};
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn test_sqlite_json() {
|
||||
|
@ -66,4 +66,11 @@ where
|
|||
Arc::clone(&storage).remove_dialogue(256).await.unwrap();
|
||||
|
||||
test_dialogues!(storage, None, None, None);
|
||||
|
||||
// Check that a try to remove a non-existing dialogue results in an error.
|
||||
let err = Arc::clone(&storage).remove_dialogue(1).await.unwrap_err();
|
||||
match err {
|
||||
SqliteStorageError::SqliteError(err) => assert!(matches!(err, sqlx::Error::RowNotFound)),
|
||||
_ => panic!("Must be sqlx::Error::RowNotFound"),
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue