mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-22 07:08:16 +01:00
Fix session cookie example (#638)
* refactor: refine session cookie example * refactor: refine session_cookie extraction * refactor: avoid to_owned() * chore: refine debug log Co-authored-by: 荒野無燈 <ttys3.rust@gmail.com>
This commit is contained in:
parent
4c48efc861
commit
3841ef44d5
2 changed files with 61 additions and 27 deletions
|
@ -5,7 +5,7 @@ edition = "2018"
|
|||
publish = false
|
||||
|
||||
[dependencies]
|
||||
axum = { path = "../../axum" }
|
||||
axum = { path = "../../axum", features = ["headers"] }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version="0.3", features = ["env-filter"] }
|
||||
|
|
|
@ -7,7 +7,8 @@
|
|||
use async_session::{MemoryStore, Session, SessionStore as _};
|
||||
use axum::{
|
||||
async_trait,
|
||||
extract::{Extension, FromRequest, RequestParts},
|
||||
extract::{Extension, FromRequest, RequestParts, TypedHeader},
|
||||
headers::Cookie,
|
||||
http::{
|
||||
self,
|
||||
header::{HeaderMap, HeaderValue},
|
||||
|
@ -18,9 +19,12 @@ use axum::{
|
|||
AddExtensionLayer, Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Debug;
|
||||
use std::net::SocketAddr;
|
||||
use uuid::Uuid;
|
||||
|
||||
const AXUM_SESSION_COOKIE_NAME: &str = "axum_session";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// Set the RUST_LOG, if it hasn't been explicitly defined
|
||||
|
@ -45,26 +49,34 @@ async fn main() {
|
|||
}
|
||||
|
||||
async fn handler(user_id: UserIdFromSession) -> impl IntoResponse {
|
||||
let (headers, user_id) = match user_id {
|
||||
UserIdFromSession::FoundUserId(user_id) => (HeaderMap::new(), user_id),
|
||||
UserIdFromSession::CreatedFreshUserId { user_id, cookie } => {
|
||||
let (headers, user_id, create_cookie) = match user_id {
|
||||
UserIdFromSession::FoundUserId(user_id) => (HeaderMap::new(), user_id, false),
|
||||
UserIdFromSession::CreatedFreshUserId(new_user) => {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(http::header::SET_COOKIE, cookie);
|
||||
(headers, user_id)
|
||||
headers.insert(http::header::SET_COOKIE, new_user.cookie);
|
||||
(headers, new_user.user_id, true)
|
||||
}
|
||||
};
|
||||
|
||||
dbg!(user_id);
|
||||
tracing::debug!("handler: user_id={:?} send_headers={:?}", user_id, headers);
|
||||
|
||||
headers
|
||||
(
|
||||
headers,
|
||||
format!(
|
||||
"user_id={:?} session_cookie_name={} create_new_session_cookie={}",
|
||||
user_id, AXUM_SESSION_COOKIE_NAME, create_cookie
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
struct FreshUserId {
|
||||
pub user_id: UserId,
|
||||
pub cookie: HeaderValue,
|
||||
}
|
||||
|
||||
enum UserIdFromSession {
|
||||
FoundUserId(UserId),
|
||||
CreatedFreshUserId {
|
||||
user_id: UserId,
|
||||
cookie: HeaderValue,
|
||||
},
|
||||
CreatedFreshUserId(FreshUserId),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
@ -79,28 +91,45 @@ where
|
|||
.await
|
||||
.expect("`MemoryStore` extension missing");
|
||||
|
||||
let headers = req.headers().expect("other extractor taken headers");
|
||||
let cookie = Option::<TypedHeader<Cookie>>::from_request(req)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let cookie = if let Some(cookie) = headers
|
||||
.get(http::header::COOKIE)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(|value| value.to_string())
|
||||
{
|
||||
cookie
|
||||
} else {
|
||||
let session_cookie = cookie
|
||||
.as_ref()
|
||||
.and_then(|cookie| cookie.get(AXUM_SESSION_COOKIE_NAME));
|
||||
|
||||
// return the new created session cookie for client
|
||||
if session_cookie.is_none() {
|
||||
let user_id = UserId::new();
|
||||
let mut session = Session::new();
|
||||
session.insert("user_id", user_id).unwrap();
|
||||
let cookie = store.store_session(session).await.unwrap().unwrap();
|
||||
|
||||
return Ok(Self::CreatedFreshUserId {
|
||||
return Ok(Self::CreatedFreshUserId(FreshUserId {
|
||||
user_id,
|
||||
cookie: cookie.parse().unwrap(),
|
||||
});
|
||||
};
|
||||
cookie: HeaderValue::from_str(
|
||||
format!("{}={}", AXUM_SESSION_COOKIE_NAME, cookie).as_str(),
|
||||
)
|
||||
.unwrap(),
|
||||
}));
|
||||
}
|
||||
|
||||
let user_id = if let Some(session) = store.load_session(cookie).await.unwrap() {
|
||||
tracing::debug!(
|
||||
"UserIdFromSession: got session cookie from user agent, {}={}",
|
||||
AXUM_SESSION_COOKIE_NAME,
|
||||
session_cookie.unwrap()
|
||||
);
|
||||
// continue to decode the session cookie
|
||||
let user_id = if let Some(session) = store
|
||||
.load_session(session_cookie.unwrap().to_owned())
|
||||
.await
|
||||
.unwrap()
|
||||
{
|
||||
if let Some(user_id) = session.get::<UserId>("user_id") {
|
||||
tracing::debug!(
|
||||
"UserIdFromSession: session decoded success, user_id={:?}",
|
||||
user_id
|
||||
);
|
||||
user_id
|
||||
} else {
|
||||
return Err((
|
||||
|
@ -109,6 +138,11 @@ where
|
|||
));
|
||||
}
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"UserIdFromSession: err session not exists in store, {}={}",
|
||||
AXUM_SESSION_COOKIE_NAME,
|
||||
session_cookie.unwrap()
|
||||
);
|
||||
return Err((StatusCode::BAD_REQUEST, "No session found for cookie"));
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in a new issue