Implement CSRF token verification for OAuth example

This commit is contained in:
Logan Nielsen 2024-01-24 22:25:18 -08:00
parent d1fb14ead1
commit 87b9b9115e

View file

@ -8,7 +8,7 @@
//! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth //! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth
//! ``` //! ```
use anyhow::{Context, Result}; use anyhow::{anyhow, Context, Result};
use async_session::{MemoryStore, Session, SessionStore}; use async_session::{MemoryStore, Session, SessionStore};
use axum::{ use axum::{
async_trait, async_trait,
@ -142,19 +142,37 @@ async fn index(user: Option<User>) -> impl IntoResponse {
} }
} }
async fn discord_auth(State(client): State<BasicClient>) -> impl IntoResponse { async fn discord_auth(
// TODO: this example currently doesn't validate the CSRF token during login attempts. That State(client): State<BasicClient>,
// makes it vulnerable to cross-site request forgery. If you copy code from this example make State(store): State<MemoryStore>,
// sure to add a check for the CSRF token. ) -> Result<impl IntoResponse, AppError> {
// let (auth_url, csrf_token) = client
// Issue for adding check to this example https://github.com/tokio-rs/axum/issues/2511
let (auth_url, _csrf_token) = client
.authorize_url(CsrfToken::new_random) .authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("identify".to_string())) .add_scope(Scope::new("identify".to_string()))
.url(); .url();
// Redirect to Discord's oauth service // Create session to store csrf_token
Redirect::to(auth_url.as_ref()) let mut session = Session::new();
session
.insert("csrf_token", &csrf_token)
.context("failed in inserting CSRF token into session")?;
// Store the session in MemoryStore and retrieve the session cookie
let cookie = store
.store_session(session)
.await
.context("failed to store CSRF token session")?
.context("unexpected error retrieving CSRF cookie value")?;
// Attach the session cookie to the response header
let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; Path=/");
let mut headers = HeaderMap::new();
headers.insert(
SET_COOKIE,
cookie.parse().context("failed to parse cookie")?,
);
Ok((headers, Redirect::to(auth_url.as_ref())))
} }
// Valid user session required. If there is none, redirect to the auth page // Valid user session required. If there is none, redirect to the auth page
@ -195,11 +213,55 @@ struct AuthRequest {
state: String, state: String,
} }
async fn csrf_token_validation_workflow(
auth_request: &AuthRequest,
cookies: &headers::Cookie,
store: &MemoryStore,
) -> Result<(), AppError> {
// Extract the cookie from the request
let cookie = cookies
.get(COOKIE_NAME)
.context("unexpected error getting cookie name")?
.to_string();
// Load the session
let session = match store
.load_session(cookie)
.await
.context("failed to load session")?
{
Some(session) => session,
None => return Err(anyhow!("Session not found").into()),
};
// Extract the CSRF token from the session
let stored_csrf_token = session
.get::<CsrfToken>("csrf_token")
.context("CSRF token not found in session")?
.to_owned();
// Cleanup the CSRF token session
store
.destroy_session(session)
.await
.context("Failed to destroy old session")?;
// Validate CSRF token is the same as the one in the auth request
if *stored_csrf_token.secret() != auth_request.state {
return Err(anyhow!("CSRF token mismatch").into());
}
Ok(())
}
async fn login_authorized( async fn login_authorized(
Query(query): Query<AuthRequest>, Query(query): Query<AuthRequest>,
State(store): State<MemoryStore>, State(store): State<MemoryStore>,
State(oauth_client): State<BasicClient>, State(oauth_client): State<BasicClient>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
) -> Result<impl IntoResponse, AppError> { ) -> Result<impl IntoResponse, AppError> {
csrf_token_validation_workflow(&query, &cookies, &store).await?;
// Get an auth token // Get an auth token
let token = oauth_client let token = oauth_client
.exchange_code(AuthorizationCode::new(query.code.clone())) .exchange_code(AuthorizationCode::new(query.code.clone()))