Fail Storage::remove_dialogue if a dialogue doesn't exist

This commit is contained in:
Temirkhan Myrzamadi 2021-05-08 17:21:24 +06:00
parent ef8c8f4cb5
commit 08bf40e555
12 changed files with 114 additions and 72 deletions

View file

@ -8,11 +8,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- `Storage::get_dialogue`
- `Storage::get_dialogue` to obtain a dialogue indexed by a chat ID.
- `RedisStorageError::RowNotFound` to be returned from `RedisStorage::remove_dialogue`.
- `InMemStorageError` with a single variant `RowNotFound` to be returned from `InMemStorage::remove_dialogue`.
### Changed
- Do not return a dialogue from `Storage::{remove_dialogue, update_dialogue}`.
- Return an error from `Storage::remove_dialogue` if a dialogue does not exist.
- Require `D: Clone` in `dialogues_repl(_with_listener)` and `InMemStorage`.
- Automatically delete a webhook if it was set up in `update_listeners::polling_default` (thereby making it `async`, [issue 319](https://github.com/teloxide/teloxide/issues/319)).

View file

@ -4,11 +4,12 @@ use crate::dispatching::{
},
DispatcherHandler, UpdateWithCx,
};
use std::{convert::Infallible, fmt::Debug, marker::PhantomData};
use std::{fmt::Debug, marker::PhantomData};
use futures::{future::BoxFuture, FutureExt, StreamExt};
use tokio::sync::mpsc;
use crate::dispatching::dialogue::InMemStorageError;
use lockfree::map::Map;
use std::sync::{Arc, Mutex};
use teloxide_core::requests::Requester;
@ -45,7 +46,7 @@ pub struct DialogueDispatcher<R, D, S, H, Upd> {
impl<R, D, H, Upd> DialogueDispatcher<R, D, InMemStorage<D>, H, Upd>
where
H: DialogueDispatcherHandler<R, Upd, D, Infallible> + Send + Sync + 'static,
H: DialogueDispatcherHandler<R, Upd, D, InMemStorageError> + Send + Sync + 'static,
Upd: GetChatId + Send + 'static,
D: Default + Send + 'static,
{

View file

@ -166,7 +166,7 @@ pub use teloxide_macros::Transition;
#[cfg(feature = "redis-storage")]
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "redis-storage")))]
pub use storage::{RedisStorage, RedisStorageError};
pub use storage::{InMemStorageError, RedisStorage, RedisStorageError};
#[cfg(feature = "sqlite-storage")]
pub use storage::{SqliteStorage, SqliteStorageError};

View file

