Make Serializer a trait, so anyone can implement it

This commit is contained in:
Maximilian Siling 2020-04-19 20:06:49 +03:00
parent 478e7038a6
commit 82d0958c91
4 changed files with 102 additions and 78 deletions

View file

@ -23,7 +23,7 @@ use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::sync::Arc;
use teloxide::{ use teloxide::{
dispatching::dialogue::{RedisStorage, Serializer, Storage}, dispatching::dialogue::{serializer::Bincode, RedisStorage, Storage},
prelude::*, prelude::*,
types::{KeyboardButton, ReplyKeyboardMarkup}, types::{KeyboardButton, ReplyKeyboardMarkup},
}; };
@ -94,7 +94,7 @@ enum Dialogue {
type Cx<State> = DialogueDispatcherHandlerCx< type Cx<State> = DialogueDispatcherHandlerCx<
Message, Message,
State, State,
<RedisStorage as Storage<Dialogue>>::Error, <RedisStorage<Bincode> as Storage<Dialogue>>::Error,
>; >;
type Res = ResponseResult<DialogueStage<Dialogue>>; type Res = ResponseResult<DialogueStage<Dialogue>>;
@ -202,7 +202,7 @@ async fn run() {
// All serializer but JSON require enabling feature // All serializer but JSON require enabling feature
// "serializer-<name>", e. g. "serializer-cbor" // "serializer-<name>", e. g. "serializer-cbor"
// or "serializer-bincode" // or "serializer-bincode"
RedisStorage::open("redis://127.0.0.1:6379", Serializer::Bincode) RedisStorage::open("redis://127.0.0.1:6379", Bincode)
.await .await
.unwrap(), .unwrap(),
), ),

View file

@ -56,4 +56,4 @@ pub use dialogue_stage::{exit, next, DialogueStage};
pub use get_chat_id::GetChatId; pub use get_chat_id::GetChatId;
#[cfg(feature = "redis-storage")] #[cfg(feature = "redis-storage")]
pub use storage::RedisStorage; pub use storage::RedisStorage;
pub use storage::{InMemStorage, Serializer, Storage}; pub use storage::{serializer, InMemStorage, Serializer, Storage};

View file

