mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-21 22:56:46 +01:00
Rework error handling example (#2382)
This commit is contained in:
parent
6c276c3ff0
commit
c3db223532
3 changed files with 235 additions and 160 deletions
|
@ -1,155 +0,0 @@
|
|||
//! Example showing how to convert errors into responses and how one might do
|
||||
//! dependency injection using trait objects.
|
||||
//!
|
||||
//! Run with
|
||||
//!
|
||||
//! ```not_rust
|
||||
//! cargo run -p example-error-handling-and-dependency-injection
|
||||
//! ```
|
||||
|
||||
use axum::{
|
||||
async_trait,
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "example_error_handling_and_dependency_injection=debug".into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
// Inject a `UserRepo` into our handlers via a trait object. This could be
|
||||
// the live implementation or just a mock for testing.
|
||||
let user_repo = Arc::new(ExampleUserRepo) as DynUserRepo;
|
||||
|
||||
// Build our application with some routes
|
||||
let app = Router::new()
|
||||
.route("/users/:id", get(users_show))
|
||||
.route("/users", post(users_create))
|
||||
.with_state(user_repo);
|
||||
|
||||
// Run our application
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
|
||||
.await
|
||||
.unwrap();
|
||||
tracing::debug!("listening on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
|
||||
/// Handler for `GET /users/:id`.
|
||||
///
|
||||
/// Extracts the user repo from request extensions and calls it. `UserRepoError`s
|
||||
/// are automatically converted into `AppError` which implements `IntoResponse`
|
||||
/// so it can be returned from handlers directly.
|
||||
async fn users_show(
|
||||
Path(user_id): Path<Uuid>,
|
||||
State(user_repo): State<DynUserRepo>,
|
||||
) -> Result<Json<User>, AppError> {
|
||||
let user = user_repo.find(user_id).await?;
|
||||
|
||||
Ok(user.into())
|
||||
}
|
||||
|
||||
/// Handler for `POST /users`.
|
||||
async fn users_create(
|
||||
State(user_repo): State<DynUserRepo>,
|
||||
Json(params): Json<CreateUser>,
|
||||
) -> Result<Json<User>, AppError> {
|
||||
let user = user_repo.create(params).await?;
|
||||
|
||||
Ok(user.into())
|
||||
}
|
||||
|
||||
/// Our app's top level error type.
|
||||
enum AppError {
|
||||
/// Something went wrong when calling the user repo.
|
||||
UserRepo(UserRepoError),
|
||||
}
|
||||
|
||||
/// This makes it possible to use `?` to automatically convert a `UserRepoError`
|
||||
/// into an `AppError`.
|
||||
impl From<UserRepoError> for AppError {
|
||||
fn from(inner: UserRepoError) -> Self {
|
||||
AppError::UserRepo(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, error_message) = match self {
|
||||
AppError::UserRepo(UserRepoError::NotFound) => {
|
||||
(StatusCode::NOT_FOUND, "User not found")
|
||||
}
|
||||
AppError::UserRepo(UserRepoError::InvalidUsername) => {
|
||||
(StatusCode::UNPROCESSABLE_ENTITY, "Invalid username")
|
||||
}
|
||||
};
|
||||
|
||||
let body = Json(json!({
|
||||
"error": error_message,
|
||||
}));
|
||||
|
||||
(status, body).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// Example implementation of `UserRepo`.
|
||||
struct ExampleUserRepo;
|
||||
|
||||
#[async_trait]
|
||||
impl UserRepo for ExampleUserRepo {
|
||||
async fn find(&self, _user_id: Uuid) -> Result<User, UserRepoError> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn create(&self, _params: CreateUser) -> Result<User, UserRepoError> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
/// Type alias that makes it easier to extract `UserRepo` trait objects.
|
||||
type DynUserRepo = Arc<dyn UserRepo + Send + Sync>;
|
||||
|
||||
/// A trait that defines things a user repo might support.
|
||||
#[async_trait]
|
||||
trait UserRepo {
|
||||
/// Loop up a user by their id.
|
||||
async fn find(&self, user_id: Uuid) -> Result<User, UserRepoError>;
|
||||
|
||||
/// Create a new user.
|
||||
async fn create(&self, params: CreateUser) -> Result<User, UserRepoError>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct User {
|
||||
id: Uuid,
|
||||
username: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct CreateUser {
|
||||
username: String,
|
||||
}
|
||||
|
||||
/// Errors that can happen when using the user repo.
|
||||
#[derive(Debug)]
|
||||
enum UserRepoError {
|
||||
#[allow(dead_code)]
|
||||
NotFound,
|
||||
#[allow(dead_code)]
|
||||
InvalidUsername,
|
||||
}
|
|
@ -1,15 +1,13 @@
|
|||
[package]
|
||||
name = "example-error-handling-and-dependency-injection"
|
||||
name = "example-error-handling"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[dependencies]
|
||||
axum = { path = "../../axum" }
|
||||
axum = { path = "../../axum", features = ["macros"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
tower = { version = "0.4", features = ["util"] }
|
||||
tower-http = { version = "0.5", features = ["trace"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
uuid = { version = "1.0", features = ["v4", "serde"] }
|
232
examples/error-handling/src/main.rs
Normal file
232
examples/error-handling/src/main.rs
Normal file
|
@ -0,0 +1,232 @@
|
|||
//! Example showing how to convert errors into responses.
|
||||
//!
|
||||
//! Run with
|
||||
//!
|
||||
//! ```not_rust
|
||||
//! cargo run -p example-error-handling
|
||||
//! ```
|
||||
//!
|
||||
//! For successful requests the log output will be
|
||||
//!
|
||||
//! ```ignore
|
||||
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_request: started processing request
|
||||
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_response: finished processing request latency=0 ms status=200
|
||||
//! ```
|
||||
//!
|
||||
//! For failed requests the log output will be
|
||||
//!
|
||||
//! ```ignore
|
||||
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_request: started processing request
|
||||
//! ERROR request{method=POST uri=/users matched_path="/users"}: example_error_handling: error from time_library err=failed to get time
|
||||
//! DEBUG request{method=POST uri=/users matched_path="/users"}: tower_http::trace::on_response: finished processing request latency=0 ms status=500
|
||||
//! ```
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc, Mutex,
|
||||
},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
extract::{rejection::JsonRejection, FromRequest, MatchedPath, Request, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::post,
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use time_library::Timestamp;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "example_error_handling=debug,tower_http=debug".into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
let state = AppState::default();
|
||||
|
||||
let app = Router::new()
|
||||
// A dummy route that accepts some JSON but sometimes fails
|
||||
.route("/users", post(users_create))
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
// Create our own span for the request and include the matched path. The matched
|
||||
// path is useful for figuring out which handler the request was routed to.
|
||||
.make_span_with(|req: &Request| {
|
||||
let method = req.method();
|
||||
let uri = req.uri();
|
||||
|
||||
// axum automatically adds this extension.
|
||||
let matched_path = req
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.map(|matched_path| matched_path.as_str());
|
||||
|
||||
tracing::debug_span!("request", %method, %uri, matched_path)
|
||||
})
|
||||
// By default `TraceLayer` will log 5xx responses but we're doing our specific
|
||||
// logging of errors so disable that
|
||||
.on_failure(()),
|
||||
)
|
||||
.with_state(state);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
|
||||
.await
|
||||
.unwrap();
|
||||
tracing::debug!("listening on {}", listener.local_addr().unwrap());
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
struct AppState {
|
||||
next_id: Arc<AtomicU64>,
|
||||
users: Arc<Mutex<HashMap<u64, User>>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UserParams {
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
struct User {
|
||||
id: u64,
|
||||
name: String,
|
||||
created_at: Timestamp,
|
||||
}
|
||||
|
||||
async fn users_create(
|
||||
State(state): State<AppState>,
|
||||
// Make sure to use our own JSON extractor so we get input errors formatted in a way that
|
||||
// matches our application
|
||||
AppJson(params): AppJson<UserParams>,
|
||||
) -> Result<AppJson<User>, AppError> {
|
||||
let id = state.next_id.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
// We have implemented `From<time_library::Error> for AppError` which allows us to use `?` to
|
||||
// automatically convert the error
|
||||
let created_at = Timestamp::now()?;
|
||||
|
||||
let user = User {
|
||||
id,
|
||||
name: params.name,
|
||||
created_at,
|
||||
};
|
||||
|
||||
state.users.lock().unwrap().insert(id, user.clone());
|
||||
|
||||
Ok(AppJson(user))
|
||||
}
|
||||
|
||||
// Create our own JSON extractor by wrapping `axum::Json`. This makes it easy to override the
|
||||
// rejection and provide our own which formats errors to match our application.
|
||||
//
|
||||
// `axum::Json` responds with plain text if the input is invalid.
|
||||
#[derive(FromRequest)]
|
||||
#[from_request(via(axum::Json), rejection(AppError))]
|
||||
struct AppJson<T>(T);
|
||||
|
||||
impl<T> IntoResponse for AppJson<T>
|
||||
where
|
||||
axum::Json<T>: IntoResponse,
|
||||
{
|
||||
fn into_response(self) -> Response {
|
||||
axum::Json(self.0).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
// The kinds of errors we can hit in our application.
|
||||
enum AppError {
|
||||
// The request body contained invalid JSON
|
||||
JsonRejection(JsonRejection),
|
||||
// Some error from a third party library we're using
|
||||
TimeError(time_library::Error),
|
||||
}
|
||||
|
||||
// Tell axum how `AppError` should be converted into a response.
|
||||
//
|
||||
// This is also a convenient place to log errors.
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
// How we want errors responses to be serialized
|
||||
#[derive(Serialize)]
|
||||
struct ErrorResponse {
|
||||
message: String,
|
||||
}
|
||||
|
||||
let (status, message) = match self {
|
||||
AppError::JsonRejection(rejection) => {
|
||||
// This error is caused by bad user input so don't log it
|
||||
(rejection.status(), rejection.body_text())
|
||||
}
|
||||
AppError::TimeError(err) => {
|
||||
// Because `TraceLayer` wraps each request in a span that contains the request
|
||||
// method, uri, etc we don't need to include those details here
|
||||
tracing::error!(%err, "error from time_library");
|
||||
|
||||
// Don't expose any details about the error to the client
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Something went wrong".to_owned(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
(status, AppJson(ErrorResponse { message })).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<JsonRejection> for AppError {
|
||||
fn from(rejection: JsonRejection) -> Self {
|
||||
Self::JsonRejection(rejection)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<time_library::Error> for AppError {
|
||||
fn from(error: time_library::Error) -> Self {
|
||||
Self::TimeError(error)
|
||||
}
|
||||
}
|
||||
|
||||
// Imagine this is some third party library that we're using. It sometimes returns errors which we
|
||||
// want to log.
|
||||
mod time_library {
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct Timestamp(u64);
|
||||
|
||||
impl Timestamp {
|
||||
pub fn now() -> Result<Self, Error> {
|
||||
static COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
// Fail on every third call just to simulate errors
|
||||
if COUNTER.fetch_add(1, Ordering::SeqCst) % 3 == 0 {
|
||||
Err(Error::FailedToGetTime)
|
||||
} else {
|
||||
Ok(Self(1337))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
FailedToGetTime,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "failed to get time")
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue