diff --git a/Cargo.toml b/Cargo.toml index d2dc00d2..d81f3cc0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,7 @@ serde_with_macros = "1.1.0" sqlx = { version = "0.4.0-beta.1", optional = true, default-features = false, features = [ "runtime-tokio", "macros", - "sqlite" + "sqlite", ] } redis = { version = "0.16.0", optional = true } serde_cbor = { version = "0.11.1", optional = true } diff --git a/src/dispatching/dialogue/mod.rs b/src/dispatching/dialogue/mod.rs index 45f67e06..32f50dd5 100644 --- a/src/dispatching/dialogue/mod.rs +++ b/src/dispatching/dialogue/mod.rs @@ -165,6 +165,6 @@ pub use teloxide_macros::Transition; pub use storage::{RedisStorage, RedisStorageError}; #[cfg(feature = "sqlite-storage")] -pub use storage::{SqliteStorage, SqliteStorageLocation, SqliteStorageError}; +pub use storage::{SqliteStorage, SqliteStorageError}; pub use storage::{serializer, InMemStorage, Serializer, Storage}; diff --git a/src/dispatching/dialogue/storage/mod.rs b/src/dispatching/dialogue/storage/mod.rs index 7fdf79d1..175ad183 100644 --- a/src/dispatching/dialogue/storage/mod.rs +++ b/src/dispatching/dialogue/storage/mod.rs @@ -17,7 +17,7 @@ pub use serializer::Serializer; use std::sync::Arc; #[cfg(feature = "sqlite-storage")] -pub use sqlite_storage::{SqliteStorage, SqliteStorageLocation, SqliteStorageError}; +pub use sqlite_storage::{SqliteStorage, SqliteStorageError}; /// A storage of dialogues. /// diff --git a/src/dispatching/dialogue/storage/sqlite_storage.rs b/src/dispatching/dialogue/storage/sqlite_storage.rs index 6356b4c7..32e7fd1f 100644 --- a/src/dispatching/dialogue/storage/sqlite_storage.rs +++ b/src/dispatching/dialogue/storage/sqlite_storage.rs @@ -1,18 +1,15 @@ -// use super::{serializer::Serializer, Storage}; -// use futures::future::BoxFuture; +use super::{serializer::Serializer, Storage}; +use futures::future::BoxFuture; +use serde::{de::DeserializeOwned, Serialize}; +use sqlx::sqlite::{SqliteConnectOptions, SqliteConnection}; +use sqlx::{ConnectOptions, Executor}; use std::{ convert::Infallible, fmt::{Debug, Display}, + sync::Arc, }; -use sqlx::sqlite::SqlitePool; -// use serde::{de::DeserializeOwned, Serialize}; use thiserror::Error; -// use tokio::task::block_in_place; - -pub enum SqliteStorageLocation { - InMemory, - Path(String), -} +use tokio::sync::Mutex; // An error returned from [`SqliteStorage`]. #[derive(Debug, Error)] @@ -20,54 +17,88 @@ pub enum SqliteStorageError where SE: Debug + Display, { - #[error("parsing/serializing error: {0}")] + #[error("dialogue serialization error: {0}")] SerdeError(SE), - #[error("error from Sqlite: {0}")] - SqliteError(Box), + #[error("sqlite error: {0}")] + SqliteError(#[from] sqlx::Error), } +// TODO: make JSON serializer to be default pub struct SqliteStorage { - conn: SqlitePool, + conn: Mutex, serializer: S, } -impl SqliteStorage { +impl SqliteStorage { pub async fn open( - path: SqliteStorageLocation, + path: &str, serializer: S, - ) -> Result>{ - let url = match path { - SqliteStorageLocation::InMemory => String::from("sqlite::memory:"), - SqliteStorageLocation::Path(p) => p, - }; - Ok(Self { - conn: SqlitePool::connect(&url[..]).await - .expect("Impossible sqlite error"), - serializer, - }) + ) -> Result, SqliteStorageError> { + let mut conn = + SqliteConnectOptions::new().filename(path).create_if_missing(true).c§onnect().await?; + + // TODO: think about a schema migration mechanism. + conn.execute( + r#" +CREATE TABLE IF NOT EXISTS teloxide_dialogues ( + chat_id BIGINT PRIMARY KEY, + dialogue BLOB NOT NULL +); + "#, + ) + .await?; + + Ok(Arc::new(Self { conn: Mutex::new(conn), serializer })) } } -// impl Storage for SqliteStorage -// where -// S: Send + Sync + Serializer + 'static, -// D: Send + Serialize + DeserializeOwned + 'static, -// >::Error: Debug + Display, -// { -// type Error = SqliteStorageError<>::Error>; +impl Storage for SqliteStorage +where + S: Send + Sync + Serializer + 'static, + D: Send + Serialize + DeserializeOwned + 'static, + >::Error: Debug + Display, +{ + type Error = SqliteStorageError<>::Error>; -// fn remove_dialogue( -// self: Arc, -// chat_id: i64, -// ) -> BoxFuture<'static, Result, Self::Error>> { -// Box::pin(async move { -// todo!() -// }); -// } + fn remove_dialogue( + self: Arc, + chat_id: i64, + ) -> BoxFuture<'static, Result, Self::Error>> { + Box::pin(async move { + self.conn + .lock() + .await + .execute( + sqlx::query("DELETE FROM teloxide_dialogues WHERE chat_id = ?").bind(chat_id), + ) + .await?; + Ok(None) + }) + } -// fn update_dialogue( -// self: Arc, -// chat_id: i64, -// dialogue: D -// ) { todo!() } -// } \ No newline at end of file + fn update_dialogue( + self: Arc, + chat_id: i64, + dialogue: D, + ) -> BoxFuture<'static, Result, Self::Error>> { + Box::pin(async move { + let dialogue = + self.serializer.serialize(&dialogue).map_err(SqliteStorageError::SerdeError)?; + self.conn + .lock() + .await + .execute( + sqlx::query( + r#" +INSERT INTO teloxide_dialogues VALUES (?, ?) WHERE chat_id = ? +ON CONFLICT(chat_id) DO UPDATE SET dialogue=excluded.dialogue + "#, + ) + .bind(chat_id) + .bind(dialogue), + ) + .await?; + Ok(None) + }) + } +} diff --git a/tests/sqlite.rs b/tests/sqlite.rs deleted file mode 100644 index aac879ff..00000000 --- a/tests/sqlite.rs +++ /dev/null @@ -1,19 +0,0 @@ -use teloxide::dispatching::dialogue::{ - serializer::{Bincode, CBOR, JSON}, - SqliteStorage, SqliteStorageLocation::InMemory -}; - -#[tokio::test] -async fn test_sqlite_json() { - let _storage = SqliteStorage::open(InMemory, JSON).await.unwrap(); -} - -#[tokio::test] -async fn test_sqlite_cbor() { - let _storage = SqliteStorage::open(InMemory, CBOR).await.unwrap(); -} - -#[tokio::test] -async fn test_sqlite_bincode() { - let _storage = SqliteStorage::open(InMemory, Bincode).await.unwrap(); -} \ No newline at end of file