diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 75f9bf81..1c9d8311 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -46,6 +46,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **breaking:** The `Body` and `BodyError` associated types on the `IntoResponse` trait have been removed - instead, `.into_response()` will now always return `Response` ([#571]) +- **breaking:** `PathParamsRejection` has been renamed to `PathRejection` and its + variants renamed to `FailedToDeserializePathParams` and `MissingPathParams`. This + makes it more consistent with the rest of axum ([#574]) +- **added:** `Path`'s rejection type now provides data about exactly which part of + the path couldn't be deserialized ([#574]) [#525]: https://github.com/tokio-rs/axum/pull/525 [#527]: https://github.com/tokio-rs/axum/pull/527 @@ -54,6 +59,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#554]: https://github.com/tokio-rs/axum/pull/554 [#564]: https://github.com/tokio-rs/axum/pull/564 [#571]: https://github.com/tokio-rs/axum/pull/571 +[#574]: https://github.com/tokio-rs/axum/pull/574 # 0.3.3 (13. November, 2021) diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index ab7ff2c3..ee6488c1 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -6,6 +6,7 @@ use rejection::*; pub mod connect_info; pub mod extractor_middleware; +pub mod path; pub mod rejection; #[cfg(feature = "ws")] @@ -16,7 +17,6 @@ mod content_length_limit; mod extension; mod form; mod matched_path; -mod path; mod query; mod raw_query; mod request_parts; diff --git a/axum/src/extract/path/de.rs b/axum/src/extract/path/de.rs index e3c437ee..a4d44596 100644 --- a/axum/src/extract/path/de.rs +++ b/axum/src/extract/path/de.rs @@ -1,36 +1,10 @@ +use super::{ErrorKind, PathDeserializationError}; use crate::util::{ByteStr, PercentDecodedByteStr}; use serde::{ de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor}, forward_to_deserialize_any, Deserializer, }; -use std::fmt::{self, Display}; - -/// This type represents errors that can occur when deserializing. -#[derive(Debug, Eq, PartialEq)] -pub(crate) struct PathDeserializerError(pub(crate) String); - -impl de::Error for PathDeserializerError { - #[inline] - fn custom(msg: T) -> Self { - PathDeserializerError(msg.to_string()) - } -} - -impl std::error::Error for PathDeserializerError { - #[inline] - fn description(&self) -> &str { - "path deserializer error" - } -} - -impl fmt::Display for PathDeserializerError { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - PathDeserializerError(msg) => write!(f, "{}", msg), - } - } -} +use std::any::type_name; macro_rules! unsupported_type { ($trait_fn:ident, $name:literal) => { @@ -38,36 +12,30 @@ macro_rules! unsupported_type { where V: Visitor<'de>, { - Err(PathDeserializerError::custom(concat!( - "unsupported type: ", - $name - ))) + Err(PathDeserializationError::unsupported_type(type_name::< + V::Value, + >())) } }; } macro_rules! parse_single_value { - ($trait_fn:ident, $visit_fn:ident, $tp:literal) => { + ($trait_fn:ident, $visit_fn:ident, $ty:literal) => { fn $trait_fn(self, visitor: V) -> Result where V: Visitor<'de>, { if self.url_params.len() != 1 { - return Err(PathDeserializerError::custom( - format!( - "wrong number of parameters: {} expected 1", - self.url_params.len() - ) - .as_str(), - )); + return Err(PathDeserializationError::wrong_number_of_parameters() + .got(self.url_params.len()) + .expected(1)); } let value = self.url_params[0].1.parse().map_err(|_| { - PathDeserializerError::custom(format!( - "can not parse `{:?}` to a `{}`", - self.url_params[0].1.as_str(), - $tp - )) + PathDeserializationError::new(ErrorKind::ParseError { + value: self.url_params[0].1.as_str().to_owned(), + expected_type: $ty, + }) })?; visitor.$visit_fn(value) } @@ -86,7 +54,7 @@ impl<'de> PathDeserializer<'de> { } impl<'de> Deserializer<'de> for PathDeserializer<'de> { - type Error = PathDeserializerError; + type Error = PathDeserializationError; unsupported_type!(deserialize_any, "'any'"); unsupported_type!(deserialize_bytes, "bytes"); @@ -116,10 +84,9 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { V: Visitor<'de>, { if self.url_params.len() != 1 { - return Err(PathDeserializerError::custom(format!( - "wrong number of parameters: {} expected 1", - self.url_params.len() - ))); + return Err(PathDeserializationError::wrong_number_of_parameters() + .got(self.url_params.len()) + .expected(1)); } visitor.visit_str(&self.url_params[0].1) } @@ -159,6 +126,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { { visitor.visit_seq(SeqDeserializer { params: self.url_params, + idx: 0, }) } @@ -167,17 +135,13 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { V: Visitor<'de>, { if self.url_params.len() < len { - return Err(PathDeserializerError::custom( - format!( - "wrong number of parameters: {} expected {}", - self.url_params.len(), - len - ) - .as_str(), - )); + return Err(PathDeserializationError::wrong_number_of_parameters() + .got(self.url_params.len()) + .expected(len)); } visitor.visit_seq(SeqDeserializer { params: self.url_params, + idx: 0, }) } @@ -191,17 +155,13 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { V: Visitor<'de>, { if self.url_params.len() < len { - return Err(PathDeserializerError::custom( - format!( - "wrong number of parameters: {} expected {}", - self.url_params.len(), - len - ) - .as_str(), - )); + return Err(PathDeserializationError::wrong_number_of_parameters() + .got(self.url_params.len()) + .expected(len)); } visitor.visit_seq(SeqDeserializer { params: self.url_params, + idx: 0, }) } @@ -212,6 +172,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { visitor.visit_map(MapDeserializer { params: self.url_params, value: None, + key: None, }) } @@ -237,10 +198,9 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { V: Visitor<'de>, { if self.url_params.len() != 1 { - return Err(PathDeserializerError::custom(format!( - "wrong number of parameters: {} expected 1", - self.url_params.len() - ))); + return Err(PathDeserializationError::wrong_number_of_parameters() + .got(self.url_params.len()) + .expected(1)); } visitor.visit_enum(EnumDeserializer { @@ -251,11 +211,12 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { struct MapDeserializer<'de> { params: &'de [(ByteStr, PercentDecodedByteStr)], + key: Option, value: Option<&'de str>, } impl<'de> MapAccess<'de> for MapDeserializer<'de> { - type Error = PathDeserializerError; + type Error = PathDeserializationError; fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> where @@ -265,6 +226,7 @@ impl<'de> MapAccess<'de> for MapDeserializer<'de> { Some(((key, value), tail)) => { self.value = Some(value); self.params = tail; + self.key = Some(KeyOrIdx::Key(key.clone())); seed.deserialize(KeyDeserializer { key }).map(Some) } None => Ok(None), @@ -276,8 +238,11 @@ impl<'de> MapAccess<'de> for MapDeserializer<'de> { V: DeserializeSeed<'de>, { match self.value.take() { - Some(value) => seed.deserialize(ValueDeserializer { value }), - None => Err(serde::de::Error::custom("value is missing")), + Some(value) => seed.deserialize(ValueDeserializer { + key: self.key.take(), + value, + }), + None => Err(PathDeserializationError::custom("value is missing")), } } } @@ -298,7 +263,7 @@ macro_rules! parse_key { } impl<'de> Deserializer<'de> for KeyDeserializer<'de> { - type Error = PathDeserializerError; + type Error = PathDeserializationError; parse_key!(deserialize_identifier); parse_key!(deserialize_str); @@ -308,7 +273,7 @@ impl<'de> Deserializer<'de> for KeyDeserializer<'de> { where V: Visitor<'de>, { - Err(PathDeserializerError::custom("Unexpected")) + Err(PathDeserializationError::custom("Unexpected key type")) } forward_to_deserialize_any! { @@ -320,15 +285,31 @@ impl<'de> Deserializer<'de> for KeyDeserializer<'de> { macro_rules! parse_value { ($trait_fn:ident, $visit_fn:ident, $ty:literal) => { - fn $trait_fn(self, visitor: V) -> Result + fn $trait_fn(mut self, visitor: V) -> Result where V: Visitor<'de>, { let v = self.value.parse().map_err(|_| { - PathDeserializerError::custom(format!( - "can not parse `{:?}` to a `{}`", - self.value, $ty - )) + if let Some(key) = self.key.take() { + let kind = match key { + KeyOrIdx::Key(key) => ErrorKind::ParseErrorAtKey { + key: key.as_str().to_owned(), + value: self.value.to_owned(), + expected_type: $ty, + }, + KeyOrIdx::Idx(index) => ErrorKind::ParseErrorAtIndex { + index, + value: self.value.to_owned(), + expected_type: $ty, + }, + }; + PathDeserializationError::new(kind) + } else { + PathDeserializationError::new(ErrorKind::ParseError { + value: self.value.to_owned(), + expected_type: $ty, + }) + } })?; visitor.$visit_fn(v) } @@ -336,11 +317,12 @@ macro_rules! parse_value { } struct ValueDeserializer<'de> { + key: Option, value: &'de str, } impl<'de> Deserializer<'de> for ValueDeserializer<'de> { - type Error = PathDeserializerError; + type Error = PathDeserializationError; unsupported_type!(deserialize_any, "any"); unsupported_type!(deserialize_seq, "seq"); @@ -418,7 +400,9 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { where V: Visitor<'de>, { - Err(PathDeserializerError::custom("unsupported type: tuple")) + Err(PathDeserializationError::unsupported_type(type_name::< + V::Value, + >())) } fn deserialize_tuple_struct( @@ -430,9 +414,9 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { where V: Visitor<'de>, { - Err(PathDeserializerError::custom( - "unsupported type: tuple struct", - )) + Err(PathDeserializationError::unsupported_type(type_name::< + V::Value, + >())) } fn deserialize_struct( @@ -444,7 +428,9 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { where V: Visitor<'de>, { - Err(PathDeserializerError::custom("unsupported type: struct")) + Err(PathDeserializationError::unsupported_type(type_name::< + V::Value, + >())) } fn deserialize_enum( @@ -472,7 +458,7 @@ struct EnumDeserializer<'de> { } impl<'de> EnumAccess<'de> for EnumDeserializer<'de> { - type Error = PathDeserializerError; + type Error = PathDeserializationError; type Variant = UnitVariant; fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> @@ -489,7 +475,7 @@ impl<'de> EnumAccess<'de> for EnumDeserializer<'de> { struct UnitVariant; impl<'de> VariantAccess<'de> for UnitVariant { - type Error = PathDeserializerError; + type Error = PathDeserializationError; fn unit_variant(self) -> Result<(), Self::Error> { Ok(()) @@ -499,14 +485,18 @@ impl<'de> VariantAccess<'de> for UnitVariant { where T: DeserializeSeed<'de>, { - Err(PathDeserializerError::custom("not supported")) + Err(PathDeserializationError::unsupported_type( + "newtype enum variant", + )) } fn tuple_variant(self, _len: usize, _visitor: V) -> Result where V: Visitor<'de>, { - Err(PathDeserializerError::custom("not supported")) + Err(PathDeserializationError::unsupported_type( + "tuple enum variant", + )) } fn struct_variant( @@ -517,16 +507,19 @@ impl<'de> VariantAccess<'de> for UnitVariant { where V: Visitor<'de>, { - Err(PathDeserializerError::custom("not supported")) + Err(PathDeserializationError::unsupported_type( + "struct enum variant", + )) } } struct SeqDeserializer<'de> { params: &'de [(ByteStr, PercentDecodedByteStr)], + idx: usize, } impl<'de> SeqAccess<'de> for SeqDeserializer<'de> { - type Error = PathDeserializerError; + type Error = PathDeserializationError; fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> where @@ -535,13 +528,24 @@ impl<'de> SeqAccess<'de> for SeqDeserializer<'de> { match self.params.split_first() { Some(((_, value), tail)) => { self.params = tail; - Ok(Some(seed.deserialize(ValueDeserializer { value })?)) + let idx = self.idx; + self.idx += 1; + Ok(Some(seed.deserialize(ValueDeserializer { + key: Some(KeyOrIdx::Idx(idx)), + value, + })?)) } None => Ok(None), } } } +#[derive(Clone)] +enum KeyOrIdx { + Key(ByteStr), + Idx(usize), +} + #[cfg(test)] mod tests { use super::*; @@ -614,10 +618,16 @@ mod tests { ); let url_params = create_url_params(vec![("a", "1"), ("b", "2")]); - assert_eq!( - i32::deserialize(PathDeserializer::new(&url_params)).unwrap_err(), - PathDeserializerError::custom("wrong number of parameters: 2 expected 1".to_owned()) - ); + let error_kind = i32::deserialize(PathDeserializer::new(&url_params)) + .unwrap_err() + .kind; + assert!(matches!( + error_kind, + ErrorKind::WrongNumberOfParameters { + expected: 1, + got: 2 + } + )); } #[test] @@ -661,6 +671,38 @@ mod tests { ); } + #[test] + fn test_parse_struct_ignoring_additional_fields() { + let url_params = create_url_params(vec![ + ("a", "1"), + ("b", "true"), + ("c", "abc"), + ("d", "false"), + ]); + assert_eq!( + Struct::deserialize(PathDeserializer::new(&url_params)).unwrap(), + Struct { + c: "abc".to_owned(), + b: true, + a: 1, + } + ); + } + + #[test] + fn test_parse_tuple_ignoring_additional_fields() { + let url_params = create_url_params(vec![ + ("a", "abc"), + ("b", "true"), + ("c", "1"), + ("d", "false"), + ]); + assert_eq!( + <(&str, bool, u32)>::deserialize(PathDeserializer::new(&url_params)).unwrap(), + ("abc", true, 1) + ); + } + #[test] fn test_parse_map() { let url_params = create_url_params(vec![("a", "1"), ("b", "true"), ("c", "abc")]); @@ -672,4 +714,110 @@ mod tests { .collect() ); } + + macro_rules! test_parse_error { + ( + $params:expr, + $ty:ty, + $expected_error_kind:expr $(,)? + ) => { + let url_params = create_url_params($params); + let actual_error_kind = <$ty>::deserialize(PathDeserializer::new(&url_params)) + .unwrap_err() + .kind; + assert_eq!(actual_error_kind, $expected_error_kind); + }; + } + + #[test] + fn test_wrong_number_of_parameters_error() { + test_parse_error!( + vec![("a", "1")], + (u32, u32), + ErrorKind::WrongNumberOfParameters { + got: 1, + expected: 2, + } + ); + } + + #[test] + fn test_parse_error_at_key_error() { + #[derive(Debug, Deserialize)] + struct Params { + a: u32, + } + test_parse_error!( + vec![("a", "false")], + Params, + ErrorKind::ParseErrorAtKey { + key: "a".to_owned(), + value: "false".to_owned(), + expected_type: "u32", + } + ); + } + + #[test] + fn test_parse_error_at_key_error_multiple() { + #[derive(Debug, Deserialize)] + struct Params { + a: u32, + b: u32, + } + test_parse_error!( + vec![("a", "false")], + Params, + ErrorKind::ParseErrorAtKey { + key: "a".to_owned(), + value: "false".to_owned(), + expected_type: "u32", + } + ); + } + + #[test] + fn test_parse_error_at_index_error() { + test_parse_error!( + vec![("a", "false"), ("b", "true")], + (bool, u32), + ErrorKind::ParseErrorAtIndex { + index: 1, + value: "true".to_owned(), + expected_type: "u32", + } + ); + } + + #[test] + fn test_parse_error_error() { + test_parse_error!( + vec![("a", "false")], + u32, + ErrorKind::ParseError { + value: "false".to_owned(), + expected_type: "u32", + } + ); + } + + #[test] + fn test_unsupported_type_error_nested_data_structure() { + test_parse_error!( + vec![("a", "false")], + Vec>, + ErrorKind::UnsupportedType { + name: "alloc::vec::Vec", + } + ); + } + + #[test] + fn test_unsupported_type_error_tuple() { + test_parse_error!( + vec![("a", "false")], + Vec<(u32, u32)>, + ErrorKind::UnsupportedType { name: "(u32, u32)" } + ); + } } diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index cd4023f7..c1ee0bb3 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -1,14 +1,20 @@ +//! Extractor that will get captures from the URL and parse them using +//! [`serde`]. + mod de; -use super::{rejection::*, FromRequest}; use crate::{ - extract::RequestParts, + body::{boxed, BoxBody, Full}, + extract::{rejection::*, FromRequest, RequestParts}, + response::IntoResponse, routing::{InvalidUtf8InPathParam, UrlParams}, }; use async_trait::async_trait; +use http::StatusCode; use serde::de::DeserializeOwned; use std::{ borrow::Cow, + fmt, ops::{Deref, DerefMut}, }; @@ -121,8 +127,14 @@ use std::{ /// # }; /// ``` /// +/// # Providing detailed rejection output +/// +/// If the URI cannot be deserialized into the target type the request will be rejected and an +/// error response will be returned. See [`customize-path-rejection`] for an exapmle of how to customize that error. +/// /// [`serde`]: https://crates.io/crates/serde /// [`serde::Deserialize`]: https://docs.rs/serde/1.0.127/serde/trait.Deserialize.html +/// [`customize-path-rejection`]: https://github.com/tokio-rs/axum/blob/main/examples/customize-path-rejection/src/main.rs #[derive(Debug)] pub struct Path(pub T); @@ -148,7 +160,7 @@ where T: DeserializeOwned + Send, B: Send, { - type Rejection = PathParamsRejection; + type Rejection = PathRejection; async fn from_request(req: &mut RequestParts) -> Result { let params = match req @@ -157,20 +169,236 @@ where { Some(Some(UrlParams(Ok(params)))) => Cow::Borrowed(params), Some(Some(UrlParams(Err(InvalidUtf8InPathParam { key })))) => { - return Err(InvalidPathParam::new(key.as_str()).into()) + let err = PathDeserializationError { + kind: ErrorKind::InvalidUtf8InPathParam { + key: key.as_str().to_owned(), + }, + }; + let err = FailedToDeserializePathParams(err); + return Err(err.into()); } Some(None) => Cow::Owned(Vec::new()), None => { - return Err(MissingRouteParams.into()); + return Err(MissingPathParams.into()); } }; T::deserialize(de::PathDeserializer::new(&*params)) - .map_err(|err| PathParamsRejection::InvalidPathParam(InvalidPathParam::new(err.0))) + .map_err(|err| { + PathRejection::FailedToDeserializePathParams(FailedToDeserializePathParams(err)) + }) .map(Path) } } +// this wrapper type is used as the deserializer error to hide the `serde::de::Error` impl which +// would otherwise be public if we used `ErrorKind` as the error directly +#[derive(Debug)] +pub(crate) struct PathDeserializationError { + pub(super) kind: ErrorKind, +} + +impl PathDeserializationError { + pub(super) fn new(kind: ErrorKind) -> Self { + Self { kind } + } + + pub(super) fn wrong_number_of_parameters() -> WrongNumberOfParameters<()> { + WrongNumberOfParameters { got: () } + } + + pub(super) fn unsupported_type(name: &'static str) -> Self { + Self::new(ErrorKind::UnsupportedType { name }) + } +} + +pub(super) struct WrongNumberOfParameters { + got: G, +} + +impl WrongNumberOfParameters { + #[allow(clippy::unused_self)] + pub(super) fn got(self, got: G2) -> WrongNumberOfParameters { + WrongNumberOfParameters { got } + } +} + +impl WrongNumberOfParameters { + pub(super) fn expected(self, expected: usize) -> PathDeserializationError { + PathDeserializationError::new(ErrorKind::WrongNumberOfParameters { + got: self.got, + expected, + }) + } +} + +impl serde::de::Error for PathDeserializationError { + #[inline] + fn custom(msg: T) -> Self + where + T: fmt::Display, + { + Self { + kind: ErrorKind::Message(msg.to_string()), + } + } +} + +impl fmt::Display for PathDeserializationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.kind.fmt(f) + } +} + +impl std::error::Error for PathDeserializationError {} + +/// The kinds of errors that can happen we deserializing into a [`Path`]. +/// +/// This type is obtained through [`FailedToDeserializePathParams::into_kind`] and is useful for building +/// more precise error messages. +#[derive(Debug, PartialEq)] +#[non_exhaustive] +pub enum ErrorKind { + /// The URI contained the wrong number of parameters. + WrongNumberOfParameters { + /// The number of actual parameters in the URI. + got: usize, + /// The number of expected parameters. + expected: usize, + }, + + /// Failed to parse the value at a specific key into the expected type. + /// + /// This variant is used when deserializing into types that have named fields, such as structs. + ParseErrorAtKey { + /// The key at which the value was located. + key: String, + /// The value from the URI. + value: String, + /// The expected type of the value. + expected_type: &'static str, + }, + + /// Failed to parse the value at a specific index into the expected type. + /// + /// This variant is used when deserializing into sequence types, such as tuples. + ParseErrorAtIndex { + /// The index at which the value was located. + index: usize, + /// The value from the URI. + value: String, + /// The expected type of the value. + expected_type: &'static str, + }, + + /// Failed to parse a value into the expected type. + /// + /// This variant is used when deserializing into a primitive type (such as `String` and `u32`). + ParseError { + /// The value from the URI. + value: String, + /// The expected type of the value. + expected_type: &'static str, + }, + + /// A parameter contained text that, once percent decoded, wasn't valid UTF-8. + InvalidUtf8InPathParam { + /// The key at which the invalid value was located. + key: String, + }, + + /// Tried to serialize into an unsupported type such as nested maps. + /// + /// This error kind is caused by programmer errors and thus gets converted into a `500 Internal + /// Server Error` response. + UnsupportedType { + /// The name of the unsupported type. + name: &'static str, + }, + + /// Catch-all variant for errors that don't fit any other variant. + Message(String), +} + +impl fmt::Display for ErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ErrorKind::Message(error) => error.fmt(f), + ErrorKind::InvalidUtf8InPathParam { key } => write!(f, "Invalid UTF-8 in `{}`", key), + ErrorKind::WrongNumberOfParameters { got, expected } => write!( + f, + "Wronger number of parameters. Expected {} but got {}", + expected, got + ), + ErrorKind::UnsupportedType { name } => write!(f, "Unsupported type `{}`", name), + ErrorKind::ParseErrorAtKey { + key, + value, + expected_type, + } => write!( + f, + "Cannot parse `{}` with value `{:?}` to a `{}`", + key, value, expected_type + ), + ErrorKind::ParseError { + value, + expected_type, + } => write!(f, "Cannot parse `{:?}` to a `{}`", value, expected_type), + ErrorKind::ParseErrorAtIndex { + index, + value, + expected_type, + } => write!( + f, + "Cannot parse value at index {} with value `{:?}` to a `{}`", + index, value, expected_type + ), + } + } +} + +/// Rejection type for [`Path`](super::Path) if the captured routes params couldn't be deserialized +/// into the expected type. +#[derive(Debug)] +pub struct FailedToDeserializePathParams(PathDeserializationError); + +impl FailedToDeserializePathParams { + /// Convert this error into the underlying error kind. + pub fn into_kind(self) -> ErrorKind { + self.0.kind + } +} + +impl IntoResponse for FailedToDeserializePathParams { + fn into_response(self) -> http::Response { + let (status, body) = match self.0.kind { + ErrorKind::Message(_) + | ErrorKind::InvalidUtf8InPathParam { .. } + | ErrorKind::WrongNumberOfParameters { .. } + | ErrorKind::ParseError { .. } + | ErrorKind::ParseErrorAtIndex { .. } + | ErrorKind::ParseErrorAtKey { .. } => ( + StatusCode::BAD_REQUEST, + format!("Invalid URL: {}", self.0.kind), + ), + ErrorKind::UnsupportedType { .. } => { + (StatusCode::INTERNAL_SERVER_ERROR, self.0.kind.to_string()) + } + }; + let mut res = http::Response::new(boxed(Full::from(body))); + *res.status_mut() = status; + res + } +} + +impl fmt::Display for FailedToDeserializePathParams { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl std::error::Error for FailedToDeserializePathParams {} + #[cfg(test)] mod tests { use http::StatusCode; diff --git a/axum/src/extract/rejection.rs b/axum/src/extract/rejection.rs index ed73d14b..cb99785e 100644 --- a/axum/src/extract/rejection.rs +++ b/axum/src/extract/rejection.rs @@ -7,6 +7,7 @@ use crate::{ }; use http_body::Full; +pub use crate::extract::path::FailedToDeserializePathParams; pub use axum_core::extract::rejection::*; #[cfg(feature = "json")] @@ -52,9 +53,10 @@ define_rejection! { define_rejection! { #[status = INTERNAL_SERVER_ERROR] - #[body = "No url params found for matched route. This is a bug in axum. Please open an issue"] - /// Rejection type used if you try and extract the URL params more than once. - pub struct MissingRouteParams; + #[body = "No paths parameters found for matched route. This is a bug in axum. Please open an issue"] + /// Rejection type used if axum's internal representation of path parameters is missing. This + /// should never happen and is a bug in axum if it does. + pub struct MissingPathParams; } define_rejection! { @@ -64,33 +66,6 @@ define_rejection! { pub struct InvalidFormContentType; } -/// Rejection type for [`Path`](super::Path) if the capture route -/// param didn't have the expected type. -#[derive(Debug)] -pub struct InvalidPathParam(pub(crate) String); - -impl InvalidPathParam { - pub(super) fn new(err: impl Into) -> Self { - InvalidPathParam(err.into()) - } -} - -impl IntoResponse for InvalidPathParam { - fn into_response(self) -> http::Response { - let mut res = http::Response::new(boxed(Full::from(self.to_string()))); - *res.status_mut() = http::StatusCode::BAD_REQUEST; - res - } -} - -impl std::fmt::Display for InvalidPathParam { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Invalid URL param. {}", self.0) - } -} - -impl std::error::Error for InvalidPathParam {} - /// Rejection type for extractors that deserialize query strings if the input /// couldn't be deserialized into the target type. #[derive(Debug)] @@ -185,9 +160,9 @@ composite_rejection! { /// /// Contains one variant for each way the [`Path`](super::Path) extractor /// can fail. - pub enum PathParamsRejection { - InvalidPathParam, - MissingRouteParams, + pub enum PathRejection { + FailedToDeserializePathParams, + MissingPathParams, } } diff --git a/axum/src/util.rs b/axum/src/util.rs index 71c50e77..34e7b761 100644 --- a/axum/src/util.rs +++ b/axum/src/util.rs @@ -1,5 +1,6 @@ use bytes::Bytes; use pin_project_lite::pin_project; +use std::fmt; use std::ops::Deref; /// A string like type backed by `Bytes` making it cheap to clone. @@ -30,6 +31,12 @@ impl ByteStr { } } +impl fmt::Display for ByteStr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub(crate) struct PercentDecodedByteStr(ByteStr); diff --git a/examples/customize-path-rejection/Cargo.toml b/examples/customize-path-rejection/Cargo.toml new file mode 100644 index 00000000..d553c668 --- /dev/null +++ b/examples/customize-path-rejection/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "example-customize-path-rejection" +version = "0.1.0" +edition = "2018" +publish = false + +[dependencies] +axum = { path = "../../axum" } +tokio = { version = "1.0", features = ["full"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +tracing = "0.1" +tracing-subscriber = { version="0.3", features = ["env-filter"] } diff --git a/examples/customize-path-rejection/src/main.rs b/examples/customize-path-rejection/src/main.rs new file mode 100644 index 00000000..2f2b856b --- /dev/null +++ b/examples/customize-path-rejection/src/main.rs @@ -0,0 +1,144 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-customize-path-rejection +//! ``` + +use axum::{ + async_trait, + extract::{path::ErrorKind, rejection::PathRejection, FromRequest, RequestParts}, + http::StatusCode, + response::IntoResponse, + routing::get, + Router, +}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::net::SocketAddr; + +#[tokio::main] +async fn main() { + // Set the RUST_LOG, if it hasn't been explicitly defined + if std::env::var_os("RUST_LOG").is_none() { + std::env::set_var("RUST_LOG", "example_customize_path_rejection=debug") + } + tracing_subscriber::fmt::init(); + + // build our application with a route + let app = Router::new().route("/users/:user_id/teams/:team_id", get(handler)); + + // run it + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + println!("listening on {}", addr); + axum::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); +} + +async fn handler(Path(params): Path) -> impl IntoResponse { + axum::Json(params) +} + +#[derive(Debug, Deserialize, Serialize)] +struct Params { + user_id: u32, + team_id: u32, +} + +// We define our own `Path` extractor that customizes the error from `axum::extract::Path` +struct Path(T); + +#[async_trait] +impl FromRequest for Path +where + // these trait bounds are copied from `impl FromRequest for axum::extract::path::Path` + T: DeserializeOwned + Send, + B: Send, +{ + type Rejection = (StatusCode, axum::Json); + + async fn from_request(req: &mut RequestParts) -> Result { + match axum::extract::Path::::from_request(req).await { + Ok(value) => Ok(Self(value.0)), + Err(rejection) => { + let (status, body) = match rejection { + PathRejection::FailedToDeserializePathParams(inner) => { + let mut status = StatusCode::BAD_REQUEST; + + let kind = inner.into_kind(); + let body = match &kind { + ErrorKind::WrongNumberOfParameters { .. } => PathError { + message: kind.to_string(), + location: None, + }, + + ErrorKind::ParseErrorAtKey { key, .. } => PathError { + message: kind.to_string(), + location: Some(key.clone()), + }, + + ErrorKind::ParseErrorAtIndex { index, .. } => PathError { + message: kind.to_string(), + location: Some(index.to_string()), + }, + + ErrorKind::ParseError { .. } => PathError { + message: kind.to_string(), + location: None, + }, + + ErrorKind::InvalidUtf8InPathParam { key } => PathError { + message: kind.to_string(), + location: Some(key.clone()), + }, + + ErrorKind::UnsupportedType { .. } => { + // this error is caused by the programmer using an unsupported type + // (such as nested maps) so respond with `500` instead + status = StatusCode::INTERNAL_SERVER_ERROR; + PathError { + message: kind.to_string(), + location: None, + } + } + + ErrorKind::Message(msg) => PathError { + message: msg.clone(), + location: None, + }, + + _ => PathError { + message: format!("Unhandled deserialization error: {}", kind), + location: None, + }, + }; + + (status, body) + } + PathRejection::MissingPathParams(error) => ( + StatusCode::INTERNAL_SERVER_ERROR, + PathError { + message: error.to_string(), + location: None, + }, + ), + _ => ( + StatusCode::INTERNAL_SERVER_ERROR, + PathError { + message: format!("Unhandled path rejection: {}", rejection), + location: None, + }, + ), + }; + + Err((status, axum::Json(body))) + } + } + } +} + +#[derive(Serialize)] +struct PathError { + message: String, + location: Option, +}