axum/examples/jwt/src/main.rs
Jonas Platte 0a399ed0fa
A few small refactorings (#655)
* Simplify json content-type check

* Import HeaderMap from http instead of from headers

The headers crate is an optional dependency and its HeaderMap re-export
is `#[doc(hidden)]`.

* Use headers re-export in axum in examples
2021-12-27 14:02:38 +01:00

197 lines
5.4 KiB
Rust

//! Example JWT authorization/authentication.
//!
//! Run with
//!
//! ```not_rust
//! JWT_SECRET=secret cargo run -p example-jwt
//! ```
use axum::{
async_trait,
extract::{FromRequest, RequestParts, TypedHeader},
headers::{authorization::Bearer, Authorization},
http::StatusCode,
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{fmt::Display, net::SocketAddr};
// Quick instructions
//
// - get an authorization token:
//
// curl -s \
// -w '\n' \
// -H 'Content-Type: application/json' \
// -d '{"client_id":"foo","client_secret":"bar"}' \
// http://localhost:3000/authorize
//
// - visit the protected area using the authorized token
//
// curl -s \
// -w '\n' \
// -H 'Content-Type: application/json' \
// -H 'Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJiQGIuY29tIiwiY29tcGFueSI6IkFDTUUiLCJleHAiOjEwMDAwMDAwMDAwfQ.M3LAZmrzUkXDC1q5mSzFAs_kJrwuKz3jOoDmjJ0G4gM' \
// http://localhost:3000/protected
//
// - try to visit the protected area using an invalid token
//
// curl -s \
// -w '\n' \
// -H 'Content-Type: application/json' \
// -H 'Authorization: Bearer blahblahblah' \
// http://localhost:3000/protected
static KEYS: Lazy<Keys> = Lazy::new(|| {
let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set");
Keys::new(secret.as_bytes())
});
#[tokio::main]
async fn main() {
// Set the RUST_LOG, if it hasn't been explicitly defined
if std::env::var_os("RUST_LOG").is_none() {
std::env::set_var("RUST_LOG", "example_jwt=debug")
}
tracing_subscriber::fmt::init();
let app = Router::new()
.route("/protected", get(protected))
.route("/authorize", post(authorize));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
async fn protected(claims: Claims) -> Result<String, AuthError> {
// Send the protected data to the user
Ok(format!(
"Welcome to the protected area :)\nYour data:\n{}",
claims
))
}
async fn authorize(Json(payload): Json<AuthPayload>) -> Result<Json<AuthBody>, AuthError> {
// Check if the user sent the credentials
if payload.client_id.is_empty() || payload.client_secret.is_empty() {
return Err(AuthError::MissingCredentials);
}
// Here you can check the user credentials from a database
if payload.client_id != "foo" || payload.client_secret != "bar" {
return Err(AuthError::WrongCredentials);
}
let claims = Claims {
sub: "b@b.com".to_owned(),
company: "ACME".to_owned(),
exp: 100000,
};
// Create the authorization token
let token = encode(&Header::default(), &claims, &KEYS.encoding)
.map_err(|_| AuthError::TokenCreation)?;
// Send the authorized token
Ok(Json(AuthBody::new(token)))
}
impl Display for Claims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Email: {}\nCompany: {}", self.sub, self.company)
}
}
impl AuthBody {
fn new(access_token: String) -> Self {
Self {
access_token,
token_type: "Bearer".to_string(),
}
}
}
#[async_trait]
impl<B> FromRequest<B> for Claims
where
B: Send,
{
type Rejection = AuthError;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
// Extract the token from the authorization header
let TypedHeader(Authorization(bearer)) =
TypedHeader::<Authorization<Bearer>>::from_request(req)
.await
.map_err(|_| AuthError::InvalidToken)?;
// Decode the user data
let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
.map_err(|_| AuthError::InvalidToken)?;
Ok(token_data.claims)
}
}
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
};
let body = Json(json!({
"error": error_message,
}));
(status, body).into_response()
}
}
#[derive(Debug)]
struct Keys {
encoding: EncodingKey,
decoding: DecodingKey<'static>,
}
impl Keys {
fn new(secret: &[u8]) -> Self {
Self {
encoding: EncodingKey::from_secret(secret),
decoding: DecodingKey::from_secret(secret).into_static(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
sub: String,
company: String,
exp: usize,
}
#[derive(Debug, Serialize)]
struct AuthBody {
access_token: String,
token_type: String,
}
#[derive(Debug, Deserialize)]
struct AuthPayload {
client_id: String,
client_secret: String,
}
#[derive(Debug)]
enum AuthError {
WrongCredentials,
MissingCredentials,
TokenCreation,
InvalidToken,
}