Properly implement SqliteStorage methods

This commit is contained in:
Sergey Levitin 2020-10-22 21:30:34 +03:00
parent fb996d943d
commit 16b0b47ecf

View file

@ -1,15 +1,14 @@
use super::{serializer::Serializer, Storage};
use futures::future::BoxFuture;
use serde::{de::DeserializeOwned, Serialize};
use sqlx::sqlite::{SqliteConnectOptions, SqliteConnection};
use sqlx::{ConnectOptions, Executor};
use sqlx::sqlite::SqlitePool;
use sqlx::Executor;
use std::{
convert::Infallible,
fmt::{Debug, Display},
sync::Arc,
};
use thiserror::Error;
use tokio::sync::Mutex;
// An error returned from [`SqliteStorage`].
#[derive(Debug, Error)]
@ -25,20 +24,23 @@ where
// TODO: make JSON serializer to be default
pub struct SqliteStorage<S> {
conn: Mutex<SqliteConnection>,
pool: SqlitePool,
serializer: S,
}
#[derive(sqlx::FromRow)]
struct DialogueDBRow {
dialogue: Vec<u8>,
}
impl<S> SqliteStorage<S> {
pub async fn open(
path: &str,
serializer: S,
) -> Result<Arc<Self>, SqliteStorageError<Infallible>> {
let mut conn =
SqliteConnectOptions::new().filename(path).create_if_missing(true).connect().await?;
// TODO: think about a schema migration mechanism.
conn.execute(
let pool = SqlitePool::connect(format!("sqlite:{}?mode=rwc", path).as_str()).await?;
let mut conn = pool.acquire().await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS teloxide_dialogues (
chat_id BIGINT PRIMARY KEY,
@ -46,9 +48,26 @@ CREATE TABLE IF NOT EXISTS teloxide_dialogues (
);
"#,
)
.execute(&mut conn)
.await?;
Ok(Arc::new(Self { conn: Mutex::new(conn), serializer }))
Ok(Arc::new(Self { pool, serializer }))
}
}
async fn get_dialogue(
pool: &SqlitePool,
chat_id: i64,
) -> Result<Option<Box<Vec<u8>>>, sqlx::Error> {
match sqlx::query_as::<_, DialogueDBRow>(
"SELECT dialogue FROM teloxide_dialogues WHERE chat_id = ?",
)
.bind(chat_id)
.fetch_optional(pool)
.await?
{
Some(r) => Ok(Some(Box::new(r.dialogue))),
_ => Ok(None),
}
}
@ -65,14 +84,18 @@ where
chat_id: i64,
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
Box::pin(async move {
self.conn
.lock()
.await
.execute(
sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?").bind(chat_id),
)
.await?;
Ok(None)
match get_dialogue(&self.pool, chat_id).await? {
None => Ok(None),
Some(d) => {
let prev_dialogue =
self.serializer.deserialize(&d).map_err(SqliteStorageError::SerdeError)?;
sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?")
.bind(chat_id)
.execute(&self.pool)
.await?;
Ok(Some(prev_dialogue))
}
}
})
}
@ -82,23 +105,31 @@ where
dialogue: D,
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
Box::pin(async move {
let dialogue =
let serialized_dialogue =
self.serializer.serialize(&dialogue).map_err(SqliteStorageError::SerdeError)?;
self.conn
.lock()
.await
let prev_dialogue = get_dialogue(&self.pool, chat_id).await?;
self.pool
.acquire()
.await?
.execute(
sqlx::query(
r#"
INSERT INTO teloxide_dialogues VALUES (?, ?) WHERE chat_id = ?
ON CONFLICT(chat_id) DO UPDATE SET dialogue=excluded.dialogue
"#,
INSERT INTO teloxide_dialogues VALUES (?, ?) WHERE chat_id = ?
ON CONFLICT(chat_id) DO UPDATE SET dialogue=excluded.dialogue
"#,
)
.bind(chat_id)
.bind(dialogue),
.bind(serialized_dialogue),
)
.await?;
Ok(None)
Ok(match prev_dialogue {
None => None,
Some(d) => {
Some(self.serializer.deserialize(&d).map_err(SqliteStorageError::SerdeError)?)
}
})
})
}
}