From 3841ef44d563021a2bf335173ff72c574914d358 Mon Sep 17 00:00:00 2001 From: ttys3 <41882455+ttys3@users.noreply.github.com> Date: Wed, 22 Dec 2021 22:27:13 +0800 Subject: [PATCH] Fix session cookie example (#638) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: refine session cookie example * refactor: refine session_cookie extraction * refactor: avoid to_owned() * chore: refine debug log Co-authored-by: 荒野無燈 --- examples/sessions/Cargo.toml | 2 +- examples/sessions/src/main.rs | 86 ++++++++++++++++++++++++----------- 2 files changed, 61 insertions(+), 27 deletions(-) diff --git a/examples/sessions/Cargo.toml b/examples/sessions/Cargo.toml index c46f20b4..fa65fcdb 100644 --- a/examples/sessions/Cargo.toml +++ b/examples/sessions/Cargo.toml @@ -5,7 +5,7 @@ edition = "2018" publish = false [dependencies] -axum = { path = "../../axum" } +axum = { path = "../../axum", features = ["headers"] } tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version="0.3", features = ["env-filter"] } diff --git a/examples/sessions/src/main.rs b/examples/sessions/src/main.rs index f16e800d..7ac6a8e4 100644 --- a/examples/sessions/src/main.rs +++ b/examples/sessions/src/main.rs @@ -7,7 +7,8 @@ use async_session::{MemoryStore, Session, SessionStore as _}; use axum::{ async_trait, - extract::{Extension, FromRequest, RequestParts}, + extract::{Extension, FromRequest, RequestParts, TypedHeader}, + headers::Cookie, http::{ self, header::{HeaderMap, HeaderValue}, @@ -18,9 +19,12 @@ use axum::{ AddExtensionLayer, Router, }; use serde::{Deserialize, Serialize}; +use std::fmt::Debug; use std::net::SocketAddr; use uuid::Uuid; +const AXUM_SESSION_COOKIE_NAME: &str = "axum_session"; + #[tokio::main] async fn main() { // Set the RUST_LOG, if it hasn't been explicitly defined @@ -45,26 +49,34 @@ async fn main() { } async fn handler(user_id: UserIdFromSession) -> impl IntoResponse { - let (headers, user_id) = match user_id { - UserIdFromSession::FoundUserId(user_id) => (HeaderMap::new(), user_id), - UserIdFromSession::CreatedFreshUserId { user_id, cookie } => { + let (headers, user_id, create_cookie) = match user_id { + UserIdFromSession::FoundUserId(user_id) => (HeaderMap::new(), user_id, false), + UserIdFromSession::CreatedFreshUserId(new_user) => { let mut headers = HeaderMap::new(); - headers.insert(http::header::SET_COOKIE, cookie); - (headers, user_id) + headers.insert(http::header::SET_COOKIE, new_user.cookie); + (headers, new_user.user_id, true) } }; - dbg!(user_id); + tracing::debug!("handler: user_id={:?} send_headers={:?}", user_id, headers); - headers + ( + headers, + format!( + "user_id={:?} session_cookie_name={} create_new_session_cookie={}", + user_id, AXUM_SESSION_COOKIE_NAME, create_cookie + ), + ) +} + +struct FreshUserId { + pub user_id: UserId, + pub cookie: HeaderValue, } enum UserIdFromSession { FoundUserId(UserId), - CreatedFreshUserId { - user_id: UserId, - cookie: HeaderValue, - }, + CreatedFreshUserId(FreshUserId), } #[async_trait] @@ -79,28 +91,45 @@ where .await .expect("`MemoryStore` extension missing"); - let headers = req.headers().expect("other extractor taken headers"); + let cookie = Option::>::from_request(req) + .await + .unwrap(); - let cookie = if let Some(cookie) = headers - .get(http::header::COOKIE) - .and_then(|value| value.to_str().ok()) - .map(|value| value.to_string()) - { - cookie - } else { + let session_cookie = cookie + .as_ref() + .and_then(|cookie| cookie.get(AXUM_SESSION_COOKIE_NAME)); + + // return the new created session cookie for client + if session_cookie.is_none() { let user_id = UserId::new(); let mut session = Session::new(); session.insert("user_id", user_id).unwrap(); let cookie = store.store_session(session).await.unwrap().unwrap(); - - return Ok(Self::CreatedFreshUserId { + return Ok(Self::CreatedFreshUserId(FreshUserId { user_id, - cookie: cookie.parse().unwrap(), - }); - }; + cookie: HeaderValue::from_str( + format!("{}={}", AXUM_SESSION_COOKIE_NAME, cookie).as_str(), + ) + .unwrap(), + })); + } - let user_id = if let Some(session) = store.load_session(cookie).await.unwrap() { + tracing::debug!( + "UserIdFromSession: got session cookie from user agent, {}={}", + AXUM_SESSION_COOKIE_NAME, + session_cookie.unwrap() + ); + // continue to decode the session cookie + let user_id = if let Some(session) = store + .load_session(session_cookie.unwrap().to_owned()) + .await + .unwrap() + { if let Some(user_id) = session.get::("user_id") { + tracing::debug!( + "UserIdFromSession: session decoded success, user_id={:?}", + user_id + ); user_id } else { return Err(( @@ -109,6 +138,11 @@ where )); } } else { + tracing::debug!( + "UserIdFromSession: err session not exists in store, {}={}", + AXUM_SESSION_COOKIE_NAME, + session_cookie.unwrap() + ); return Err((StatusCode::BAD_REQUEST, "No session found for cookie")); };