Provide more data in Path deserialization error (#574)

* Provide more error in `Path` deserialization error

* Rename

* Add error kind for deserializing sequences

* Rename

* Fix wrong docs

* Rename `MissingRouteParams`

* Rename error to have more consistency

* Rename internal error

* Update changelog

* One last renaming, for now

* Add tests

* Tweak changelog a bit
This commit is contained in:
David Pedersen 2021-12-02 08:51:29 +01:00 committed by GitHub
parent 66c5142d0c
commit 3ec680cce7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 658 additions and 137 deletions

View file

@ -46,6 +46,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **breaking:** The `Body` and `BodyError` associated types on the - **breaking:** The `Body` and `BodyError` associated types on the
`IntoResponse` trait have been removed - instead, `.into_response()` will now `IntoResponse` trait have been removed - instead, `.into_response()` will now
always return `Response<BoxBody>` ([#571]) always return `Response<BoxBody>` ([#571])
- **breaking:** `PathParamsRejection` has been renamed to `PathRejection` and its
variants renamed to `FailedToDeserializePathParams` and `MissingPathParams`. This
makes it more consistent with the rest of axum ([#574])
- **added:** `Path`'s rejection type now provides data about exactly which part of
the path couldn't be deserialized ([#574])
[#525]: https://github.com/tokio-rs/axum/pull/525 [#525]: https://github.com/tokio-rs/axum/pull/525
[#527]: https://github.com/tokio-rs/axum/pull/527 [#527]: https://github.com/tokio-rs/axum/pull/527
@ -54,6 +59,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#554]: https://github.com/tokio-rs/axum/pull/554 [#554]: https://github.com/tokio-rs/axum/pull/554
[#564]: https://github.com/tokio-rs/axum/pull/564 [#564]: https://github.com/tokio-rs/axum/pull/564
[#571]: https://github.com/tokio-rs/axum/pull/571 [#571]: https://github.com/tokio-rs/axum/pull/571
[#574]: https://github.com/tokio-rs/axum/pull/574
# 0.3.3 (13. November, 2021) # 0.3.3 (13. November, 2021)

View file

@ -6,6 +6,7 @@ use rejection::*;
pub mod connect_info; pub mod connect_info;
pub mod extractor_middleware; pub mod extractor_middleware;
pub mod path;
pub mod rejection; pub mod rejection;
#[cfg(feature = "ws")] #[cfg(feature = "ws")]
@ -16,7 +17,6 @@ mod content_length_limit;
mod extension; mod extension;
mod form; mod form;
mod matched_path; mod matched_path;
mod path;
mod query; mod query;
mod raw_query; mod raw_query;
mod request_parts; mod request_parts;

View file

@ -1,36 +1,10 @@
use super::{ErrorKind, PathDeserializationError};
use crate::util::{ByteStr, PercentDecodedByteStr}; use crate::util::{ByteStr, PercentDecodedByteStr};
use serde::{ use serde::{
de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor}, de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor},
forward_to_deserialize_any, Deserializer, forward_to_deserialize_any, Deserializer,
}; };
use std::fmt::{self, Display}; use std::any::type_name;
/// This type represents errors that can occur when deserializing.
#[derive(Debug, Eq, PartialEq)]
pub(crate) struct PathDeserializerError(pub(crate) String);
impl de::Error for PathDeserializerError {
#[inline]
fn custom<T: Display>(msg: T) -> Self {
PathDeserializerError(msg.to_string())
}
}
impl std::error::Error for PathDeserializerError {
#[inline]
fn description(&self) -> &str {
"path deserializer error"
}
}
impl fmt::Display for PathDeserializerError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
PathDeserializerError(msg) => write!(f, "{}", msg),
}
}
}
macro_rules! unsupported_type { macro_rules! unsupported_type {
($trait_fn:ident, $name:literal) => { ($trait_fn:ident, $name:literal) => {
@ -38,36 +12,30 @@ macro_rules! unsupported_type {
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
Err(PathDeserializerError::custom(concat!( Err(PathDeserializationError::unsupported_type(type_name::<
"unsupported type: ", V::Value,
$name >()))
)))
} }
}; };
} }
macro_rules! parse_single_value { macro_rules! parse_single_value {
($trait_fn:ident, $visit_fn:ident, $tp:literal) => { ($trait_fn:ident, $visit_fn:ident, $ty:literal) => {
fn $trait_fn<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn $trait_fn<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
if self.url_params.len() != 1 { if self.url_params.len() != 1 {
return Err(PathDeserializerError::custom( return Err(PathDeserializationError::wrong_number_of_parameters()
format!( .got(self.url_params.len())
"wrong number of parameters: {} expected 1", .expected(1));
self.url_params.len()
)
.as_str(),
));
} }
let value = self.url_params[0].1.parse().map_err(|_| { let value = self.url_params[0].1.parse().map_err(|_| {
PathDeserializerError::custom(format!( PathDeserializationError::new(ErrorKind::ParseError {
"can not parse `{:?}` to a `{}`", value: self.url_params[0].1.as_str().to_owned(),
self.url_params[0].1.as_str(), expected_type: $ty,
$tp })
))
})?; })?;
visitor.$visit_fn(value) visitor.$visit_fn(value)
} }
@ -86,7 +54,7 @@ impl<'de> PathDeserializer<'de> {
} }
impl<'de> Deserializer<'de> for PathDeserializer<'de> { impl<'de> Deserializer<'de> for PathDeserializer<'de> {
type Error = PathDeserializerError; type Error = PathDeserializationError;
unsupported_type!(deserialize_any, "'any'"); unsupported_type!(deserialize_any, "'any'");
unsupported_type!(deserialize_bytes, "bytes"); unsupported_type!(deserialize_bytes, "bytes");
@ -116,10 +84,9 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
V: Visitor<'de>, V: Visitor<'de>,
{ {
if self.url_params.len() != 1 { if self.url_params.len() != 1 {
return Err(PathDeserializerError::custom(format!( return Err(PathDeserializationError::wrong_number_of_parameters()
"wrong number of parameters: {} expected 1", .got(self.url_params.len())
self.url_params.len() .expected(1));
)));
} }
visitor.visit_str(&self.url_params[0].1) visitor.visit_str(&self.url_params[0].1)
} }
@ -159,6 +126,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
{ {
visitor.visit_seq(SeqDeserializer { visitor.visit_seq(SeqDeserializer {
params: self.url_params, params: self.url_params,
idx: 0,
}) })
} }
@ -167,17 +135,13 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
V: Visitor<'de>, V: Visitor<'de>,
{ {
if self.url_params.len() < len { if self.url_params.len() < len {
return Err(PathDeserializerError::custom( return Err(PathDeserializationError::wrong_number_of_parameters()
format!( .got(self.url_params.len())
"wrong number of parameters: {} expected {}", .expected(len));
self.url_params.len(),
len
)
.as_str(),
));
} }
visitor.visit_seq(SeqDeserializer { visitor.visit_seq(SeqDeserializer {
params: self.url_params, params: self.url_params,
idx: 0,
}) })
} }
@ -191,17 +155,13 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
V: Visitor<'de>, V: Visitor<'de>,
{ {
if self.url_params.len() < len { if self.url_params.len() < len {
return Err(PathDeserializerError::custom( return Err(PathDeserializationError::wrong_number_of_parameters()
format!( .got(self.url_params.len())
"wrong number of parameters: {} expected {}", .expected(len));
self.url_params.len(),
len
)
.as_str(),
));
} }
visitor.visit_seq(SeqDeserializer { visitor.visit_seq(SeqDeserializer {
params: self.url_params, params: self.url_params,
idx: 0,
}) })
} }
@ -212,6 +172,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
visitor.visit_map(MapDeserializer { visitor.visit_map(MapDeserializer {
params: self.url_params, params: self.url_params,
value: None, value: None,
key: None,
}) })
} }
@ -237,10 +198,9 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
V: Visitor<'de>, V: Visitor<'de>,
{ {
if self.url_params.len() != 1 { if self.url_params.len() != 1 {
return Err(PathDeserializerError::custom(format!( return Err(PathDeserializationError::wrong_number_of_parameters()
"wrong number of parameters: {} expected 1", .got(self.url_params.len())
self.url_params.len() .expected(1));
)));
} }
visitor.visit_enum(EnumDeserializer { visitor.visit_enum(EnumDeserializer {
@ -251,11 +211,12 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
struct MapDeserializer<'de> { struct MapDeserializer<'de> {
params: &'de [(ByteStr, PercentDecodedByteStr)], params: &'de [(ByteStr, PercentDecodedByteStr)],
key: Option<KeyOrIdx>,
value: Option<&'de str>, value: Option<&'de str>,
} }
impl<'de> MapAccess<'de> for MapDeserializer<'de> { impl<'de> MapAccess<'de> for MapDeserializer<'de> {
type Error = PathDeserializerError; type Error = PathDeserializationError;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error> fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where where
@ -265,6 +226,7 @@ impl<'de> MapAccess<'de> for MapDeserializer<'de> {
Some(((key, value), tail)) => { Some(((key, value), tail)) => {
self.value = Some(value); self.value = Some(value);
self.params = tail; self.params = tail;
self.key = Some(KeyOrIdx::Key(key.clone()));
seed.deserialize(KeyDeserializer { key }).map(Some) seed.deserialize(KeyDeserializer { key }).map(Some)
} }
None => Ok(None), None => Ok(None),
@ -276,8 +238,11 @@ impl<'de> MapAccess<'de> for MapDeserializer<'de> {
V: DeserializeSeed<'de>, V: DeserializeSeed<'de>,
{ {
match self.value.take() { match self.value.take() {
Some(value) => seed.deserialize(ValueDeserializer { value }), Some(value) => seed.deserialize(ValueDeserializer {
None => Err(serde::de::Error::custom("value is missing")), key: self.key.take(),
value,
}),
None => Err(PathDeserializationError::custom("value is missing")),
} }
} }
} }
@ -298,7 +263,7 @@ macro_rules! parse_key {
} }
impl<'de> Deserializer<'de> for KeyDeserializer<'de> { impl<'de> Deserializer<'de> for KeyDeserializer<'de> {
type Error = PathDeserializerError; type Error = PathDeserializationError;
parse_key!(deserialize_identifier); parse_key!(deserialize_identifier);
parse_key!(deserialize_str); parse_key!(deserialize_str);
@ -308,7 +273,7 @@ impl<'de> Deserializer<'de> for KeyDeserializer<'de> {
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
Err(PathDeserializerError::custom("Unexpected")) Err(PathDeserializationError::custom("Unexpected key type"))
} }
forward_to_deserialize_any! { forward_to_deserialize_any! {
@ -320,15 +285,31 @@ impl<'de> Deserializer<'de> for KeyDeserializer<'de> {
macro_rules! parse_value { macro_rules! parse_value {
($trait_fn:ident, $visit_fn:ident, $ty:literal) => { ($trait_fn:ident, $visit_fn:ident, $ty:literal) => {
fn $trait_fn<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn $trait_fn<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
let v = self.value.parse().map_err(|_| { let v = self.value.parse().map_err(|_| {
PathDeserializerError::custom(format!( if let Some(key) = self.key.take() {
"can not parse `{:?}` to a `{}`", let kind = match key {
self.value, $ty KeyOrIdx::Key(key) => ErrorKind::ParseErrorAtKey {
)) key: key.as_str().to_owned(),
value: self.value.to_owned(),
expected_type: $ty,
},
KeyOrIdx::Idx(index) => ErrorKind::ParseErrorAtIndex {
index,
value: self.value.to_owned(),
expected_type: $ty,
},
};
PathDeserializationError::new(kind)
} else {
PathDeserializationError::new(ErrorKind::ParseError {
value: self.value.to_owned(),
expected_type: $ty,
})
}
})?; })?;
visitor.$visit_fn(v) visitor.$visit_fn(v)
} }
@ -336,11 +317,12 @@ macro_rules! parse_value {
} }
struct ValueDeserializer<'de> { struct ValueDeserializer<'de> {
key: Option<KeyOrIdx>,
value: &'de str, value: &'de str,
} }
impl<'de> Deserializer<'de> for ValueDeserializer<'de> { impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
type Error = PathDeserializerError; type Error = PathDeserializationError;
unsupported_type!(deserialize_any, "any"); unsupported_type!(deserialize_any, "any");
unsupported_type!(deserialize_seq, "seq"); unsupported_type!(deserialize_seq, "seq");
@ -418,7 +400,9 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
Err(PathDeserializerError::custom("unsupported type: tuple")) Err(PathDeserializationError::unsupported_type(type_name::<
V::Value,
>()))
} }
fn deserialize_tuple_struct<V>( fn deserialize_tuple_struct<V>(
@ -430,9 +414,9 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
Err(PathDeserializerError::custom( Err(PathDeserializationError::unsupported_type(type_name::<
"unsupported type: tuple struct", V::Value,
)) >()))
} }
fn deserialize_struct<V>( fn deserialize_struct<V>(
@ -444,7 +428,9 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
Err(PathDeserializerError::custom("unsupported type: struct")) Err(PathDeserializationError::unsupported_type(type_name::<
V::Value,
>()))
} }
fn deserialize_enum<V>( fn deserialize_enum<V>(
@ -472,7 +458,7 @@ struct EnumDeserializer<'de> {
} }
impl<'de> EnumAccess<'de> for EnumDeserializer<'de> { impl<'de> EnumAccess<'de> for EnumDeserializer<'de> {
type Error = PathDeserializerError; type Error = PathDeserializationError;
type Variant = UnitVariant; type Variant = UnitVariant;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
@ -489,7 +475,7 @@ impl<'de> EnumAccess<'de> for EnumDeserializer<'de> {
struct UnitVariant; struct UnitVariant;
impl<'de> VariantAccess<'de> for UnitVariant { impl<'de> VariantAccess<'de> for UnitVariant {
type Error = PathDeserializerError; type Error = PathDeserializationError;
fn unit_variant(self) -> Result<(), Self::Error> { fn unit_variant(self) -> Result<(), Self::Error> {
Ok(()) Ok(())
@ -499,14 +485,18 @@ impl<'de> VariantAccess<'de> for UnitVariant {
where where
T: DeserializeSeed<'de>, T: DeserializeSeed<'de>,
{ {
Err(PathDeserializerError::custom("not supported")) Err(PathDeserializationError::unsupported_type(
"newtype enum variant",
))
} }
fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error> fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
Err(PathDeserializerError::custom("not supported")) Err(PathDeserializationError::unsupported_type(
"tuple enum variant",
))
} }
fn struct_variant<V>( fn struct_variant<V>(
@ -517,16 +507,19 @@ impl<'de> VariantAccess<'de> for UnitVariant {
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
Err(PathDeserializerError::custom("not supported")) Err(PathDeserializationError::unsupported_type(
"struct enum variant",
))
} }
} }
struct SeqDeserializer<'de> { struct SeqDeserializer<'de> {
params: &'de [(ByteStr, PercentDecodedByteStr)], params: &'de [(ByteStr, PercentDecodedByteStr)],
idx: usize,
} }
impl<'de> SeqAccess<'de> for SeqDeserializer<'de> { impl<'de> SeqAccess<'de> for SeqDeserializer<'de> {
type Error = PathDeserializerError; type Error = PathDeserializationError;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error> fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where where
@ -535,13 +528,24 @@ impl<'de> SeqAccess<'de> for SeqDeserializer<'de> {
match self.params.split_first() { match self.params.split_first() {
Some(((_, value), tail)) => { Some(((_, value), tail)) => {
self.params = tail; self.params = tail;
Ok(Some(seed.deserialize(ValueDeserializer { value })?)) let idx = self.idx;
self.idx += 1;
Ok(Some(seed.deserialize(ValueDeserializer {
key: Some(KeyOrIdx::Idx(idx)),
value,
})?))
} }
None => Ok(None), None => Ok(None),
} }
} }
} }
#[derive(Clone)]
enum KeyOrIdx {
Key(ByteStr),
Idx(usize),
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -614,10 +618,16 @@ mod tests {
); );
let url_params = create_url_params(vec![("a", "1"), ("b", "2")]); let url_params = create_url_params(vec![("a", "1"), ("b", "2")]);
assert_eq!( let error_kind = i32::deserialize(PathDeserializer::new(&url_params))
i32::deserialize(PathDeserializer::new(&url_params)).unwrap_err(), .unwrap_err()
PathDeserializerError::custom("wrong number of parameters: 2 expected 1".to_owned()) .kind;
); assert!(matches!(
error_kind,
ErrorKind::WrongNumberOfParameters {
expected: 1,
got: 2
}
));
} }
#[test] #[test]
@ -661,6 +671,38 @@ mod tests {
); );
} }
#[test]
fn test_parse_struct_ignoring_additional_fields() {
let url_params = create_url_params(vec![
("a", "1"),
("b", "true"),
("c", "abc"),
("d", "false"),
]);
assert_eq!(
Struct::deserialize(PathDeserializer::new(&url_params)).unwrap(),
Struct {
c: "abc".to_owned(),
b: true,
a: 1,
}
);
}
#[test]
fn test_parse_tuple_ignoring_additional_fields() {
let url_params = create_url_params(vec![
("a", "abc"),
("b", "true"),
("c", "1"),
("d", "false"),
]);
assert_eq!(
<(&str, bool, u32)>::deserialize(PathDeserializer::new(&url_params)).unwrap(),
("abc", true, 1)
);
}
#[test] #[test]
fn test_parse_map() { fn test_parse_map() {
let url_params = create_url_params(vec![("a", "1"), ("b", "true"), ("c", "abc")]); let url_params = create_url_params(vec![("a", "1"), ("b", "true"), ("c", "abc")]);
@ -672,4 +714,110 @@ mod tests {
.collect() .collect()
); );
} }
macro_rules! test_parse_error {
(
$params:expr,
$ty:ty,
$expected_error_kind:expr $(,)?
) => {
let url_params = create_url_params($params);
let actual_error_kind = <$ty>::deserialize(PathDeserializer::new(&url_params))
.unwrap_err()
.kind;
assert_eq!(actual_error_kind, $expected_error_kind);
};
}
#[test]
fn test_wrong_number_of_parameters_error() {
test_parse_error!(
vec![("a", "1")],
(u32, u32),
ErrorKind::WrongNumberOfParameters {
got: 1,
expected: 2,
}
);
}
#[test]
fn test_parse_error_at_key_error() {
#[derive(Debug, Deserialize)]
struct Params {
a: u32,
}
test_parse_error!(
vec![("a", "false")],
Params,
ErrorKind::ParseErrorAtKey {
key: "a".to_owned(),
value: "false".to_owned(),
expected_type: "u32",
}
);
}
#[test]
fn test_parse_error_at_key_error_multiple() {
#[derive(Debug, Deserialize)]
struct Params {
a: u32,
b: u32,
}
test_parse_error!(
vec![("a", "false")],
Params,
ErrorKind::ParseErrorAtKey {
key: "a".to_owned(),
value: "false".to_owned(),
expected_type: "u32",
}
);
}
#[test]
fn test_parse_error_at_index_error() {
test_parse_error!(
vec![("a", "false"), ("b", "true")],
(bool, u32),
ErrorKind::ParseErrorAtIndex {
index: 1,
value: "true".to_owned(),
expected_type: "u32",
}
);
}
#[test]
fn test_parse_error_error() {
test_parse_error!(
vec![("a", "false")],
u32,
ErrorKind::ParseError {
value: "false".to_owned(),
expected_type: "u32",
}
);
}
#[test]
fn test_unsupported_type_error_nested_data_structure() {
test_parse_error!(
vec![("a", "false")],
Vec<Vec<u32>>,
ErrorKind::UnsupportedType {
name: "alloc::vec::Vec<u32>",
}
);
}
#[test]
fn test_unsupported_type_error_tuple() {
test_parse_error!(
vec![("a", "false")],
Vec<(u32, u32)>,
ErrorKind::UnsupportedType { name: "(u32, u32)" }
);
}
} }

