mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-21 06:36:32 +01:00
Implement CSRF token verification for OAuth example (#2534)
Co-authored-by: Logan Nielsen <loganbn@amazon.com> Co-authored-by: Logan Nielsen <logan.nielsen@one.app>
This commit is contained in:
parent
dc5c202c5f
commit
7e59625778
1 changed files with 74 additions and 11 deletions
|
@ -8,7 +8,7 @@
|
|||
//! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth
|
||||
//! ```
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_session::{MemoryStore, Session, SessionStore};
|
||||
use axum::{
|
||||
extract::{FromRef, FromRequestParts, Query, State},
|
||||
|
@ -28,6 +28,7 @@ use std::env;
|
|||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
static COOKIE_NAME: &str = "SESSION";
|
||||
static CSRF_TOKEN: &str = "csrf_token";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
|
@ -141,19 +142,37 @@ async fn index(user: Option<User>) -> impl IntoResponse {
|
|||
}
|
||||
}
|
||||
|
||||
async fn discord_auth(State(client): State<BasicClient>) -> impl IntoResponse {
|
||||
// TODO: this example currently doesn't validate the CSRF token during login attempts. That
|
||||
// makes it vulnerable to cross-site request forgery. If you copy code from this example make
|
||||
// sure to add a check for the CSRF token.
|
||||
//
|
||||
// Issue for adding check to this example https://github.com/tokio-rs/axum/issues/2511
|
||||
let (auth_url, _csrf_token) = client
|
||||
async fn discord_auth(
|
||||
State(client): State<BasicClient>,
|
||||
State(store): State<MemoryStore>,
|
||||
) -> Result<impl IntoResponse, AppError> {
|
||||
let (auth_url, csrf_token) = client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scope(Scope::new("identify".to_string()))
|
||||
.url();
|
||||
|
||||
// Redirect to Discord's oauth service
|
||||
Redirect::to(auth_url.as_ref())
|
||||
// Create session to store csrf_token
|
||||
let mut session = Session::new();
|
||||
session
|
||||
.insert(CSRF_TOKEN, &csrf_token)
|
||||
.context("failed in inserting CSRF token into session")?;
|
||||
|
||||
// Store the session in MemoryStore and retrieve the session cookie
|
||||
let cookie = store
|
||||
.store_session(session)
|
||||
.await
|
||||
.context("failed to store CSRF token session")?
|
||||
.context("unexpected error retrieving CSRF cookie value")?;
|
||||
|
||||
// Attach the session cookie to the response header
|
||||
let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/");
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
SET_COOKIE,
|
||||
cookie.parse().context("failed to parse cookie")?,
|
||||
);
|
||||
|
||||
Ok((headers, Redirect::to(auth_url.as_ref())))
|
||||
}
|
||||
|
||||
// Valid user session required. If there is none, redirect to the auth page
|
||||
|
@ -194,11 +213,55 @@ struct AuthRequest {
|
|||
state: String,
|
||||
}
|
||||
|
||||
async fn csrf_token_validation_workflow(
|
||||
auth_request: &AuthRequest,
|
||||
cookies: &headers::Cookie,
|
||||
store: &MemoryStore,
|
||||
) -> Result<(), AppError> {
|
||||
// Extract the cookie from the request
|
||||
let cookie = cookies
|
||||
.get(COOKIE_NAME)
|
||||
.context("unexpected error getting cookie name")?
|
||||
.to_string();
|
||||
|
||||
// Load the session
|
||||
let session = match store
|
||||
.load_session(cookie)
|
||||
.await
|
||||
.context("failed to load session")?
|
||||
{
|
||||
Some(session) => session,
|
||||
None => return Err(anyhow!("Session not found").into()),
|
||||
};
|
||||
|
||||
// Extract the CSRF token from the session
|
||||
let stored_csrf_token = session
|
||||
.get::<CsrfToken>(CSRF_TOKEN)
|
||||
.context("CSRF token not found in session")?
|
||||
.to_owned();
|
||||
|
||||
// Cleanup the CSRF token session
|
||||
store
|
||||
.destroy_session(session)
|
||||
.await
|
||||
.context("Failed to destroy old session")?;
|
||||
|
||||
// Validate CSRF token is the same as the one in the auth request
|
||||
if *stored_csrf_token.secret() != auth_request.state {
|
||||
return Err(anyhow!("CSRF token mismatch").into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn login_authorized(
|
||||
Query(query): Query<AuthRequest>,
|
||||
State(store): State<MemoryStore>,
|
||||
State(oauth_client): State<BasicClient>,
|
||||
TypedHeader(cookies): TypedHeader<headers::Cookie>,
|
||||
) -> Result<impl IntoResponse, AppError> {
|
||||
csrf_token_validation_workflow(&query, &cookies, &store).await?;
|
||||
|
||||
// Get an auth token
|
||||
let token = oauth_client
|
||||
.exchange_code(AuthorizationCode::new(query.code.clone()))
|
||||
|
@ -233,7 +296,7 @@ async fn login_authorized(
|
|||
.context("unexpected error retrieving cookie value")?;
|
||||
|
||||
// Build the cookie
|
||||
let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; Path=/");
|
||||
let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/");
|
||||
|
||||
// Set cookie
|
||||
let mut headers = HeaderMap::new();
|
||||
|
|
Loading…
Reference in a new issue