From d703e6f97a0156177466b6741be0beac0c83d8c7 Mon Sep 17 00:00:00 2001
From: drbh <david.richard.holtz@gmail.com>
Date: Sun, 11 Feb 2024 06:27:12 -0500
Subject: [PATCH] feat: add simple tokio redis example (#2543)

---
 examples/tokio-redis/Cargo.toml  |  14 ++++
 examples/tokio-redis/src/main.rs | 106 +++++++++++++++++++++++++++++++
 2 files changed, 120 insertions(+)
 create mode 100644 examples/tokio-redis/Cargo.toml
 create mode 100644 examples/tokio-redis/src/main.rs

diff --git a/examples/tokio-redis/Cargo.toml b/examples/tokio-redis/Cargo.toml
new file mode 100644
index 00000000..fb276849
--- /dev/null
+++ b/examples/tokio-redis/Cargo.toml
@@ -0,0 +1,14 @@
+[package]
+name = "example-tokio-redis"
+version = "0.1.0"
+edition = "2021"
+publish = false
+
+[dependencies]
+axum = { path = "../../axum" }
+bb8 = "0.7.1"
+bb8-redis = "0.14.0"
+redis = "0.24.0"
+tokio = { version = "1.0", features = ["full"] }
+tracing = "0.1"
+tracing-subscriber = { version = "0.3", features = ["env-filter"] }
diff --git a/examples/tokio-redis/src/main.rs b/examples/tokio-redis/src/main.rs
new file mode 100644
index 00000000..f0109f21
--- /dev/null
+++ b/examples/tokio-redis/src/main.rs
@@ -0,0 +1,106 @@
+//! Run with
+//!
+//! ```not_rust
+//! cargo run -p example-tokio-redis
+//! ```
+
+use axum::{
+    async_trait,
+    extract::{FromRef, FromRequestParts, State},
+    http::{request::Parts, StatusCode},
+    routing::get,
+    Router,
+};
+use bb8::{Pool, PooledConnection};
+use bb8_redis::RedisConnectionManager;
+use redis::AsyncCommands;
+use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
+
+use bb8_redis::bb8;
+
+#[tokio::main]
+async fn main() {
+    tracing_subscriber::registry()
+        .with(
+            tracing_subscriber::EnvFilter::try_from_default_env()
+                .unwrap_or_else(|_| "example_tokio_redis=debug".into()),
+        )
+        .with(tracing_subscriber::fmt::layer())
+        .init();
+
+    tracing::debug!("connecting to redis");
+    let manager = RedisConnectionManager::new("redis://localhost").unwrap();
+    let pool = bb8::Pool::builder().build(manager).await.unwrap();
+
+    {
+        // ping the database before starting
+        let mut conn = pool.get().await.unwrap();
+        conn.set::<&str, &str, ()>("foo", "bar").await.unwrap();
+        let result: String = conn.get("foo").await.unwrap();
+        assert_eq!(result, "bar");
+    }
+    tracing::debug!("successfully connected to redis and pinged it");
+
+    // build our application with some routes
+    let app = Router::new()
+        .route(
+            "/",
+            get(using_connection_pool_extractor).post(using_connection_extractor),
+        )
+        .with_state(pool);
+
+    // run it
+    let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
+        .await
+        .unwrap();
+    tracing::debug!("listening on {}", listener.local_addr().unwrap());
+    axum::serve(listener, app).await.unwrap();
+}
+
+type ConnectionPool = Pool<RedisConnectionManager>;
+
+async fn using_connection_pool_extractor(
+    State(pool): State<ConnectionPool>,
+) -> Result<String, (StatusCode, String)> {
+    let mut conn = pool.get().await.map_err(internal_error)?;
+    let result: String = conn.get("foo").await.map_err(internal_error)?;
+    Ok(result)
+}
+
+// we can also write a custom extractor that grabs a connection from the pool
+// which setup is appropriate depends on your application
+struct DatabaseConnection(PooledConnection<'static, RedisConnectionManager>);
+
+#[async_trait]
+impl<S> FromRequestParts<S> for DatabaseConnection
+where
+    ConnectionPool: FromRef<S>,
+    S: Send + Sync,
+{
+    type Rejection = (StatusCode, String);
+
+    async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
+        let pool = ConnectionPool::from_ref(state);
+
+        let conn = pool.get_owned().await.map_err(internal_error)?;
+
+        Ok(Self(conn))
+    }
+}
+
+async fn using_connection_extractor(
+    DatabaseConnection(mut conn): DatabaseConnection,
+) -> Result<String, (StatusCode, String)> {
+    let result: String = conn.get("foo").await.map_err(internal_error)?;
+
+    Ok(result)
+}
+
+/// Utility function for mapping any error into a `500 Internal Server Error`
+/// response.
+fn internal_error<E>(err: E) -> (StatusCode, String)
+where
+    E: std::error::Error,
+{
+    (StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
+}