@ -1,18 +1,23 @@
use super::Storage;
use futures::future::BoxFuture;
use std::{collections::HashMap, sync::Arc};
use thiserror::Error;
use tokio::sync::Mutex;
/// A memory storage based on a hash map. Stores all the dialogues directly in
/// RAM.
/// An error returned from [`InMemStorage`].
#[derive(Debug, Error)]
pub enum InMemStorageError {
/// Returned from [`InMemStorage::remove_dialogue`].
#[error("row not found")]
RowNotFound,
}
/// A dialogue storage based on [`std::collections::HashMap`].
///
/// ## Note
/// All the dialogues will be lost after you restart your bot. If you need to
/// store them somewhere on a drive, you should use [`SqliteStorage`],
/// [`RedisStorage`] or implement your own.
///
/// [`RedisStorage`]: crate::dispatching::dialogue::RedisStorage
/// [`SqliteStorage`]: crate::dispatching::dialogue::SqliteStorage
/// All your dialogues will be lost after you restart your bot. If you need to
/// store them somewhere on a drive, you should use e.g.
/// [`super::SqliteStorage`] or implement your own.
#[derive(Debug)]
pub struct InMemStorage<D> {
map: Mutex<HashMap<i64, D>>,
@ -30,15 +35,18 @@ where
D: Clone,
D: Send + 'static,
{
type Error = std::convert::Infallible;
type Error = InMemStorageError;
fn remove_dialogue(self: Arc<Self>, chat_id: i64) -> BoxFuture<'static, Result<(), Self::Error>>
where
D: Send + 'static,
{
Box::pin(async move {
self.map.lock().await.remove(&chat_id);
Ok(())
self.map
.lock()
.await
.remove(&chat_id)
.map_or_else(|| Err(InMemStorageError::RowNotFound), |_| Ok(()))
})
}

View file

@ -11,7 +11,10 @@ mod sqlite_storage;
use futures::future::BoxFuture;
pub use self::{in_mem_storage::InMemStorage, trace_storage::TraceStorage};
pub use self::{
in_mem_storage::{InMemStorage, InMemStorageError},
trace_storage::TraceStorage,
};
#[cfg(feature = "redis-storage")]
#[cfg_attr(all(docsrs, feature = "nightly"), doc(cfg(feature = "redis-storage")))]
@ -32,9 +35,9 @@ pub use sqlite_storage::{SqliteStorage, SqliteStorageError};
///
/// Currently we support the following storages out of the box:
///
/// - [`InMemStorage`] - a storage based on a simple hash map
/// - [`RedisStorage`] - a Redis-based storage
/// - [`SqliteStorage`] - an SQLite-based persistent storage
/// - [`InMemStorage`] -- A storage based on [`std::collections::HashMap`].
/// - [`RedisStorage`] -- A Redis-based storage.
/// - [`SqliteStorage`] -- An SQLite-based persistent storage.
///
/// [`InMemStorage`]: crate::dispatching::dialogue::InMemStorage
/// [`RedisStorage`]: crate::dispatching::dialogue::RedisStorage
@ -43,6 +46,9 @@ pub trait Storage<D> {
type Error;
/// Removes a dialogue indexed by `chat_id`.
///
/// If the dialogue indexed by `chat_id` does not exist, this function
/// results in an error.
#[must_use = "Futures are lazy and do nothing unless polled with .await"]
fn remove_dialogue(
self: Arc<Self>,
@ -61,7 +67,7 @@ pub trait Storage<D> {
where
D: Send + 'static;
/// Provides a dialogue indexed by `chat_id`.
/// Extracts a dialogue indexed by `chat_id`.
#[must_use = "Futures are lazy and do nothing unless polled with .await"]
fn get_dialogue(
self: Arc<Self>,

View file

@ -1,6 +1,6 @@
use super::{serializer::Serializer, Storage};
use futures::future::BoxFuture;
use redis::{AsyncCommands, FromRedisValue, IntoConnectionInfo};
use redis::{AsyncCommands, IntoConnectionInfo};
use serde::{de::DeserializeOwned, Serialize};
use std::{
convert::Infallible,
@ -12,8 +12,6 @@ use thiserror::Error;
use tokio::sync::Mutex;
/// An error returned from [`RedisStorage`].
///
/// [`RedisStorage`]: struct.RedisStorage.html
#[derive(Debug, Error)]
pub enum RedisStorageError<SE>
where
@ -21,11 +19,16 @@ where
{
#[error("parsing/serializing error: {0}")]
SerdeError(SE),
#[error("error from Redis: {0}")]
RedisError(#[from] redis::RedisError),
/// Returned from [`RedisStorage::remove_dialogue`].
#[error("row not found")]
RowNotFound,
}
/// A memory storage based on [Redis](https://redis.io/).
/// A dialogue storage based on [Redis](https://redis.io/).
pub struct RedisStorage<S> {
conn: Mutex<redis::aio::Connection>,
serializer: S,
@ -51,36 +54,30 @@ where
{
type Error = RedisStorageError<<S as Serializer<D>>::Error>;
// `.del().ignore()` is much more readable than `.del()\n.ignore()`
#[rustfmt::skip]
fn remove_dialogue(
self: Arc<Self>,
chat_id: i64,
) -> BoxFuture<'static, Result<(), Self::Error>> {
Box::pin(async move {
let res = redis::pipe()
let deleted_rows_count = redis::pipe()
.atomic()
.get(chat_id)
.del(chat_id).ignore()
.query_async::<_, redis::Value>(
self.conn.lock().await.deref_mut(),
)
.del(chat_id)
.query_async::<_, redis::Value>(self.conn.lock().await.deref_mut())
.await?;
// We're expecting `.pipe()` to return us an exactly one result in
// bulk, so all other branches should be unreachable
match res {
redis::Value::Bulk(bulk) if bulk.len() == 1 => {
Option::<Vec<u8>>::from_redis_value(&bulk[0])?
.map(|v| {
self.serializer
.deserialize(&v)
.map_err(RedisStorageError::SerdeError)
})
.transpose()?;
Ok(())
}
_ => unreachable!(),
let deleted_rows_count = match deleted_rows_count {
redis::Value::Bulk(values) => match values[0] {
redis::Value::Int(x) => x,
_ => unreachable!("Must return redis::Value::Int"),
},
_ => unreachable!("Must return redis::Value::Bulk"),
};
if deleted_rows_count == 0 {
return Err(RedisStorageError::RowNotFound);
}
Ok(())
})
}

View file

@ -1,4 +1,4 @@
//! Various serializers for memory storages.
//! Various serializers for dialogue storages.
use serde::{de::DeserializeOwned, ser::Serialize};

View file

@ -10,15 +10,13 @@ use std::{
};
use thiserror::Error;
/// A persistent storage based on [SQLite](https://www.sqlite.org/).
/// A persistent dialogue storage based on [SQLite](https://www.sqlite.org/).
pub struct SqliteStorage<S> {
pool: SqlitePool,
serializer: S,
}
/// An error returned from [`SqliteStorage`].
///
/// [`SqliteStorage`]: struct.SqliteStorage.html
#[derive(Debug, Error)]
pub enum SqliteStorageError<SE>
where
@ -26,6 +24,7 @@ where
{
#[error("dialogue serialization error: {0}")]
SerdeError(SE),
#[error("sqlite error: {0}")]
SqliteError(#[from] sqlx::Error),
}
@ -60,15 +59,23 @@ where
{
type Error = SqliteStorageError<<S as Serializer<D>>::Error>;
/// Returns [`sqlx::Error::RowNotFound`] if a dialogue does not exist.
fn remove_dialogue(
self: Arc<Self>,
chat_id: i64,
) -> BoxFuture<'static, Result<(), Self::Error>> {
Box::pin(async move {
sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?")
.bind(chat_id)
.execute(&self.pool)
.await?;
let deleted_rows_count =
sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?; SELECT changes()")
.bind(chat_id)
.execute(&self.pool)
.await?
.rows_affected();
if deleted_rows_count == 0 {
return Err(SqliteStorageError::SqliteError(sqlx::Error::RowNotFound));
}
Ok(())
})
}
@ -112,20 +119,22 @@ where
}
}
#[derive(sqlx::FromRow)]
struct DialogueDbRow {
dialogue: Vec<u8>,
}
async fn get_dialogue(
pool: &SqlitePool,
chat_id: i64,
) -> Result<Option<Box<Vec<u8>>>, sqlx::Error> {
Ok(sqlx::query_as::<_, DialogueDbRow>(
#[derive(sqlx::FromRow)]
struct DialogueDbRow {
dialogue: Vec<u8>,
}
let bytes = sqlx::query_as::<_, DialogueDbRow>(
"SELECT dialogue FROM teloxide_dialogues WHERE chat_id = ?",
)
.bind(chat_id)
.fetch_optional(pool)
.await?
.map(|r| Box::new(r.dialogue)))
.map(|r| Box::new(r.dialogue));
Ok(bytes)
}

View file

@ -8,10 +8,10 @@ use futures::future::BoxFuture;
use crate::dispatching::dialogue::Storage;
/// Storage wrapper for logging purposes.
/// A dialogue storage wrapper which logs all actions performed on an underlying
/// storage.
///
/// Reports about any dialogue action using the `trace` level in the `log`
/// crate.
/// Reports about any dialogue action via [`log::Level::Trace`].
pub struct TraceStorage<S> {
inner: Arc<S>,
}

View file

@ -1,13 +1,13 @@
use crate::{
dispatching::{
dialogue::{DialogueDispatcher, DialogueStage, DialogueWithCx},
dialogue::{DialogueDispatcher, DialogueStage, DialogueWithCx, InMemStorageError},
update_listeners,
update_listeners::UpdateListener,
Dispatcher, UpdateWithCx,
},
error_handlers::LoggingErrorHandler,
};
use std::{convert::Infallible, fmt::Debug, future::Future, sync::Arc};
use std::{fmt::Debug, future::Future, sync::Arc};
use teloxide_core::{requests::Requester, types::Message};
/// A [REPL] for dialogues.
@ -71,7 +71,12 @@ pub async fn dialogues_repl_with_listener<'a, R, H, D, Fut, L, ListenerE>(
Dispatcher::new(requester)
.messages_handler(DialogueDispatcher::new(
move |DialogueWithCx { cx, dialogue }: DialogueWithCx<R, Message, D, Infallible>| {
move |DialogueWithCx { cx, dialogue }: DialogueWithCx<
R,
Message,
D,
InMemStorageError,
>| {
let handler = Arc::clone(&handler);
async move {

View file

@ -2,12 +2,12 @@ use std::{
fmt::{Debug, Display},
sync::Arc,
};
use teloxide::dispatching::dialogue::{RedisStorage, Serializer, Storage};
use teloxide::dispatching::dialogue::{RedisStorage, RedisStorageError, Serializer, Storage};
#[tokio::test]
async fn test_redis_json() {
let storage = RedisStorage::open(
"redis://127.0.0.1:7777",
"redis://127.0.0.1:9000",
teloxide::dispatching::dialogue::serializer::Json,
)
.await
@ -18,7 +18,7 @@ async fn test_redis_json() {
#[tokio::test]
async fn test_redis_bincode() {
let storage = RedisStorage::open(
"redis://127.0.0.1:7778",
"redis://127.0.0.1:9001",
teloxide::dispatching::dialogue::serializer::Bincode,
)
.await
@ -29,7 +29,7 @@ async fn test_redis_bincode() {
#[tokio::test]
async fn test_redis_cbor() {
let storage = RedisStorage::open(
"redis://127.0.0.1:7779",
"redis://127.0.0.1:9002",
teloxide::dispatching::dialogue::serializer::Cbor,
)
.await
@ -70,4 +70,10 @@ where
Arc::clone(&storage).remove_dialogue(256).await.unwrap();
test_dialogues!(storage, None, None, None);
// Check that a try to remove a non-existing dialogue results in an error.
assert!(matches!(
Arc::clone(&storage).remove_dialogue(1).await.unwrap_err(),
RedisStorageError::RowNotFound
));
}

View file

@ -2,7 +2,7 @@ use std::{
fmt::{Debug, Display},
sync::Arc,
};
use teloxide::dispatching::dialogue::{Serializer, SqliteStorage, Storage};
use teloxide::dispatching::dialogue::{Serializer, SqliteStorage, SqliteStorageError, Storage};
#[tokio::test(flavor = "multi_thread")]
async fn test_sqlite_json() {
@ -66,4 +66,11 @@ where
Arc::clone(&storage).remove_dialogue(256).await.unwrap();
test_dialogues!(storage, None, None, None);
// Check that a try to remove a non-existing dialogue results in an error.
let err = Arc::clone(&storage).remove_dialogue(1).await.unwrap_err();
match err {
SqliteStorageError::SqliteError(err) => assert!(matches!(err, sqlx::Error::RowNotFound)),
_ => panic!("Must be sqlx::Error::RowNotFound"),
}
}