diff --git a/src/dispatching/dialogue/storage/redis_storage.rs b/src/dispatching/dialogue/storage/redis_storage.rs index 31a358e8..a89576cf 100644 --- a/src/dispatching/dialogue/storage/redis_storage.rs +++ b/src/dispatching/dialogue/storage/redis_storage.rs @@ -56,7 +56,7 @@ where fn remove_dialogue( self: Arc<Self>, chat_id: i64, - ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> { + ) -> BoxFuture<'static, Result<(), Self::Error>> { Box::pin(async move { let res = redis::pipe() .atomic() @@ -70,13 +70,14 @@ where // bulk, so all other branches should be unreachable match res { redis::Value::Bulk(bulk) if bulk.len() == 1 => { - Ok(Option::<Vec<u8>>::from_redis_value(&bulk[0])? + Option::<Vec<u8>>::from_redis_value(&bulk[0])? .map(|v| { self.serializer .deserialize(&v) .map_err(RedisStorageError::SerdeError) }) - .transpose()?) + .transpose()?; + Ok(()) } _ => unreachable!(), } @@ -87,14 +88,24 @@ where self: Arc<Self>, chat_id: i64, dialogue: D, - ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> { + ) -> BoxFuture<'static, Result<(), Self::Error>> { Box::pin(async move { let dialogue = self.serializer.serialize(&dialogue).map_err(RedisStorageError::SerdeError)?; + self.conn.lock().await.getset::<_, Vec<u8>, Option<Vec<u8>>>(chat_id, dialogue).await?; + Ok(()) + }) + } + + fn get_dialogue( + self: Arc<Self>, + chat_id: i64, + ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> { + Box::pin(async move { self.conn .lock() .await - .getset::<_, Vec<u8>, Option<Vec<u8>>>(chat_id, dialogue) + .get::<_, Option<Vec<u8>>>(chat_id) .await? .map(|d| self.serializer.deserialize(&d).map_err(RedisStorageError::SerdeError)) .transpose() diff --git a/src/dispatching/dialogue/storage/sqlite_storage.rs b/src/dispatching/dialogue/storage/sqlite_storage.rs index f4e4d98c..ceaeacf2 100644 --- a/src/dispatching/dialogue/storage/sqlite_storage.rs +++ b/src/dispatching/dialogue/storage/sqlite_storage.rs @@ -63,20 +63,16 @@ where fn remove_dialogue( self: Arc<Self>, chat_id: i64, - ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> { + ) -> BoxFuture<'static, Result<(), Self::Error>> { Box::pin(async move { - Ok(match get_dialogue(&self.pool, chat_id).await? { - Some(d) => { - let prev_dialogue = - self.serializer.deserialize(&d).map_err(SqliteStorageError::SerdeError)?; - sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?") - .bind(chat_id) - .execute(&self.pool) - .await?; - Some(prev_dialogue) - } - _ => None, - }) + if get_dialogue(&self.pool, chat_id).await?.is_some() { + sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?") + .bind(chat_id) + .execute(&self.pool) + .await?; + } + + Ok(()) }) } @@ -84,14 +80,10 @@ where self: Arc<Self>, chat_id: i64, dialogue: D, - ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> { + ) -> BoxFuture<'static, Result<(), Self::Error>> { Box::pin(async move { - let prev_dialogue = get_dialogue(&self.pool, chat_id) - .await? - .map(|d| self.serializer.deserialize(&d).map_err(SqliteStorageError::SerdeError)) - .transpose()?; - let upd_dialogue = - self.serializer.serialize(&dialogue).map_err(SqliteStorageError::SerdeError)?; + let d = self.serializer.serialize(&dialogue).map_err(SqliteStorageError::SerdeError)?; + self.pool .acquire() .await? @@ -103,10 +95,22 @@ where "#, ) .bind(chat_id) - .bind(upd_dialogue), + .bind(d), ) .await?; - Ok(prev_dialogue) + Ok(()) + }) + } + + fn get_dialogue( + self: Arc<Self>, + chat_id: i64, + ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> { + Box::pin(async move { + get_dialogue(&self.pool, chat_id) + .await? + .map(|d| self.serializer.deserialize(&d).map_err(SqliteStorageError::SerdeError)) + .transpose() }) } }