From c3db223532f1f1006d2e0b0c576d627d6e95cdcb Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 29 Dec 2023 01:49:13 +0100 Subject: [PATCH] Rework error handling example (#2382) --- .../src/main.rs | 155 ------------ .../Cargo.toml | 8 +- examples/error-handling/src/main.rs | 232 ++++++++++++++++++ 3 files changed, 235 insertions(+), 160 deletions(-) delete mode 100644 examples/error-handling-and-dependency-injection/src/main.rs rename examples/{error-handling-and-dependency-injection => error-handling}/Cargo.toml (55%) create mode 100644 examples/error-handling/src/main.rs diff --git a/examples/error-handling-and-dependency-injection/src/main.rs b/examples/error-handling-and-dependency-injection/src/main.rs deleted file mode 100644 index a5bdad9e..00000000 --- a/examples/error-handling-and-dependency-injection/src/main.rs +++ /dev/null @@ -1,155 +0,0 @@ -//! Example showing how to convert errors into responses and how one might do -//! dependency injection using trait objects. -//! -//! Run with -//! -//! ```not_rust -//! cargo run -p example-error-handling-and-dependency-injection -//! ``` - -use axum::{ - async_trait, - extract::{Path, State}, - http::StatusCode, - response::{IntoResponse, Response}, - routing::{get, post}, - Json, Router, -}; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use std::sync::Arc; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; -use uuid::Uuid; - -#[tokio::main] -async fn main() { - tracing_subscriber::registry() - .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_error_handling_and_dependency_injection=debug".into()), - ) - .with(tracing_subscriber::fmt::layer()) - .init(); - - // Inject a `UserRepo` into our handlers via a trait object. This could be - // the live implementation or just a mock for testing. - let user_repo = Arc::new(ExampleUserRepo) as DynUserRepo; - - // Build our application with some routes - let app = Router::new() - .route("/users/:id", get(users_show)) - .route("/users", post(users_create)) - .with_state(user_repo); - - // Run our application - let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") - .await - .unwrap(); - tracing::debug!("listening on {}", listener.local_addr().unwrap()); - axum::serve(listener, app).await.unwrap(); -} - -/// Handler for `GET /users/:id`. -/// -/// Extracts the user repo from request extensions and calls it. `UserRepoError`s -/// are automatically converted into `AppError` which implements `IntoResponse` -/// so it can be returned from handlers directly. -async fn users_show( - Path(user_id): Path, - State(user_repo): State, -) -> Result, AppError> { - let user = user_repo.find(user_id).await?; - - Ok(user.into()) -} - -/// Handler for `POST /users`. -async fn users_create( - State(user_repo): State, - Json(params): Json, -) -> Result, AppError> { - let user = user_repo.create(params).await?; - - Ok(user.into()) -} - -/// Our app's top level error type. -enum AppError { - /// Something went wrong when calling the user repo. - UserRepo(UserRepoError), -} - -/// This makes it possible to use `?` to automatically convert a `UserRepoError` -/// into an `AppError`. -impl From for AppError { - fn from(inner: UserRepoError) -> Self { - AppError::UserRepo(inner) - } -} - -impl IntoResponse for AppError { - fn into_response(self) -> Response { - let (status, error_message) = match self { - AppError::UserRepo(UserRepoError::NotFound) => { - (StatusCode::NOT_FOUND, "User not found") - } - AppError::UserRepo(UserRepoError::InvalidUsername) => { - (StatusCode::UNPROCESSABLE_ENTITY, "Invalid username") - } - }; - - let body = Json(json!({ - "error": error_message, - })); - - (status, body).into_response() - } -} - -/// Example implementation of `UserRepo`. -struct ExampleUserRepo; - -#[async_trait] -impl UserRepo for ExampleUserRepo { - async fn find(&self, _user_id: Uuid) -> Result { - unimplemented!() - } - - async fn create(&self, _params: CreateUser) -> Result { - unimplemented!() - } -} - -/// Type alias that makes it easier to extract `UserRepo` trait objects. -type DynUserRepo = Arc; - -/// A trait that defines things a user repo might support. -#[async_trait] -trait UserRepo { - /// Loop up a user by their id. - async fn find(&self, user_id: Uuid) -> Result; - - /// Create a new user. - async fn create(&self, params: CreateUser) -> Result; -} - -#[derive(Debug, Serialize)] -struct User { - id: Uuid, - username: String, -} - -#[derive(Debug, Deserialize)] -#[allow(dead_code)] -struct CreateUser { - username: String, -} - -/// Errors that can happen when using the user repo. -#[derive(Debug)] -enum UserRepoError { - #[allow(dead_code)] - NotFound, - #[allow(dead_code)] - InvalidUsername, -} diff --git a/examples/error-handling-and-dependency-injection/Cargo.toml b/examples/error-handling/Cargo.toml similarity index 55% rename from examples/error-handling-and-dependency-injection/Cargo.toml rename to examples/error-handling/Cargo.toml index 583ab15a..26fc3b98 100644 --- a/examples/error-handling-and-dependency-injection/Cargo.toml +++ b/examples/error-handling/Cargo.toml @@ -1,15 +1,13 @@ [package] -name = "example-error-handling-and-dependency-injection" +name = "example-error-handling" version = "0.1.0" edition = "2021" publish = false [dependencies] -axum = { path = "../../axum" } +axum = { path = "../../axum", features = ["macros"] } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } +tower-http = { version = "0.5", features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -uuid = { version = "1.0", features = ["v4", "serde"] } diff --git a/examples/error-handling/src/main.rs b/examples/error-handling/src/main.rs new file mode 100644 index 00000000..6981f59e --- /dev/null +++ b/examples/error-handling/src/main.rs @@ -0,0 +1,232 @@ +//! Example showing how to convert errors into responses. +//! +//! Run with +//! +//! ```not_rust +//! cargo run -p example-error-handling +//! ``` +//! +//! For successful requests the log output will be +//! +//! ```ignore +//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_request: started processing request +//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_response: finished processing request latency=0 ms status=200 +//! ``` +//! +//! For failed requests the log output will be +//! +//! ```ignore +//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_request: started processing request +//! ERROR request{method=POST uri=/users matched_path="/users"}: example_error_handling: error from time_library err=failed to get time +//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_response: finished processing request latency=0 ms status=500 +//! ``` + +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Mutex, + }, +}; + +use axum::{ + extract::{rejection::JsonRejection, FromRequest, MatchedPath, Request, State}, + http::StatusCode, + response::{IntoResponse, Response}, + routing::post, + Router, +}; +use serde::{Deserialize, Serialize}; +use time_library::Timestamp; +use tower_http::trace::TraceLayer; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "example_error_handling=debug,tower_http=debug".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let state = AppState::default(); + + let app = Router::new() + // A dummy route that accepts some JSON but sometimes fails + .route("/users", post(users_create)) + .layer( + TraceLayer::new_for_http() + // Create our own span for the request and include the matched path. The matched + // path is useful for figuring out which handler the request was routed to. + .make_span_with(|req: &Request| { + let method = req.method(); + let uri = req.uri(); + + // axum automatically adds this extension. + let matched_path = req + .extensions() + .get::() + .map(|matched_path| matched_path.as_str()); + + tracing::debug_span!("request", %method, %uri, matched_path) + }) + // By default `TraceLayer` will log 5xx responses but we're doing our specific + // logging of errors so disable that + .on_failure(()), + ) + .with_state(state); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + tracing::debug!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} + +#[derive(Default, Clone)] +struct AppState { + next_id: Arc, + users: Arc>>, +} + +#[derive(Deserialize)] +struct UserParams { + name: String, +} + +#[derive(Serialize, Clone)] +struct User { + id: u64, + name: String, + created_at: Timestamp, +} + +async fn users_create( + State(state): State, + // Make sure to use our own JSON extractor so we get input errors formatted in a way that + // matches our application + AppJson(params): AppJson, +) -> Result, AppError> { + let id = state.next_id.fetch_add(1, Ordering::SeqCst); + + // We have implemented `From for AppError` which allows us to use `?` to + // automatically convert the error + let created_at = Timestamp::now()?; + + let user = User { + id, + name: params.name, + created_at, + }; + + state.users.lock().unwrap().insert(id, user.clone()); + + Ok(AppJson(user)) +} + +// Create our own JSON extractor by wrapping `axum::Json`. This makes it easy to override the +// rejection and provide our own which formats errors to match our application. +// +// `axum::Json` responds with plain text if the input is invalid. +#[derive(FromRequest)] +#[from_request(via(axum::Json), rejection(AppError))] +struct AppJson(T); + +impl IntoResponse for AppJson +where + axum::Json: IntoResponse, +{ + fn into_response(self) -> Response { + axum::Json(self.0).into_response() + } +} + +// The kinds of errors we can hit in our application. +enum AppError { + // The request body contained invalid JSON + JsonRejection(JsonRejection), + // Some error from a third party library we're using + TimeError(time_library::Error), +} + +// Tell axum how `AppError` should be converted into a response. +// +// This is also a convenient place to log errors. +impl IntoResponse for AppError { + fn into_response(self) -> Response { + // How we want errors responses to be serialized + #[derive(Serialize)] + struct ErrorResponse { + message: String, + } + + let (status, message) = match self { + AppError::JsonRejection(rejection) => { + // This error is caused by bad user input so don't log it + (rejection.status(), rejection.body_text()) + } + AppError::TimeError(err) => { + // Because `TraceLayer` wraps each request in a span that contains the request + // method, uri, etc we don't need to include those details here + tracing::error!(%err, "error from time_library"); + + // Don't expose any details about the error to the client + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Something went wrong".to_owned(), + ) + } + }; + + (status, AppJson(ErrorResponse { message })).into_response() + } +} + +impl From for AppError { + fn from(rejection: JsonRejection) -> Self { + Self::JsonRejection(rejection) + } +} + +impl From for AppError { + fn from(error: time_library::Error) -> Self { + Self::TimeError(error) + } +} + +// Imagine this is some third party library that we're using. It sometimes returns errors which we +// want to log. +mod time_library { + use std::sync::atomic::{AtomicU64, Ordering}; + + use serde::Serialize; + + #[derive(Serialize, Clone)] + pub struct Timestamp(u64); + + impl Timestamp { + pub fn now() -> Result { + static COUNTER: AtomicU64 = AtomicU64::new(0); + + // Fail on every third call just to simulate errors + if COUNTER.fetch_add(1, Ordering::SeqCst) % 3 == 0 { + Err(Error::FailedToGetTime) + } else { + Ok(Self(1337)) + } + } + } + + #[derive(Debug)] + pub enum Error { + FailedToGetTime, + } + + impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "failed to get time") + } + } +}