Properly implement SqliteStorage methods

This commit is contained in:
Sergey Levitin 2020-10-22 21:30:34 +03:00
parent fb996d943d
commit 16b0b47ecf

View file

@ -1,15 +1,14 @@
use super::{serializer::Serializer, Storage}; use super::{serializer::Serializer, Storage};
use futures::future::BoxFuture; use futures::future::BoxFuture;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use sqlx::sqlite::{SqliteConnectOptions, SqliteConnection}; use sqlx::sqlite::SqlitePool;
use sqlx::{ConnectOptions, Executor}; use sqlx::Executor;
use std::{ use std::{
convert::Infallible, convert::Infallible,
fmt::{Debug, Display}, fmt::{Debug, Display},
sync::Arc, sync::Arc,
}; };
use thiserror::Error; use thiserror::Error;
use tokio::sync::Mutex;
// An error returned from [`SqliteStorage`]. // An error returned from [`SqliteStorage`].
#[derive(Debug, Error)] #[derive(Debug, Error)]
@ -25,20 +24,23 @@ where
// TODO: make JSON serializer to be default // TODO: make JSON serializer to be default
pub struct SqliteStorage<S> { pub struct SqliteStorage<S> {
conn: Mutex<SqliteConnection>, pool: SqlitePool,
serializer: S, serializer: S,
} }
#[derive(sqlx::FromRow)]
struct DialogueDBRow {
dialogue: Vec<u8>,
}
impl<S> SqliteStorage<S> { impl<S> SqliteStorage<S> {
pub async fn open( pub async fn open(
path: &str, path: &str,
serializer: S, serializer: S,
) -> Result<Arc<Self>, SqliteStorageError<Infallible>> { ) -> Result<Arc<Self>, SqliteStorageError<Infallible>> {
let mut conn = let pool = SqlitePool::connect(format!("sqlite:{}?mode=rwc", path).as_str()).await?;
SqliteConnectOptions::new().filename(path).create_if_missing(true).connect().await?; let mut conn = pool.acquire().await?;
sqlx::query(
// TODO: think about a schema migration mechanism.
conn.execute(
r#" r#"
CREATE TABLE IF NOT EXISTS teloxide_dialogues ( CREATE TABLE IF NOT EXISTS teloxide_dialogues (
chat_id BIGINT PRIMARY KEY, chat_id BIGINT PRIMARY KEY,
@ -46,9 +48,26 @@ CREATE TABLE IF NOT EXISTS teloxide_dialogues (
); );
"#, "#,
) )
.execute(&mut conn)
.await?; .await?;
Ok(Arc::new(Self { conn: Mutex::new(conn), serializer })) Ok(Arc::new(Self { pool, serializer }))
}
}
async fn get_dialogue(
pool: &SqlitePool,
chat_id: i64,
) -> Result<Option<Box<Vec<u8>>>, sqlx::Error> {
match sqlx::query_as::<_, DialogueDBRow>(
"SELECT dialogue FROM teloxide_dialogues WHERE chat_id = ?",
)
.bind(chat_id)
.fetch_optional(pool)
.await?
{
Some(r) => Ok(Some(Box::new(r.dialogue))),
_ => Ok(None),
} }
} }
@ -65,14 +84,18 @@ where
chat_id: i64, chat_id: i64,
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> { ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
Box::pin(async move { Box::pin(async move {
self.conn match get_dialogue(&self.pool, chat_id).await? {
.lock() None => Ok(None),
.await Some(d) => {
.execute( let prev_dialogue =
sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?").bind(chat_id), self.serializer.deserialize(&d).map_err(SqliteStorageError::SerdeError)?;
) sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?")
.await?; .bind(chat_id)
Ok(None) .execute(&self.pool)
.await?;
Ok(Some(prev_dialogue))
}
}
}) })
} }
@ -82,23 +105,31 @@ where
dialogue: D, dialogue: D,
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> { ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
Box::pin(async move { Box::pin(async move {
let dialogue = let serialized_dialogue =
self.serializer.serialize(&dialogue).map_err(SqliteStorageError::SerdeError)?; self.serializer.serialize(&dialogue).map_err(SqliteStorageError::SerdeError)?;
self.conn let prev_dialogue = get_dialogue(&self.pool, chat_id).await?;
.lock()
.await self.pool
.acquire()
.await?
.execute( .execute(
sqlx::query( sqlx::query(
r#" r#"
INSERT INTO teloxide_dialogues VALUES (?, ?) WHERE chat_id = ? INSERT INTO teloxide_dialogues VALUES (?, ?) WHERE chat_id = ?
ON CONFLICT(chat_id) DO UPDATE SET dialogue=excluded.dialogue ON CONFLICT(chat_id) DO UPDATE SET dialogue=excluded.dialogue
"#, "#,
) )
.bind(chat_id) .bind(chat_id)
.bind(dialogue), .bind(serialized_dialogue),
) )
.await?; .await?;
Ok(None)
Ok(match prev_dialogue {
None => None,
Some(d) => {
Some(self.serializer.deserialize(&d).map_err(SqliteStorageError::SerdeError)?)
}
})
}) })
} }
} }