From 16b0b47ecfec09169d05ea36c2b5dd88caacf702 Mon Sep 17 00:00:00 2001 From: Sergey Levitin Date: Thu, 22 Oct 2020 21:30:34 +0300 Subject: [PATCH] Properly implement SqliteStorage methods --- .../dialogue/storage/sqlite_storage.rs | 85 +++++++++++++------ 1 file changed, 58 insertions(+), 27 deletions(-) diff --git a/src/dispatching/dialogue/storage/sqlite_storage.rs b/src/dispatching/dialogue/storage/sqlite_storage.rs index 123e8bbc..197a6d10 100644 --- a/src/dispatching/dialogue/storage/sqlite_storage.rs +++ b/src/dispatching/dialogue/storage/sqlite_storage.rs @@ -1,15 +1,14 @@ use super::{serializer::Serializer, Storage}; use futures::future::BoxFuture; use serde::{de::DeserializeOwned, Serialize}; -use sqlx::sqlite::{SqliteConnectOptions, SqliteConnection}; -use sqlx::{ConnectOptions, Executor}; +use sqlx::sqlite::SqlitePool; +use sqlx::Executor; use std::{ convert::Infallible, fmt::{Debug, Display}, sync::Arc, }; use thiserror::Error; -use tokio::sync::Mutex; // An error returned from [`SqliteStorage`]. #[derive(Debug, Error)] @@ -25,20 +24,23 @@ where // TODO: make JSON serializer to be default pub struct SqliteStorage { - conn: Mutex, + pool: SqlitePool, serializer: S, } +#[derive(sqlx::FromRow)] +struct DialogueDBRow { + dialogue: Vec, +} + impl SqliteStorage { pub async fn open( path: &str, serializer: S, ) -> Result, SqliteStorageError> { - let mut conn = - SqliteConnectOptions::new().filename(path).create_if_missing(true).connect().await?; - - // TODO: think about a schema migration mechanism. - conn.execute( + let pool = SqlitePool::connect(format!("sqlite:{}?mode=rwc", path).as_str()).await?; + let mut conn = pool.acquire().await?; + sqlx::query( r#" CREATE TABLE IF NOT EXISTS teloxide_dialogues ( chat_id BIGINT PRIMARY KEY, @@ -46,9 +48,26 @@ CREATE TABLE IF NOT EXISTS teloxide_dialogues ( ); "#, ) + .execute(&mut conn) .await?; - Ok(Arc::new(Self { conn: Mutex::new(conn), serializer })) + Ok(Arc::new(Self { pool, serializer })) + } +} + +async fn get_dialogue( + pool: &SqlitePool, + chat_id: i64, +) -> Result>>, sqlx::Error> { + match sqlx::query_as::<_, DialogueDBRow>( + "SELECT dialogue FROM teloxide_dialogues WHERE chat_id = ?", + ) + .bind(chat_id) + .fetch_optional(pool) + .await? + { + Some(r) => Ok(Some(Box::new(r.dialogue))), + _ => Ok(None), } } @@ -65,14 +84,18 @@ where chat_id: i64, ) -> BoxFuture<'static, Result, Self::Error>> { Box::pin(async move { - self.conn - .lock() - .await - .execute( - sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?").bind(chat_id), - ) - .await?; - Ok(None) + match get_dialogue(&self.pool, chat_id).await? { + None => Ok(None), + Some(d) => { + let prev_dialogue = + self.serializer.deserialize(&d).map_err(SqliteStorageError::SerdeError)?; + sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?") + .bind(chat_id) + .execute(&self.pool) + .await?; + Ok(Some(prev_dialogue)) + } + } }) } @@ -82,23 +105,31 @@ where dialogue: D, ) -> BoxFuture<'static, Result, Self::Error>> { Box::pin(async move { - let dialogue = + let serialized_dialogue = self.serializer.serialize(&dialogue).map_err(SqliteStorageError::SerdeError)?; - self.conn - .lock() - .await + let prev_dialogue = get_dialogue(&self.pool, chat_id).await?; + + self.pool + .acquire() + .await? .execute( sqlx::query( r#" -INSERT INTO teloxide_dialogues VALUES (?, ?) WHERE chat_id = ? -ON CONFLICT(chat_id) DO UPDATE SET dialogue=excluded.dialogue - "#, + INSERT INTO teloxide_dialogues VALUES (?, ?) WHERE chat_id = ? + ON CONFLICT(chat_id) DO UPDATE SET dialogue=excluded.dialogue + "#, ) .bind(chat_id) - .bind(dialogue), + .bind(serialized_dialogue), ) .await?; - Ok(None) + + Ok(match prev_dialogue { + None => None, + Some(d) => { + Some(self.serializer.deserialize(&d).map_err(SqliteStorageError::SerdeError)?) + } + }) }) } }