Rework error handling example (#2382)

This commit is contained in:
David Pedersen 2023-12-29 01:49:13 +01:00 committed by GitHub
parent 6c276c3ff0
commit c3db223532
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 235 additions and 160 deletions

View file

@ -1,155 +0,0 @@
//! Example showing how to convert errors into responses and how one might do
//! dependency injection using trait objects.
//!
//! Run with
//!
//! ```not_rust
//! cargo run -p example-error-handling-and-dependency-injection
//! ```
use axum::{
async_trait,
extract::{Path, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::sync::Arc;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use uuid::Uuid;
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "example_error_handling_and_dependency_injection=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
// Inject a `UserRepo` into our handlers via a trait object. This could be
// the live implementation or just a mock for testing.
let user_repo = Arc::new(ExampleUserRepo) as DynUserRepo;
// Build our application with some routes
let app = Router::new()
.route("/users/:id", get(users_show))
.route("/users", post(users_create))
.with_state(user_repo);
// Run our application
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
/// Handler for `GET /users/:id`.
///
/// Extracts the user repo from request extensions and calls it. `UserRepoError`s
/// are automatically converted into `AppError` which implements `IntoResponse`
/// so it can be returned from handlers directly.
async fn users_show(
Path(user_id): Path<Uuid>,
State(user_repo): State<DynUserRepo>,
) -> Result<Json<User>, AppError> {
let user = user_repo.find(user_id).await?;
Ok(user.into())
}
/// Handler for `POST /users`.
async fn users_create(
State(user_repo): State<DynUserRepo>,
Json(params): Json<CreateUser>,
) -> Result<Json<User>, AppError> {
let user = user_repo.create(params).await?;
Ok(user.into())
}
/// Our app's top level error type.
enum AppError {
/// Something went wrong when calling the user repo.
UserRepo(UserRepoError),
}
/// This makes it possible to use `?` to automatically convert a `UserRepoError`
/// into an `AppError`.
impl From<UserRepoError> for AppError {
fn from(inner: UserRepoError) -> Self {
AppError::UserRepo(inner)
}
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AppError::UserRepo(UserRepoError::NotFound) => {
(StatusCode::NOT_FOUND, "User not found")
}
AppError::UserRepo(UserRepoError::InvalidUsername) => {
(StatusCode::UNPROCESSABLE_ENTITY, "Invalid username")
}
};
let body = Json(json!({
"error": error_message,
}));
(status, body).into_response()
}
}
/// Example implementation of `UserRepo`.
struct ExampleUserRepo;
#[async_trait]
impl UserRepo for ExampleUserRepo {
async fn find(&self, _user_id: Uuid) -> Result<User, UserRepoError> {
unimplemented!()
}
async fn create(&self, _params: CreateUser) -> Result<User, UserRepoError> {
unimplemented!()
}
}
/// Type alias that makes it easier to extract `UserRepo` trait objects.
type DynUserRepo = Arc<dyn UserRepo + Send + Sync>;
/// A trait that defines things a user repo might support.
#[async_trait]
trait UserRepo {
/// Loop up a user by their id.
async fn find(&self, user_id: Uuid) -> Result<User, UserRepoError>;
/// Create a new user.
async fn create(&self, params: CreateUser) -> Result<User, UserRepoError>;
}
#[derive(Debug, Serialize)]
struct User {
id: Uuid,
username: String,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct CreateUser {
username: String,
}
/// Errors that can happen when using the user repo.
#[derive(Debug)]
enum UserRepoError {
#[allow(dead_code)]
NotFound,
#[allow(dead_code)]
InvalidUsername,
}

View file

@ -1,15 +1,13 @@
[package]
name = "example-error-handling-and-dependency-injection"
name = "example-error-handling"
version = "0.1.0"
edition = "2021"
publish = false
[dependencies]
axum = { path = "../../axum" }
axum = { path = "../../axum", features = ["macros"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.4", features = ["util"] }
tower-http = { version = "0.5", features = ["trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.0", features = ["v4", "serde"] }

View file

@ -0,0 +1,232 @@
//! Example showing how to convert errors into responses.
//!
//! Run with
//!
//! ```not_rust
//! cargo run -p example-error-handling
//! ```
//!
//! For successful requests the log output will be
//!
//! ```ignore
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_request: started processing request
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_response: finished processing request latency=0 ms status=200
//! ```
//!
//! For failed requests the log output will be
//!
//! ```ignore
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_request: started processing request
//! ERROR request{method=POST uri=/users matched_path="/users"}: example_error_handling: error from time_library err=failed to get time
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_response: finished processing request latency=0 ms status=500
//! ```
use std::{
collections::HashMap,
sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
},
};
use axum::{
extract::{rejection::JsonRejection, FromRequest, MatchedPath, Request, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::post,
Router,
};
use serde::{Deserialize, Serialize};
use time_library::Timestamp;
use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "example_error_handling=debug,tower_http=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let state = AppState::default();
let app = Router::new()
// A dummy route that accepts some JSON but sometimes fails
.route("/users", post(users_create))
.layer(
TraceLayer::new_for_http()
// Create our own span for the request and include the matched path. The matched
// path is useful for figuring out which handler the request was routed to.
.make_span_with(|req: &Request| {
let method = req.method();
let uri = req.uri();
// axum automatically adds this extension.
let matched_path = req
.extensions()
.get::<MatchedPath>()
.map(|matched_path| matched_path.as_str());
tracing::debug_span!("request", %method, %uri, matched_path)
})
// By default `TraceLayer` will log 5xx responses but we're doing our specific
// logging of errors so disable that
.on_failure(()),
)
.with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
#[derive(Default, Clone)]
struct AppState {
next_id: Arc<AtomicU64>,
users: Arc<Mutex<HashMap<u64, User>>>,
}
#[derive(Deserialize)]
struct UserParams {
name: String,
}
#[derive(Serialize, Clone)]
struct User {
id: u64,
name: String,
created_at: Timestamp,
}
async fn users_create(
State(state): State<AppState>,
// Make sure to use our own JSON extractor so we get input errors formatted in a way that
// matches our application
AppJson(params): AppJson<UserParams>,
) -> Result<AppJson<User>, AppError> {
let id = state.next_id.fetch_add(1, Ordering::SeqCst);
// We have implemented `From<time_library::Error> for AppError` which allows us to use `?` to
// automatically convert the error
let created_at = Timestamp::now()?;
let user = User {
id,
name: params.name,
created_at,
};
state.users.lock().unwrap().insert(id, user.clone());
Ok(AppJson(user))
}
// Create our own JSON extractor by wrapping `axum::Json`. This makes it easy to override the
// rejection and provide our own which formats errors to match our application.
//
// `axum::Json` responds with plain text if the input is invalid.
#[derive(FromRequest)]
#[from_request(via(axum::Json), rejection(AppError))]
struct AppJson<T>(T);
impl<T> IntoResponse for AppJson<T>
where
axum::Json<T>: IntoResponse,
{
fn into_response(self) -> Response {
axum::Json(self.0).into_response()
}
}
// The kinds of errors we can hit in our application.
enum AppError {
// The request body contained invalid JSON
JsonRejection(JsonRejection),
// Some error from a third party library we're using
TimeError(time_library::Error),
}
// Tell axum how `AppError` should be converted into a response.
//
// This is also a convenient place to log errors.
impl IntoResponse for AppError {
fn into_response(self) -> Response {
// How we want errors responses to be serialized
#[derive(Serialize)]
struct ErrorResponse {
message: String,
}
let (status, message) = match self {
AppError::JsonRejection(rejection) => {
// This error is caused by bad user input so don't log it
(rejection.status(), rejection.body_text())
}
AppError::TimeError(err) => {
// Because `TraceLayer` wraps each request in a span that contains the request
// method, uri, etc we don't need to include those details here
tracing::error!(%err, "error from time_library");
// Don't expose any details about the error to the client
(
StatusCode::INTERNAL_SERVER_ERROR,
"Something went wrong".to_owned(),
)
}
};
(status, AppJson(ErrorResponse { message })).into_response()
}
}
impl From<JsonRejection> for AppError {
fn from(rejection: JsonRejection) -> Self {
Self::JsonRejection(rejection)
}
}
impl From<time_library::Error> for AppError {
fn from(error: time_library::Error) -> Self {
Self::TimeError(error)
}
}
// Imagine this is some third party library that we're using. It sometimes returns errors which we
// want to log.
mod time_library {
use std::sync::atomic::{AtomicU64, Ordering};
use serde::Serialize;
#[derive(Serialize, Clone)]
pub struct Timestamp(u64);
impl Timestamp {
pub fn now() -> Result<Self, Error> {
static COUNTER: AtomicU64 = AtomicU64::new(0);
// Fail on every third call just to simulate errors
if COUNTER.fetch_add(1, Ordering::SeqCst) % 3 == 0 {
Err(Error::FailedToGetTime)
} else {
Ok(Self(1337))
}
}
}
#[derive(Debug)]
pub enum Error {
FailedToGetTime,
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "failed to get time")
}
}
}