@ -1,34 +1,37 @@
use super::{ use super::{serializer::Serializer, Storage};
serializer::{self, Serializer},
Storage,
};
use futures::future::BoxFuture; use futures::future::BoxFuture;
use redis::{AsyncCommands, FromRedisValue, IntoConnectionInfo}; use redis::{AsyncCommands, FromRedisValue, IntoConnectionInfo};
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use std::{ops::DerefMut, sync::Arc}; use std::{
convert::Infallible,
fmt::{Debug, Display},
ops::DerefMut,
sync::Arc,
};
use thiserror::Error; use thiserror::Error;
use tokio::sync::Mutex; use tokio::sync::Mutex;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum Error { pub enum Error<SE>
#[error("{0}")] where
SerdeError(#[from] serializer::Error), SE: Debug + Display,
{
#[error("parsing/serializing error: {0}")]
SerdeError(SE),
#[error("error from Redis: {0}")] #[error("error from Redis: {0}")]
RedisError(#[from] redis::RedisError), RedisError(#[from] redis::RedisError),
} }
type Result<T, E = Error> = std::result::Result<T, E>; pub struct RedisStorage<S> {
pub struct RedisStorage {
conn: Mutex<redis::aio::Connection>, conn: Mutex<redis::aio::Connection>,
serializer: Serializer, serializer: S,
} }
impl RedisStorage { impl<S> RedisStorage<S> {
pub async fn open( pub async fn open(
url: impl IntoConnectionInfo, url: impl IntoConnectionInfo,
serializer: Serializer, serializer: S,
) -> Result<Self> { ) -> Result<Self, Error<Infallible>> {
Ok(Self { Ok(Self {
conn: Mutex::new( conn: Mutex::new(
redis::Client::open(url)?.get_async_connection().await?, redis::Client::open(url)?.get_async_connection().await?,
@ -38,36 +41,42 @@ impl RedisStorage {
} }
} }
impl<D> Storage<D> for RedisStorage impl<S, D> Storage<D> for RedisStorage<S>
where where
S: Send + Sync + Serializer<D> + 'static,
D: Send + Serialize + DeserializeOwned + 'static, D: Send + Serialize + DeserializeOwned + 'static,
<S as Serializer<D>>::Error: Debug + Display,
{ {
type Error = Error; type Error = Error<<S as Serializer<D>>::Error>;
// `.del().ignore()` is much more readable than `.del()\n.ignore()` // `.del().ignore()` is much more readable than `.del()\n.ignore()`
#[rustfmt::skip] #[rustfmt::skip]
fn remove_dialogue( fn remove_dialogue(
self: Arc<Self>, self: Arc<Self>,
chat_id: i64, chat_id: i64,
) -> BoxFuture<'static, Result<Option<D>>> { ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
Box::pin(async move { Box::pin(async move {
let res = redis::pipe() let res = redis::pipe()
.atomic() .atomic()
.get(chat_id) .get(chat_id)
.del(chat_id).ignore() .del(chat_id).ignore()
.query_async::<_, redis::Value>(self.conn.lock().await.deref_mut()) .query_async::<_, redis::Value>(
self.conn.lock().await.deref_mut(),
)
.await?; .await?;
// We're expecting `.pipe()` to return us an exactly one result in bulk, // We're expecting `.pipe()` to return us an exactly one result in
// so all other branches should be unreachable // bulk, so all other branches should be unreachable
match res { match res {
redis::Value::Bulk(bulk) if bulk.len() == 1 => { redis::Value::Bulk(bulk) if bulk.len() == 1 => {
Ok( Ok(Option::<Vec<u8>>::from_redis_value(&bulk[0])?
Option::<Vec<u8>>::from_redis_value(&bulk[0])? .map(|v| {
.map(|v| self.serializer.deserialize(&v)) self.serializer
.transpose()? .deserialize(&v)
) .map_err(Error::SerdeError)
}, })
_ => unreachable!() .transpose()?)
}
_ => unreachable!(),
} }
}) })
} }
@ -76,16 +85,21 @@ where
self: Arc<Self>, self: Arc<Self>,
chat_id: i64, chat_id: i64,
dialogue: D, dialogue: D,
) -> BoxFuture<'static, Result<Option<D>>> { ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
Box::pin(async move { Box::pin(async move {
let dialogue = self.serializer.serialize(&dialogue)?; let dialogue = self
.serializer
.serialize(&dialogue)
.map_err(Error::SerdeError)?;
Ok(self Ok(self
.conn .conn
.lock() .lock()
.await .await
.getset::<_, Vec<u8>, Option<Vec<u8>>>(chat_id, dialogue) .getset::<_, Vec<u8>, Option<Vec<u8>>>(chat_id, dialogue)
.await? .await?
.map(|d| self.serializer.deserialize(&d)) .map(|d| {
self.serializer.deserialize(&d).map_err(Error::SerdeError)
})
.transpose()?) .transpose()?)
}) })
} }

View file

@ -1,53 +1,63 @@
use serde::{de::DeserializeOwned, ser::Serialize}; use serde::{de::DeserializeOwned, ser::Serialize};
use thiserror::Error;
use Serializer::*;
#[derive(Debug, Error)] pub trait Serializer<D> {
pub enum Error { type Error;
#[error("failed parsing/serializing JSON: {0}")]
JSONError(#[from] serde_json::Error), fn serialize(&self, val: &D) -> Result<Vec<u8>, Self::Error>;
#[cfg(feature = "cbor-serializer")] fn deserialize(&self, data: &[u8]) -> Result<D, Self::Error>;
#[error("failed parsing/serializing CBOR: {0}")]
CBORError(#[from] serde_cbor::Error),
#[cfg(feature = "bincode-serializer")]
#[error("failed parsing/serializing Bincode: {0}")]
BincodeError(#[from] bincode::Error),
} }
type Result<T, E = Error> = std::result::Result<T, E>; pub struct JSON;
pub enum Serializer { impl<D> Serializer<D> for JSON
JSON, where
#[cfg(feature = "cbor-serializer")] D: Serialize + DeserializeOwned,
CBOR, {
#[cfg(feature = "bincode-serializer")] type Error = serde_json::Error;
Bincode,
}
impl Serializer { fn serialize(&self, val: &D) -> Result<Vec<u8>, Self::Error> {
pub fn serialize<D>(&self, val: &D) -> Result<Vec<u8>> serde_json::to_vec(val)
where
D: Serialize,
{
Ok(match self {
JSON => serde_json::to_vec(val)?,
#[cfg(feature = "cbor-serializer")]
CBOR => serde_cbor::to_vec(val)?,
#[cfg(feature = "bincode-serializer")]
Bincode => bincode::serialize(val)?,
})
} }
pub fn deserialize<'de, D>(&self, data: &'de [u8]) -> Result<D> fn deserialize(&self, data: &[u8]) -> Result<D, Self::Error> {
where serde_json::from_slice(data)
D: DeserializeOwned, }
{ }
Ok(match self {
JSON => serde_json::from_slice(data)?, #[cfg(feature = "cbor-serializer")]
#[cfg(feature = "cbor-serializer")] pub struct CBOR;
CBOR => serde_cbor::from_slice(data)?,
#[cfg(feature = "bincode-serializer")] #[cfg(feature = "cbor-serializer")]
Bincode => bincode::deserialize(data)?, impl<D> Serializer<D> for CBOR
}) where
D: Serialize + DeserializeOwned,
{
type Error = serde_cbor::Error;
fn serialize(&self, val: &D) -> Result<Vec<u8>, Self::Error> {
serde_cbor::to_vec(val)
}
fn deserialize(&self, data: &[u8]) -> Result<D, Self::Error> {
serde_cbor::from_slice(data)
}
}
#[cfg(feature = "bincode-serializer")]
pub struct Bincode;
#[cfg(feature = "bincode-serializer")]
impl<D> Serializer<D> for Bincode
where
D: Serialize + DeserializeOwned,
{
type Error = bincode::Error;
fn serialize(&self, val: &D) -> Result<Vec<u8>, Self::Error> {
bincode::serialize(val)
}
fn deserialize(&self, data: &[u8]) -> Result<D, Self::Error> {
bincode::deserialize(data)
} }
} }