diff --git a/Cargo.toml b/Cargo.toml index fa0d4a70..5ff8f99c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,7 @@ tokio-postgres = "0.7.2" tracing = "0.1" tracing-subscriber = "0.2" uuid = { version = "0.8", features = ["serde", "v4"] } +async-session = "3.0.0" [dev-dependencies.tower] version = "0.4" diff --git a/deny.toml b/deny.toml index 966efc09..adc24c1a 100644 --- a/deny.toml +++ b/deny.toml @@ -37,12 +37,13 @@ multiple-versions = "deny" highlight = "all" skip-tree = [] skip = [ - # iri-string uses old version - # iri-string pulled in by tower-http - # PR to update tower-http is https://github.com/tower-rs/tower-http/pull/110 - { name = "nom", version = "=5.1.2" }, - # rustls uses old version + # rustls uses old version (dev dep) { name = "spin", version = "=0.5.2" }, + # tokio-postgres uses old version (dev dep) + { name = "hmac", version = "=0.10.1" }, + { name = "crypto-mac" }, + # async-session uses old version (dev dep) + { name = "cfg-if", version = "=0.1.10" }, ] [sources] diff --git a/examples/README.md b/examples/README.md index d1e3dc8d..8bf8ecde 100644 --- a/examples/README.md +++ b/examples/README.md @@ -13,3 +13,5 @@ - [`error_handling_and_dependency_injection`](../examples/error_handling_and_dependency_injection.rs) - How to handle errors and dependency injection using trait objects. - [`tokio_postgres`](../examples/tokio_postgres.rs) - How to use a tokio-postgres and bb8 to query a database. - [`unix_domain_socket`](../examples/unix_domain_socket.rs) - How to run an Axum server over unix domain sockets. +- [`sessions`](../examples/sessions.rs) - Sessions and cookies using [`async-session`](https://crates.io/crates/async-session). +- [`tls_rustls`](../examples/tls_rustls.rs) - TLS with [`tokio-rustls`](https://crates.io/crates/tokio-rustls). diff --git a/examples/sessions.rs b/examples/sessions.rs new file mode 100644 index 00000000..2843cbe2 --- /dev/null +++ b/examples/sessions.rs @@ -0,0 +1,109 @@ +use async_session::{MemoryStore, Session, SessionStore as _}; +use axum::{ + async_trait, + extract::{FromRequest, RequestParts}, + prelude::*, + response::IntoResponse, + AddExtensionLayer, +}; +use headers::{HeaderMap, HeaderValue}; +use http::StatusCode; +use serde::{Deserialize, Serialize}; +use std::net::SocketAddr; +use uuid::Uuid; + +#[tokio::main] +async fn main() { + // `MemoryStore` just used as an example. Don't use this in production. + let store = MemoryStore::new(); + + let app = route("/", get(handler)).layer(AddExtensionLayer::new(store)); + + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + tracing::debug!("listening on {}", addr); + hyper::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); +} + +async fn handler(user_id: UserIdFromSession) -> impl IntoResponse { + let (headers, user_id) = match user_id { + UserIdFromSession::FoundUserId(user_id) => (HeaderMap::new(), user_id), + UserIdFromSession::CreatedFreshUserId { user_id, cookie } => { + let mut headers = HeaderMap::new(); + headers.insert(http::header::SET_COOKIE, cookie); + (headers, user_id) + } + }; + + dbg!(user_id); + + headers +} + +enum UserIdFromSession { + FoundUserId(UserId), + CreatedFreshUserId { + user_id: UserId, + cookie: HeaderValue, + }, +} + +#[async_trait] +impl FromRequest for UserIdFromSession +where + B: Send, +{ + type Rejection = (StatusCode, &'static str); + + async fn from_request(req: &mut RequestParts) -> Result { + let extract::Extension(store) = extract::Extension::::from_request(req) + .await + .expect("`MemoryStore` extension missing"); + + let headers = req.headers().expect("other extractor taken headers"); + + let cookie = if let Some(cookie) = headers + .get(http::header::COOKIE) + .and_then(|value| value.to_str().ok()) + .map(|value| value.to_string()) + { + cookie + } else { + let user_id = UserId::new(); + let mut session = Session::new(); + session.insert("user_id", user_id).unwrap(); + let cookie = store.store_session(session).await.unwrap().unwrap(); + + return Ok(Self::CreatedFreshUserId { + user_id, + cookie: cookie.parse().unwrap(), + }); + }; + + let user_id = if let Some(session) = store.load_session(cookie).await.unwrap() { + if let Some(user_id) = session.get::("user_id") { + user_id + } else { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + "No `user_id` found in session", + )); + } + } else { + return Err((StatusCode::BAD_REQUEST, "No session found for cookie")); + }; + + Ok(Self::FoundUserId(user_id)) + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, Copy)] +struct UserId(Uuid); + +impl UserId { + fn new() -> Self { + Self(Uuid::new_v4()) + } +}