From 01a4c88d7a42a6ce410833e23a1a27d339a6d316 Mon Sep 17 00:00:00 2001
From: David Pedersen <david.pdrsn@gmail.com>
Date: Fri, 1 Apr 2022 16:37:55 +0200
Subject: [PATCH] Add `PrivateCookieJar` (#900)

---
 axum-extra/CHANGELOG.md                  |   1 +
 axum-extra/Cargo.toml                    |   2 +-
 axum-extra/src/extract/cookie.rs         |   7 +-
 axum-extra/src/extract/cookie/private.rs | 220 +++++++++++++++++++++++
 axum-extra/src/extract/mod.rs            |   2 +-
 5 files changed, 229 insertions(+), 3 deletions(-)
 create mode 100644 axum-extra/src/extract/cookie/private.rs

diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md
index b9dba964..3eac0ab9 100644
--- a/axum-extra/CHANGELOG.md
+++ b/axum-extra/CHANGELOG.md
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning].
 
 - **added:** Re-export `SameSite` and `Expiration` from the `cookie` crate.
 - **fixed:** Fix `SignedCookieJar` when using custom key types ([#899])
+- **added:** `PrivateCookieJar` for managing private cookies
 
 [#899]: https://github.com/tokio-rs/axum/pull/899
 
diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml
index 42411895..63b2f2ec 100644
--- a/axum-extra/Cargo.toml
+++ b/axum-extra/Cargo.toml
@@ -32,7 +32,7 @@ axum-macros = { path = "../axum-macros", version = "0.2", optional = true }
 serde = { version = "1.0", optional = true }
 serde_json = { version = "1.0.71", optional = true }
 percent-encoding = { version = "2.1", optional = true }
-cookie-lib = { package = "cookie", version = "0.16", features = ["percent-encode", "signed"], optional = true }
+cookie-lib = { package = "cookie", version = "0.16", features = ["percent-encode", "signed", "private"], optional = true }
 
 [dev-dependencies]
 axum = { path = "../axum", version = "0.5", features = ["headers"] }
diff --git a/axum-extra/src/extract/cookie.rs b/axum-extra/src/extract/cookie.rs
index 56cfd399..096c6716 100644
--- a/axum-extra/src/extract/cookie.rs
+++ b/axum-extra/src/extract/cookie.rs
@@ -1,6 +1,6 @@
 //! Cookie parsing and cookie jar management.
 //!
-//! See [`CookieJar`] and [`SignedCookieJar`] for more details.
+//! See [`CookieJar`], [`SignedCookieJar`], and [`PrivateCookieJar`] for more details.
 
 use axum::{
     async_trait,
@@ -15,6 +15,9 @@ use http::{
 };
 use std::{convert::Infallible, fmt, marker::PhantomData};
 
+mod private;
+
+pub use self::private::PrivateCookieJar;
 pub use cookie_lib::{Cookie, Expiration, Key, SameSite};
 
 /// Extractor that grabs cookies from the request and manages the jar.
@@ -494,6 +497,8 @@ mod tests {
     cookie_test!(plaintext_cookies, CookieJar);
     cookie_test!(signed_cookies, SignedCookieJar);
     cookie_test!(signed_cookies_with_custom_key, SignedCookieJar<CustomKey>);
+    cookie_test!(private_cookies, PrivateCookieJar);
+    cookie_test!(private_cookies_with_custom_key, PrivateCookieJar<CustomKey>);
 
     #[derive(Clone)]
     struct CustomKey(Key);
diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs
new file mode 100644
index 00000000..a896d110
--- /dev/null
+++ b/axum-extra/src/extract/cookie/private.rs
@@ -0,0 +1,220 @@
+use super::{cookies_from_request, set_cookies, Cookie, Key};
+use axum::{
+    async_trait,
+    extract::{FromRequest, RequestParts},
+    response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
+    Extension,
+};
+use cookie_lib::PrivateJar;
+use std::{convert::Infallible, fmt, marker::PhantomData};
+
+/// Extractor that grabs private cookies from the request and manages the jar.
+///
+/// All cookies will be private and encrypted with a [`Key`]. This makes it suitable for storing
+/// private data.
+///
+/// Note that methods like [`PrivateCookieJar::add`], [`PrivateCookieJar::remove`], etc updates the
+/// [`PrivateCookieJar`] and returns it. This value _must_ be returned from the handler as part of
+/// the response for the changes to be propagated.
+///
+/// # Example
+///
+/// ```rust
+/// use axum::{
+///     Router,
+///     Extension,
+///     routing::{post, get},
+///     extract::TypedHeader,
+///     response::{IntoResponse, Redirect},
+///     headers::authorization::{Authorization, Bearer},
+///     http::StatusCode,
+/// };
+/// use axum_extra::extract::cookie::{PrivateCookieJar, Cookie, Key};
+///
+/// async fn set_secret(
+///     jar: PrivateCookieJar,
+/// ) -> impl IntoResponse {
+///     let updated_jar = jar.add(Cookie::new("secret", "secret-data"));
+///     (updated_jar, Redirect::to("/get"))
+/// }
+///
+/// async fn get_secret(jar: PrivateCookieJar) -> impl IntoResponse {
+///     if let Some(data) = jar.get("secret") {
+///         // ...
+///     }
+/// }
+///
+/// // Generate a secure key
+/// //
+/// // You probably don't wanna generate a new one each time the app starts though
+/// let key = Key::generate();
+///
+/// let app = Router::new()
+///     .route("/set", post(set_secret))
+///     .route("/get", get(get_secret))
+///     // add extension with the key so `PrivateCookieJar` can access it
+///     .layer(Extension(key));
+/// # let app: Router<axum::body::Body> = app;
+/// ```
+pub struct PrivateCookieJar<K = Key> {
+    jar: cookie_lib::CookieJar,
+    key: Key,
+    // The key used to extract the key extension. Allows users to use multiple keys for different
+    // jars. Maybe a library wants its own key.
+    _marker: PhantomData<K>,
+}
+
+impl<K> fmt::Debug for PrivateCookieJar<K> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("PrivateCookieJar")
+            .field("jar", &self.jar)
+            .field("key", &"REDACTED")
+            .finish()
+    }
+}
+
+#[async_trait]
+impl<B, K> FromRequest<B> for PrivateCookieJar<K>
+where
+    B: Send,
+    K: Into<Key> + Clone + Send + Sync + 'static,
+{
+    type Rejection = <axum::Extension<K> as FromRequest<B>>::Rejection;
+
+    async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+        let key = Extension::<K>::from_request(req).await?.0.into();
+
+        let mut jar = cookie_lib::CookieJar::new();
+        let mut private_jar = jar.private_mut(&key);
+        for cookie in cookies_from_request(req) {
+            if let Some(cookie) = private_jar.decrypt(cookie) {
+                private_jar.add_original(cookie);
+            }
+        }
+
+        Ok(Self {
+            jar,
+            key,
+            _marker: PhantomData,
+        })
+    }
+}
+
+impl<K> PrivateCookieJar<K> {
+    /// Get a cookie from the jar.
+    ///
+    /// If the cookie exists and can be decrypted then it is returned in plaintext.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    /// use axum_extra::extract::cookie::PrivateCookieJar;
+    /// use axum::response::IntoResponse;
+    ///
+    /// async fn handle(jar: PrivateCookieJar) {
+    ///     let value: Option<String> = jar
+    ///         .get("foo")
+    ///         .map(|cookie| cookie.value().to_owned());
+    /// }
+    /// ```
+    pub fn get(&self, name: &str) -> Option<Cookie<'static>> {
+        self.private_jar().get(name)
+    }
+
+    /// Remove a cookie from the jar.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    /// use axum_extra::extract::cookie::{PrivateCookieJar, Cookie};
+    /// use axum::response::IntoResponse;
+    ///
+    /// async fn handle(jar: PrivateCookieJar) -> impl IntoResponse {
+    ///     jar.remove(Cookie::named("foo"))
+    /// }
+    /// ```
+    #[must_use]
+    pub fn remove(mut self, cookie: Cookie<'static>) -> Self {
+        self.private_jar_mut().remove(cookie);
+        self
+    }
+
+    /// Add a cookie to the jar.
+    ///
+    /// The value will automatically be percent-encoded.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    /// use axum_extra::extract::cookie::{PrivateCookieJar, Cookie};
+    /// use axum::response::IntoResponse;
+    ///
+    /// async fn handle(jar: PrivateCookieJar) -> impl IntoResponse {
+    ///     jar.add(Cookie::new("foo", "bar"))
+    /// }
+    /// ```
+    #[must_use]
+    #[allow(clippy::should_implement_trait)]
+    pub fn add(mut self, cookie: Cookie<'static>) -> Self {
+        self.private_jar_mut().add(cookie);
+        self
+    }
+
+    /// Authenticates and decrypts `cookie`, returning the plaintext version if decryption succeeds
+    /// or `None` otherwise.
+    pub fn decrypt(&self, cookie: Cookie<'static>) -> Option<Cookie<'static>> {
+        self.private_jar().decrypt(cookie)
+    }
+
+    /// Get an iterator over all cookies in the jar.
+    ///
+    /// Only cookies with valid authenticity and integrity are yielded by the iterator.
+    pub fn iter(&self) -> impl Iterator<Item = Cookie<'static>> + '_ {
+        PrivateCookieJarIter {
+            jar: self,
+            iter: self.jar.iter(),
+        }
+    }
+
+    fn private_jar(&self) -> PrivateJar<&'_ cookie_lib::CookieJar> {
+        self.jar.private(&self.key)
+    }
+
+    fn private_jar_mut(&mut self) -> PrivateJar<&'_ mut cookie_lib::CookieJar> {
+        self.jar.private_mut(&self.key)
+    }
+}
+
+impl<K> IntoResponseParts for PrivateCookieJar<K> {
+    type Error = Infallible;
+
+    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
+        set_cookies(self.jar, res.headers_mut());
+        Ok(res)
+    }
+}
+
+impl<K> IntoResponse for PrivateCookieJar<K> {
+    fn into_response(self) -> Response {
+        (self, ()).into_response()
+    }
+}
+
+struct PrivateCookieJarIter<'a, K> {
+    jar: &'a PrivateCookieJar<K>,
+    iter: cookie_lib::Iter<'a>,
+}
+
+impl<'a, K> Iterator for PrivateCookieJarIter<'a, K> {
+    type Item = Cookie<'static>;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        loop {
+            let cookie = self.iter.next()?;
+
+            if let Some(cookie) = self.jar.get(cookie.name()) {
+                return Some(cookie);
+            }
+        }
+    }
+}
diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs
index 77f02eda..3db1248c 100644
--- a/axum-extra/src/extract/mod.rs
+++ b/axum-extra/src/extract/mod.rs
@@ -7,4 +7,4 @@ pub mod cookie;
 pub use self::cached::Cached;
 
 #[cfg(feature = "cookie")]
-pub use self::cookie::{CookieJar, SignedCookieJar};
+pub use self::cookie::{CookieJar, PrivateCookieJar, SignedCookieJar};