View file

@ -1,14 +1,20 @@
//! Extractor that will get captures from the URL and parse them using
//! [`serde`].
mod de; mod de;
use super::{rejection::*, FromRequest};
use crate::{ use crate::{
extract::RequestParts, body::{boxed, BoxBody, Full},
extract::{rejection::*, FromRequest, RequestParts},
response::IntoResponse,
routing::{InvalidUtf8InPathParam, UrlParams}, routing::{InvalidUtf8InPathParam, UrlParams},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use http::StatusCode;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::{ use std::{
borrow::Cow, borrow::Cow,
fmt,
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
}; };
@ -121,8 +127,14 @@ use std::{
/// # }; /// # };
/// ``` /// ```
/// ///
/// # Providing detailed rejection output
///
/// If the URI cannot be deserialized into the target type the request will be rejected and an
/// error response will be returned. See [`customize-path-rejection`] for an exapmle of how to customize that error.
///
/// [`serde`]: https://crates.io/crates/serde /// [`serde`]: https://crates.io/crates/serde
/// [`serde::Deserialize`]: https://docs.rs/serde/1.0.127/serde/trait.Deserialize.html /// [`serde::Deserialize`]: https://docs.rs/serde/1.0.127/serde/trait.Deserialize.html
/// [`customize-path-rejection`]: https://github.com/tokio-rs/axum/blob/main/examples/customize-path-rejection/src/main.rs
#[derive(Debug)] #[derive(Debug)]
pub struct Path<T>(pub T); pub struct Path<T>(pub T);
@ -148,7 +160,7 @@ where
T: DeserializeOwned + Send, T: DeserializeOwned + Send,
B: Send, B: Send,
{ {
type Rejection = PathParamsRejection; type Rejection = PathRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let params = match req let params = match req
@ -157,20 +169,236 @@ where
{ {
Some(Some(UrlParams(Ok(params)))) => Cow::Borrowed(params), Some(Some(UrlParams(Ok(params)))) => Cow::Borrowed(params),
Some(Some(UrlParams(Err(InvalidUtf8InPathParam { key })))) => { Some(Some(UrlParams(Err(InvalidUtf8InPathParam { key })))) => {
return Err(InvalidPathParam::new(key.as_str()).into()) let err = PathDeserializationError {
kind: ErrorKind::InvalidUtf8InPathParam {
key: key.as_str().to_owned(),
},
};
let err = FailedToDeserializePathParams(err);
return Err(err.into());
} }
Some(None) => Cow::Owned(Vec::new()), Some(None) => Cow::Owned(Vec::new()),
None => { None => {
return Err(MissingRouteParams.into()); return Err(MissingPathParams.into());
} }
}; };
T::deserialize(de::PathDeserializer::new(&*params)) T::deserialize(de::PathDeserializer::new(&*params))
.map_err(|err| PathParamsRejection::InvalidPathParam(InvalidPathParam::new(err.0))) .map_err(|err| {
PathRejection::FailedToDeserializePathParams(FailedToDeserializePathParams(err))
})
.map(Path) .map(Path)
} }
} }
// this wrapper type is used as the deserializer error to hide the `serde::de::Error` impl which
// would otherwise be public if we used `ErrorKind` as the error directly
#[derive(Debug)]
pub(crate) struct PathDeserializationError {
pub(super) kind: ErrorKind,
}
impl PathDeserializationError {
pub(super) fn new(kind: ErrorKind) -> Self {
Self { kind }
}
pub(super) fn wrong_number_of_parameters() -> WrongNumberOfParameters<()> {
WrongNumberOfParameters { got: () }
}
pub(super) fn unsupported_type(name: &'static str) -> Self {
Self::new(ErrorKind::UnsupportedType { name })
}
}
pub(super) struct WrongNumberOfParameters<G> {
got: G,
}
impl<G> WrongNumberOfParameters<G> {
#[allow(clippy::unused_self)]
pub(super) fn got<G2>(self, got: G2) -> WrongNumberOfParameters<G2> {
WrongNumberOfParameters { got }
}
}
impl WrongNumberOfParameters<usize> {
pub(super) fn expected(self, expected: usize) -> PathDeserializationError {
PathDeserializationError::new(ErrorKind::WrongNumberOfParameters {
got: self.got,
expected,
})
}
}
impl serde::de::Error for PathDeserializationError {
#[inline]
fn custom<T>(msg: T) -> Self
where
T: fmt::Display,
{
Self {
kind: ErrorKind::Message(msg.to_string()),
}
}
}
impl fmt::Display for PathDeserializationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.kind.fmt(f)
}
}
impl std::error::Error for PathDeserializationError {}
/// The kinds of errors that can happen we deserializing into a [`Path`].
///
/// This type is obtained through [`FailedToDeserializePathParams::into_kind`] and is useful for building
/// more precise error messages.
#[derive(Debug, PartialEq)]
#[non_exhaustive]
pub enum ErrorKind {
/// The URI contained the wrong number of parameters.
WrongNumberOfParameters {
/// The number of actual parameters in the URI.
got: usize,
/// The number of expected parameters.
expected: usize,
},
/// Failed to parse the value at a specific key into the expected type.
///
/// This variant is used when deserializing into types that have named fields, such as structs.
ParseErrorAtKey {
/// The key at which the value was located.
key: String,
/// The value from the URI.
value: String,
/// The expected type of the value.
expected_type: &'static str,
},
/// Failed to parse the value at a specific index into the expected type.
///
/// This variant is used when deserializing into sequence types, such as tuples.
ParseErrorAtIndex {
/// The index at which the value was located.
index: usize,
/// The value from the URI.
value: String,
/// The expected type of the value.
expected_type: &'static str,
},
/// Failed to parse a value into the expected type.
///
/// This variant is used when deserializing into a primitive type (such as `String` and `u32`).
ParseError {
/// The value from the URI.
value: String,
/// The expected type of the value.
expected_type: &'static str,
},
/// A parameter contained text that, once percent decoded, wasn't valid UTF-8.
InvalidUtf8InPathParam {
/// The key at which the invalid value was located.
key: String,
},
/// Tried to serialize into an unsupported type such as nested maps.
///
/// This error kind is caused by programmer errors and thus gets converted into a `500 Internal
/// Server Error` response.
UnsupportedType {
/// The name of the unsupported type.
name: &'static str,
},
/// Catch-all variant for errors that don't fit any other variant.
Message(String),
}
impl fmt::Display for ErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ErrorKind::Message(error) => error.fmt(f),
ErrorKind::InvalidUtf8InPathParam { key } => write!(f, "Invalid UTF-8 in `{}`", key),
ErrorKind::WrongNumberOfParameters { got, expected } => write!(
f,
"Wronger number of parameters. Expected {} but got {}",
expected, got
),
ErrorKind::UnsupportedType { name } => write!(f, "Unsupported type `{}`", name),
ErrorKind::ParseErrorAtKey {
key,
value,
expected_type,
} => write!(
f,
"Cannot parse `{}` with value `{:?}` to a `{}`",
key, value, expected_type
),
ErrorKind::ParseError {
value,
expected_type,
} => write!(f, "Cannot parse `{:?}` to a `{}`", value, expected_type),
ErrorKind::ParseErrorAtIndex {
index,
value,
expected_type,
} => write!(
f,
"Cannot parse value at index {} with value `{:?}` to a `{}`",
index, value, expected_type
),
}
}
}
/// Rejection type for [`Path`](super::Path) if the captured routes params couldn't be deserialized
/// into the expected type.
#[derive(Debug)]
pub struct FailedToDeserializePathParams(PathDeserializationError);
impl FailedToDeserializePathParams {
/// Convert this error into the underlying error kind.
pub fn into_kind(self) -> ErrorKind {
self.0.kind
}
}
impl IntoResponse for FailedToDeserializePathParams {
fn into_response(self) -> http::Response<BoxBody> {
let (status, body) = match self.0.kind {
ErrorKind::Message(_)
| ErrorKind::InvalidUtf8InPathParam { .. }
| ErrorKind::WrongNumberOfParameters { .. }
| ErrorKind::ParseError { .. }
| ErrorKind::ParseErrorAtIndex { .. }
| ErrorKind::ParseErrorAtKey { .. } => (
StatusCode::BAD_REQUEST,
format!("Invalid URL: {}", self.0.kind),
),
ErrorKind::UnsupportedType { .. } => {
(StatusCode::INTERNAL_SERVER_ERROR, self.0.kind.to_string())
}
};
let mut res = http::Response::new(boxed(Full::from(body)));
*res.status_mut() = status;
res
}
}
impl fmt::Display for FailedToDeserializePathParams {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl std::error::Error for FailedToDeserializePathParams {}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use http::StatusCode; use http::StatusCode;

View file

@ -7,6 +7,7 @@ use crate::{
}; };
use http_body::Full; use http_body::Full;
pub use crate::extract::path::FailedToDeserializePathParams;
pub use axum_core::extract::rejection::*; pub use axum_core::extract::rejection::*;
#[cfg(feature = "json")] #[cfg(feature = "json")]
@ -52,9 +53,10 @@ define_rejection! {
define_rejection! { define_rejection! {
#[status = INTERNAL_SERVER_ERROR] #[status = INTERNAL_SERVER_ERROR]
#[body = "No url params found for matched route. This is a bug in axum. Please open an issue"] #[body = "No paths parameters found for matched route. This is a bug in axum. Please open an issue"]
/// Rejection type used if you try and extract the URL params more than once. /// Rejection type used if axum's internal representation of path parameters is missing. This
pub struct MissingRouteParams; /// should never happen and is a bug in axum if it does.
pub struct MissingPathParams;
} }
define_rejection! { define_rejection! {
@ -64,33 +66,6 @@ define_rejection! {
pub struct InvalidFormContentType; pub struct InvalidFormContentType;
} }
/// Rejection type for [`Path`](super::Path) if the capture route
/// param didn't have the expected type.
#[derive(Debug)]
pub struct InvalidPathParam(pub(crate) String);
impl InvalidPathParam {
pub(super) fn new(err: impl Into<String>) -> Self {
InvalidPathParam(err.into())
}
}
impl IntoResponse for InvalidPathParam {
fn into_response(self) -> http::Response<BoxBody> {
let mut res = http::Response::new(boxed(Full::from(self.to_string())));
*res.status_mut() = http::StatusCode::BAD_REQUEST;
res
}
}
impl std::fmt::Display for InvalidPathParam {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Invalid URL param. {}", self.0)
}
}
impl std::error::Error for InvalidPathParam {}
/// Rejection type for extractors that deserialize query strings if the input /// Rejection type for extractors that deserialize query strings if the input
/// couldn't be deserialized into the target type. /// couldn't be deserialized into the target type.
#[derive(Debug)] #[derive(Debug)]
@ -185,9 +160,9 @@ composite_rejection! {
/// ///
/// Contains one variant for each way the [`Path`](super::Path) extractor /// Contains one variant for each way the [`Path`](super::Path) extractor
/// can fail. /// can fail.
pub enum PathParamsRejection { pub enum PathRejection {
InvalidPathParam, FailedToDeserializePathParams,
MissingRouteParams, MissingPathParams,
} }
} }

View file

@ -1,5 +1,6 @@
use bytes::Bytes; use bytes::Bytes;
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use std::fmt;
use std::ops::Deref; use std::ops::Deref;
/// A string like type backed by `Bytes` making it cheap to clone. /// A string like type backed by `Bytes` making it cheap to clone.
@ -30,6 +31,12 @@ impl ByteStr {
} }
} }
impl fmt::Display for ByteStr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct PercentDecodedByteStr(ByteStr); pub(crate) struct PercentDecodedByteStr(ByteStr);

View file

@ -0,0 +1,13 @@
[package]
name = "example-customize-path-rejection"
version = "0.1.0"
edition = "2018"
publish = false
[dependencies]
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tracing = "0.1"
tracing-subscriber = { version="0.3", features = ["env-filter"] }

View file

@ -0,0 +1,144 @@
//! Run with
//!
//! ```not_rust
//! cargo run -p example-customize-path-rejection
//! ```
use axum::{
async_trait,
extract::{path::ErrorKind, rejection::PathRejection, FromRequest, RequestParts},
http::StatusCode,
response::IntoResponse,
routing::get,
Router,
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::net::SocketAddr;
#[tokio::main]
async fn main() {
// Set the RUST_LOG, if it hasn't been explicitly defined
if std::env::var_os("RUST_LOG").is_none() {
std::env::set_var("RUST_LOG", "example_customize_path_rejection=debug")
}
tracing_subscriber::fmt::init();
// build our application with a route
let app = Router::new().route("/users/:user_id/teams/:team_id", get(handler));
// run it
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
println!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
async fn handler(Path(params): Path<Params>) -> impl IntoResponse {
axum::Json(params)
}
#[derive(Debug, Deserialize, Serialize)]
struct Params {
user_id: u32,
team_id: u32,
}
// We define our own `Path` extractor that customizes the error from `axum::extract::Path`
struct Path<T>(T);
#[async_trait]
impl<B, T> FromRequest<B> for Path<T>
where
// these trait bounds are copied from `impl FromRequest for axum::extract::path::Path`
T: DeserializeOwned + Send,
B: Send,
{
type Rejection = (StatusCode, axum::Json<PathError>);
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
match axum::extract::Path::<T>::from_request(req).await {
Ok(value) => Ok(Self(value.0)),
Err(rejection) => {
let (status, body) = match rejection {
PathRejection::FailedToDeserializePathParams(inner) => {
let mut status = StatusCode::BAD_REQUEST;
let kind = inner.into_kind();
let body = match &kind {
ErrorKind::WrongNumberOfParameters { .. } => PathError {
message: kind.to_string(),
location: None,
},
ErrorKind::ParseErrorAtKey { key, .. } => PathError {
message: kind.to_string(),
location: Some(key.clone()),
},
ErrorKind::ParseErrorAtIndex { index, .. } => PathError {
message: kind.to_string(),
location: Some(index.to_string()),
},
ErrorKind::ParseError { .. } => PathError {
message: kind.to_string(),
location: None,
},
ErrorKind::InvalidUtf8InPathParam { key } => PathError {
message: kind.to_string(),
location: Some(key.clone()),
},
ErrorKind::UnsupportedType { .. } => {
// this error is caused by the programmer using an unsupported type
// (such as nested maps) so respond with `500` instead
status = StatusCode::INTERNAL_SERVER_ERROR;
PathError {
message: kind.to_string(),
location: None,
}
}
ErrorKind::Message(msg) => PathError {
message: msg.clone(),
location: None,
},
_ => PathError {
message: format!("Unhandled deserialization error: {}", kind),
location: None,
},
};
(status, body)
}
PathRejection::MissingPathParams(error) => (
StatusCode::INTERNAL_SERVER_ERROR,
PathError {
message: error.to_string(),
location: None,
},
),
_ => (
StatusCode::INTERNAL_SERVER_ERROR,
PathError {
message: format!("Unhandled path rejection: {}", rejection),
location: None,
},
),
};
Err((status, axum::Json(body)))
}
}
}
}
#[derive(Serialize)]
struct PathError {
message: String,
location: Option<String>,
}