Implement CSRF token verification for OAuth example (#2534)

Co-authored-by: Logan Nielsen <loganbn@amazon.com>
Co-authored-by: Logan Nielsen <logan.nielsen@one.app>
This commit is contained in:
Logan B. Nielsen 2024-11-17 15:27:54 -06:00 committed by GitHub
parent dc5c202c5f
commit 7e59625778
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -8,7 +8,7 @@
//! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth //! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth
//! ``` //! ```
use anyhow::{Context, Result}; use anyhow::{anyhow, Context, Result};
use async_session::{MemoryStore, Session, SessionStore}; use async_session::{MemoryStore, Session, SessionStore};
use axum::{ use axum::{
extract::{FromRef, FromRequestParts, Query, State}, extract::{FromRef, FromRequestParts, Query, State},
@ -28,6 +28,7 @@ use std::env;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
static COOKIE_NAME: &str = "SESSION"; static COOKIE_NAME: &str = "SESSION";
static CSRF_TOKEN: &str = "csrf_token";
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
@ -141,19 +142,37 @@ async fn index(user: Option<User>) -> impl IntoResponse {
} }
} }
async fn discord_auth(State(client): State<BasicClient>) -> impl IntoResponse { async fn discord_auth(
// TODO: this example currently doesn't validate the CSRF token during login attempts. That State(client): State<BasicClient>,
// makes it vulnerable to cross-site request forgery. If you copy code from this example make State(store): State<MemoryStore>,
// sure to add a check for the CSRF token. ) -> Result<impl IntoResponse, AppError> {
// let (auth_url, csrf_token) = client
// Issue for adding check to this example https://github.com/tokio-rs/axum/issues/2511
let (auth_url, _csrf_token) = client
.authorize_url(CsrfToken::new_random) .authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("identify".to_string())) .add_scope(Scope::new("identify".to_string()))
.url(); .url();
// Redirect to Discord's oauth service // Create session to store csrf_token
Redirect::to(auth_url.as_ref()) let mut session = Session::new();
session
.insert(CSRF_TOKEN, &csrf_token)
.context("failed in inserting CSRF token into session")?;
// Store the session in MemoryStore and retrieve the session cookie
let cookie = store
.store_session(session)
.await
.context("failed to store CSRF token session")?
.context("unexpected error retrieving CSRF cookie value")?;
// Attach the session cookie to the response header
let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/");
let mut headers = HeaderMap::new();
headers.insert(
SET_COOKIE,
cookie.parse().context("failed to parse cookie")?,
);
Ok((headers, Redirect::to(auth_url.as_ref())))
} }
// Valid user session required. If there is none, redirect to the auth page // Valid user session required. If there is none, redirect to the auth page
@ -194,11 +213,55 @@ struct AuthRequest {
state: String, state: String,
} }
async fn csrf_token_validation_workflow(
auth_request: &AuthRequest,
cookies: &headers::Cookie,
store: &MemoryStore,
) -> Result<(), AppError> {
// Extract the cookie from the request
let cookie = cookies
.get(COOKIE_NAME)
.context("unexpected error getting cookie name")?
.to_string();
// Load the session
let session = match store
.load_session(cookie)
.await
.context("failed to load session")?
{
Some(session) => session,
None => return Err(anyhow!("Session not found").into()),
};
// Extract the CSRF token from the session
let stored_csrf_token = session
.get::<CsrfToken>(CSRF_TOKEN)
.context("CSRF token not found in session")?
.to_owned();
// Cleanup the CSRF token session
store
.destroy_session(session)
.await
.context("Failed to destroy old session")?;
// Validate CSRF token is the same as the one in the auth request
if *stored_csrf_token.secret() != auth_request.state {
return Err(anyhow!("CSRF token mismatch").into());
}
Ok(())
}
async fn login_authorized( async fn login_authorized(
Query(query): Query<AuthRequest>, Query(query): Query<AuthRequest>,
State(store): State<MemoryStore>, State(store): State<MemoryStore>,
State(oauth_client): State<BasicClient>, State(oauth_client): State<BasicClient>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
) -> Result<impl IntoResponse, AppError> { ) -> Result<impl IntoResponse, AppError> {
csrf_token_validation_workflow(&query, &cookies, &store).await?;
// Get an auth token // Get an auth token
let token = oauth_client let token = oauth_client
.exchange_code(AuthorizationCode::new(query.code.clone())) .exchange_code(AuthorizationCode::new(query.code.clone()))
@ -233,7 +296,7 @@ async fn login_authorized(
.context("unexpected error retrieving cookie value")?; .context("unexpected error retrieving cookie value")?;
// Build the cookie // Build the cookie
let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; Path=/"); let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/");
// Set cookie // Set cookie
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();