From 6c13b22cd42d84cfe3aee8370eef635a4c4e23a1 Mon Sep 17 00:00:00 2001 From: silvioprog Date: Fri, 10 Sep 2021 12:51:20 -0300 Subject: [PATCH] Add JWT example (#306) (#308) Co-authored-by: Jonas Platte --- examples/jwt/Cargo.toml | 16 ++++ examples/jwt/src/main.rs | 201 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 examples/jwt/Cargo.toml create mode 100644 examples/jwt/src/main.rs diff --git a/examples/jwt/Cargo.toml b/examples/jwt/Cargo.toml new file mode 100644 index 00000000..201ead4b --- /dev/null +++ b/examples/jwt/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "example-jwt" +version = "0.1.0" +edition = "2018" +publish = false + +[dependencies] +axum = { path = "../..", features = ["headers"] } +tokio = { version = "1.0", features = ["full"] } +tracing = "0.1" +tracing-subscriber = "0.2" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +headers = "0.3" +jsonwebtoken = "7" +once_cell = "1.8" diff --git a/examples/jwt/src/main.rs b/examples/jwt/src/main.rs new file mode 100644 index 00000000..5c223133 --- /dev/null +++ b/examples/jwt/src/main.rs @@ -0,0 +1,201 @@ +//! Example JWT authorization/authentication. +//! +//! Run with +//! +//! ```not_rust +//! JWT_SECRET=secret cargo run -p example-jwt +//! ``` + +use axum::{ + async_trait, + body::{Bytes, Full}, + extract::{FromRequest, RequestParts, TypedHeader}, + handler::{get, post}, + http::{Response, StatusCode}, + response::IntoResponse, + Json, Router, +}; +use headers::{authorization::Bearer, Authorization}; +use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; +use once_cell::sync::Lazy; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::{convert::Infallible, fmt::Display, net::SocketAddr}; + +// Quick instructions +// +// - get an authorization token: +// +// curl -s \ +// -w '\n' \ +// -H 'Content-Type: application/json' \ +// -d '{"client_id":"foo","client_secret":"bar"}' \ +// http://localhost:3000/authorize +// +// - visit the protected area using the authorized token +// +// curl -s \ +// -w '\n' \ +// -H 'Content-Type: application/json' \ +// -H 'Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjEwMDAwMDAwMDAwfQ.M3LAZmrzUkXDC1q5mSzFAs_kJrwuKz3jOoDmjJ0G4gM' \ +// http://localhost:3000/protected +// +// - try to visit the protected area using an invalid token +// +// curl -s \ +// -w '\n' \ +// -H 'Content-Type: application/json' \ +// -H 'Authorization: Bearer blahblahblah' \ +// http://localhost:3000/protected + +static KEYS: Lazy = Lazy::new(|| { + let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); + Keys::new(secret.as_bytes()) +}); + +#[tokio::main] +async fn main() { + // Set the RUST_LOG, if it hasn't been explicitly defined + if std::env::var("RUST_LOG").is_err() { + std::env::set_var("RUST_LOG", "example_jwt=debug") + } + tracing_subscriber::fmt::init(); + + let app = Router::new() + .route("/protected", get(protected)) + .route("/authorize", post(authorize)); + + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + tracing::debug!("listening on {}", addr); + + axum::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); +} + +async fn protected(claims: Claims) -> Result { + // Send the protected data to the user + Ok(format!( + "Welcome to the protected area :)\nYour data:\n{}", + claims + )) +} + +async fn authorize(Json(payload): Json) -> Result, AuthError> { + // Check if the user sent the credentials + if payload.client_id.is_empty() || payload.client_secret.is_empty() { + return Err(AuthError::MissingCredentials); + } + // Here you can check the user credentials from a database + if payload.client_id != "foo" || payload.client_secret != "bar" { + return Err(AuthError::WrongCredentials); + } + let claims = Claims { + sub: "b@b.com".to_owned(), + company: "ACME".to_owned(), + exp: 10000000000, + }; + // Create the authorization token + let token = encode(&Header::default(), &claims, &KEYS.encoding) + .map_err(|_| AuthError::TokenCreation)?; + + // Send the authorized token + Ok(Json(AuthBody::new(token))) +} + +impl Display for Claims { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Email: {}\nCompany: {}", self.sub, self.company) + } +} + +impl AuthBody { + fn new(access_token: String) -> Self { + Self { + access_token, + token_type: "Bearer".to_string(), + } + } +} + +#[async_trait] +impl FromRequest for Claims +where + B: Send, +{ + type Rejection = AuthError; + + async fn from_request(req: &mut RequestParts) -> Result { + // Extract the token from the authorization header + let TypedHeader(Authorization(bearer)) = + TypedHeader::>::from_request(req) + .await + .map_err(|_| AuthError::InvalidToken)?; + // Decode the user data + let token_data = decode::(bearer.token(), &KEYS.decoding, &Validation::default()) + .map_err(|_| AuthError::InvalidToken)?; + + Ok(token_data.claims) + } +} + +impl IntoResponse for AuthError { + type Body = Full; + type BodyError = Infallible; + + fn into_response(self) -> Response { + let (status, error_message) = match self { + AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"), + AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"), + AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"), + AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"), + }; + let body = Json(json!({ + "error": error_message, + })); + (status, body).into_response() + } +} + +#[derive(Debug)] +struct Keys { + encoding: EncodingKey, + decoding: DecodingKey<'static>, +} + +impl Keys { + fn new(secret: &[u8]) -> Self { + Self { + encoding: EncodingKey::from_secret(secret), + decoding: DecodingKey::from_secret(secret).into_static(), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct Claims { + sub: String, + company: String, + exp: usize, +} + +#[derive(Debug, Serialize)] +struct AuthBody { + access_token: String, + token_type: String, +} + +#[derive(Debug, Deserialize)] +struct AuthPayload { + client_id: String, + client_secret: String, +} + +#[derive(Debug)] +enum AuthError { + WrongCredentials, + MissingCredentials, + TokenCreation, + InvalidToken, +}