mirror of
https://github.com/teloxide/teloxide.git
synced 2024-10-24 01:47:08 +02:00
Properly implement SqliteStorage methods
This commit is contained in:
parent
fb996d943d
commit
16b0b47ecf
1 changed files with 58 additions and 27 deletions
|
@ -1,15 +1,14 @@
|
||||||
use super::{serializer::Serializer, Storage};
|
use super::{serializer::Serializer, Storage};
|
||||||
use futures::future::BoxFuture;
|
use futures::future::BoxFuture;
|
||||||
use serde::{de::DeserializeOwned, Serialize};
|
use serde::{de::DeserializeOwned, Serialize};
|
||||||
use sqlx::sqlite::{SqliteConnectOptions, SqliteConnection};
|
use sqlx::sqlite::SqlitePool;
|
||||||
use sqlx::{ConnectOptions, Executor};
|
use sqlx::Executor;
|
||||||
use std::{
|
use std::{
|
||||||
convert::Infallible,
|
convert::Infallible,
|
||||||
fmt::{Debug, Display},
|
fmt::{Debug, Display},
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::Mutex;
|
|
||||||
|
|
||||||
// An error returned from [`SqliteStorage`].
|
// An error returned from [`SqliteStorage`].
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
|
@ -25,20 +24,23 @@ where
|
||||||
|
|
||||||
// TODO: make JSON serializer to be default
|
// TODO: make JSON serializer to be default
|
||||||
pub struct SqliteStorage<S> {
|
pub struct SqliteStorage<S> {
|
||||||
conn: Mutex<SqliteConnection>,
|
pool: SqlitePool,
|
||||||
serializer: S,
|
serializer: S,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(sqlx::FromRow)]
|
||||||
|
struct DialogueDBRow {
|
||||||
|
dialogue: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
impl<S> SqliteStorage<S> {
|
impl<S> SqliteStorage<S> {
|
||||||
pub async fn open(
|
pub async fn open(
|
||||||
path: &str,
|
path: &str,
|
||||||
serializer: S,
|
serializer: S,
|
||||||
) -> Result<Arc<Self>, SqliteStorageError<Infallible>> {
|
) -> Result<Arc<Self>, SqliteStorageError<Infallible>> {
|
||||||
let mut conn =
|
let pool = SqlitePool::connect(format!("sqlite:{}?mode=rwc", path).as_str()).await?;
|
||||||
SqliteConnectOptions::new().filename(path).create_if_missing(true).connect().await?;
|
let mut conn = pool.acquire().await?;
|
||||||
|
sqlx::query(
|
||||||
// TODO: think about a schema migration mechanism.
|
|
||||||
conn.execute(
|
|
||||||
r#"
|
r#"
|
||||||
CREATE TABLE IF NOT EXISTS teloxide_dialogues (
|
CREATE TABLE IF NOT EXISTS teloxide_dialogues (
|
||||||
chat_id BIGINT PRIMARY KEY,
|
chat_id BIGINT PRIMARY KEY,
|
||||||
|
@ -46,9 +48,26 @@ CREATE TABLE IF NOT EXISTS teloxide_dialogues (
|
||||||
);
|
);
|
||||||
"#,
|
"#,
|
||||||
)
|
)
|
||||||
|
.execute(&mut conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(Arc::new(Self { conn: Mutex::new(conn), serializer }))
|
Ok(Arc::new(Self { pool, serializer }))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_dialogue(
|
||||||
|
pool: &SqlitePool,
|
||||||
|
chat_id: i64,
|
||||||
|
) -> Result<Option<Box<Vec<u8>>>, sqlx::Error> {
|
||||||
|
match sqlx::query_as::<_, DialogueDBRow>(
|
||||||
|
"SELECT dialogue FROM teloxide_dialogues WHERE chat_id = ?",
|
||||||
|
)
|
||||||
|
.bind(chat_id)
|
||||||
|
.fetch_optional(pool)
|
||||||
|
.await?
|
||||||
|
{
|
||||||
|
Some(r) => Ok(Some(Box::new(r.dialogue))),
|
||||||
|
_ => Ok(None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,14 +84,18 @@ where
|
||||||
chat_id: i64,
|
chat_id: i64,
|
||||||
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
|
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
self.conn
|
match get_dialogue(&self.pool, chat_id).await? {
|
||||||
.lock()
|
None => Ok(None),
|
||||||
.await
|
Some(d) => {
|
||||||
.execute(
|
let prev_dialogue =
|
||||||
sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?").bind(chat_id),
|
self.serializer.deserialize(&d).map_err(SqliteStorageError::SerdeError)?;
|
||||||
)
|
sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?")
|
||||||
.await?;
|
.bind(chat_id)
|
||||||
Ok(None)
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
Ok(Some(prev_dialogue))
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,23 +105,31 @@ where
|
||||||
dialogue: D,
|
dialogue: D,
|
||||||
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
|
) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let dialogue =
|
let serialized_dialogue =
|
||||||
self.serializer.serialize(&dialogue).map_err(SqliteStorageError::SerdeError)?;
|
self.serializer.serialize(&dialogue).map_err(SqliteStorageError::SerdeError)?;
|
||||||
self.conn
|
let prev_dialogue = get_dialogue(&self.pool, chat_id).await?;
|
||||||
.lock()
|
|
||||||
.await
|
self.pool
|
||||||
|
.acquire()
|
||||||
|
.await?
|
||||||
.execute(
|
.execute(
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
INSERT INTO teloxide_dialogues VALUES (?, ?) WHERE chat_id = ?
|
INSERT INTO teloxide_dialogues VALUES (?, ?) WHERE chat_id = ?
|
||||||
ON CONFLICT(chat_id) DO UPDATE SET dialogue=excluded.dialogue
|
ON CONFLICT(chat_id) DO UPDATE SET dialogue=excluded.dialogue
|
||||||
"#,
|
"#,
|
||||||
)
|
)
|
||||||
.bind(chat_id)
|
.bind(chat_id)
|
||||||
.bind(dialogue),
|
.bind(serialized_dialogue),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(None)
|
|
||||||
|
Ok(match prev_dialogue {
|
||||||
|
None => None,
|
||||||
|
Some(d) => {
|
||||||
|
Some(self.serializer.deserialize(&d).map_err(SqliteStorageError::SerdeError)?)
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue