From 6c10f41d944c56b0d7e92b173bc84ba18f6b74c2 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 27 May 2022 14:11:33 +0200 Subject: [PATCH] Support `Path>` (#1059) * Support `Path>` * changelog --- axum/CHANGELOG.md | 5 ++ axum/src/extract/path/de.rs | 169 ++++++++++++++++++++++++++++------- axum/src/extract/path/mod.rs | 21 +++++ axum/src/util.rs | 4 + 4 files changed, 165 insertions(+), 34 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index c5f74759..65fd881c 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased - **added:** Implement `Default` for `Extension` ([#1043]) +- **fixed:** Support deserializing `Vec<(String, String)>` in `extract::Path<_>` to get vector of + key/value pairs ([#1059]) + +[#1043]: https://github.com/tokio-rs/axum/pull/1043 +[#1059]: https://github.com/tokio-rs/axum/pull/1059 # 0.5.6 (15. May, 2022) diff --git a/axum/src/extract/path/de.rs b/axum/src/extract/path/de.rs index 42403109..0adf52de 100644 --- a/axum/src/extract/path/de.rs +++ b/axum/src/extract/path/de.rs @@ -7,7 +7,7 @@ use serde::{ use std::{any::type_name, sync::Arc}; macro_rules! unsupported_type { - ($trait_fn:ident, $name:literal) => { + ($trait_fn:ident) => { fn $trait_fn(self, _: V) -> Result where V: Visitor<'de>, @@ -56,11 +56,11 @@ impl<'de> PathDeserializer<'de> { impl<'de> Deserializer<'de> for PathDeserializer<'de> { type Error = PathDeserializationError; - unsupported_type!(deserialize_any, "any"); - unsupported_type!(deserialize_bytes, "bytes"); - unsupported_type!(deserialize_option, "Option"); - unsupported_type!(deserialize_identifier, "identifier"); - unsupported_type!(deserialize_ignored_any, "ignored_any"); + unsupported_type!(deserialize_any); + unsupported_type!(deserialize_bytes); + unsupported_type!(deserialize_option); + unsupported_type!(deserialize_identifier); + unsupported_type!(deserialize_ignored_any); parse_single_value!(deserialize_bool, visit_bool, "bool"); parse_single_value!(deserialize_i8, visit_i8, "i8"); @@ -204,7 +204,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { } visitor.visit_enum(EnumDeserializer { - value: &self.url_params[0].1, + value: self.url_params[0].1.clone().into_inner(), }) } } @@ -212,7 +212,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { struct MapDeserializer<'de> { params: &'de [(Arc, PercentDecodedStr)], key: Option, - value: Option<&'de str>, + value: Option<&'de PercentDecodedStr>, } impl<'de> MapAccess<'de> for MapDeserializer<'de> { @@ -227,7 +227,10 @@ impl<'de> MapAccess<'de> for MapDeserializer<'de> { self.value = Some(value); self.params = tail; self.key = Some(KeyOrIdx::Key(key.clone())); - seed.deserialize(KeyDeserializer { key }).map(Some) + seed.deserialize(KeyDeserializer { + key: Arc::clone(key), + }) + .map(Some) } None => Ok(None), } @@ -247,8 +250,8 @@ impl<'de> MapAccess<'de> for MapDeserializer<'de> { } } -struct KeyDeserializer<'de> { - key: &'de str, +struct KeyDeserializer { + key: Arc, } macro_rules! parse_key { @@ -257,12 +260,12 @@ macro_rules! parse_key { where V: Visitor<'de>, { - visitor.visit_str(self.key) + visitor.visit_str(&self.key) } }; } -impl<'de> Deserializer<'de> for KeyDeserializer<'de> { +impl<'de> Deserializer<'de> for KeyDeserializer { type Error = PathDeserializationError; parse_key!(deserialize_identifier); @@ -294,19 +297,19 @@ macro_rules! parse_value { let kind = match key { KeyOrIdx::Key(key) => ErrorKind::ParseErrorAtKey { key: key.to_string(), - value: self.value.to_owned(), + value: self.value.as_str().to_owned(), expected_type: $ty, }, - KeyOrIdx::Idx(index) => ErrorKind::ParseErrorAtIndex { + KeyOrIdx::Idx { idx: index, key: _ } => ErrorKind::ParseErrorAtIndex { index, - value: self.value.to_owned(), + value: self.value.as_str().to_owned(), expected_type: $ty, }, }; PathDeserializationError::new(kind) } else { PathDeserializationError::new(ErrorKind::ParseError { - value: self.value.to_owned(), + value: self.value.as_str().to_owned(), expected_type: $ty, }) } @@ -316,18 +319,18 @@ macro_rules! parse_value { }; } +#[derive(Debug)] struct ValueDeserializer<'de> { key: Option, - value: &'de str, + value: &'de PercentDecodedStr, } impl<'de> Deserializer<'de> for ValueDeserializer<'de> { type Error = PathDeserializationError; - unsupported_type!(deserialize_any, "any"); - unsupported_type!(deserialize_seq, "seq"); - unsupported_type!(deserialize_map, "map"); - unsupported_type!(deserialize_identifier, "identifier"); + unsupported_type!(deserialize_any); + unsupported_type!(deserialize_map); + unsupported_type!(deserialize_identifier); parse_value!(deserialize_bool, visit_bool, "bool"); parse_value!(deserialize_i8, visit_i8, "i8"); @@ -396,7 +399,57 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { visitor.visit_newtype_struct(self) } - fn deserialize_tuple(self, _len: usize, _visitor: V) -> Result + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + struct PairDeserializer<'de> { + key: Option, + value: Option<&'de PercentDecodedStr>, + } + + impl<'de> SeqAccess<'de> for PairDeserializer<'de> { + type Error = PathDeserializationError; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + match self.key.take() { + Some(KeyOrIdx::Idx { idx: _, key }) => { + return seed.deserialize(KeyDeserializer { key }).map(Some); + } + // `KeyOrIdx::Key` is only used when deserializing maps so `deserialize_seq` + // wouldn't be called for that + Some(KeyOrIdx::Key(_)) => unreachable!(), + None => {} + }; + + self.value + .take() + .map(|value| seed.deserialize(ValueDeserializer { key: None, value })) + .transpose() + } + } + + if len == 2 { + match self.key { + Some(key) => visitor.visit_seq(PairDeserializer { + key: Some(key), + value: Some(self.value), + }), + // `self.key` is only `None` when deserializing maps so `deserialize_seq` + // wouldn't be called for that + None => unreachable!(), + } + } else { + Err(PathDeserializationError::unsupported_type(type_name::< + V::Value, + >())) + } + } + + fn deserialize_seq(self, _visitor: V) -> Result where V: Visitor<'de>, { @@ -442,7 +495,9 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { where V: Visitor<'de>, { - visitor.visit_enum(EnumDeserializer { value: self.value }) + visitor.visit_enum(EnumDeserializer { + value: self.value.clone().into_inner(), + }) } fn deserialize_ignored_any(self, visitor: V) -> Result @@ -453,11 +508,11 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { } } -struct EnumDeserializer<'de> { - value: &'de str, +struct EnumDeserializer { + value: Arc, } -impl<'de> EnumAccess<'de> for EnumDeserializer<'de> { +impl<'de> EnumAccess<'de> for EnumDeserializer { type Error = PathDeserializationError; type Variant = UnitVariant; @@ -526,12 +581,15 @@ impl<'de> SeqAccess<'de> for SeqDeserializer<'de> { T: DeserializeSeed<'de>, { match self.params.split_first() { - Some(((_, value), tail)) => { + Some(((key, value), tail)) => { self.params = tail; let idx = self.idx; self.idx += 1; Ok(Some(seed.deserialize(ValueDeserializer { - key: Some(KeyOrIdx::Idx(idx)), + key: Some(KeyOrIdx::Idx { + idx, + key: key.clone(), + }), value, })?)) } @@ -540,10 +598,10 @@ impl<'de> SeqAccess<'de> for SeqDeserializer<'de> { } } -#[derive(Clone)] +#[derive(Debug, Clone)] enum KeyOrIdx { Key(Arc), - Idx(usize), + Idx { idx: usize, key: Arc }, } #[cfg(test)] @@ -659,6 +717,27 @@ mod tests { ); } + #[test] + fn test_parse_seq_tuple_string_string() { + let url_params = create_url_params(vec![("a", "foo"), ("b", "bar")]); + assert_eq!( + >::deserialize(PathDeserializer::new(&url_params)).unwrap(), + vec![ + ("a".to_owned(), "foo".to_owned()), + ("b".to_owned(), "bar".to_owned()) + ] + ); + } + + #[test] + fn test_parse_seq_tuple_string_parse() { + let url_params = create_url_params(vec![("a", "1"), ("b", "2")]); + assert_eq!( + >::deserialize(PathDeserializer::new(&url_params)).unwrap(), + vec![("a".to_owned(), 1), ("b".to_owned(), 2)] + ); + } + #[test] fn test_parse_struct() { let url_params = create_url_params(vec![("a", "1"), ("b", "true"), ("c", "abc")]); @@ -816,11 +895,33 @@ mod tests { } #[test] - fn test_unsupported_type_error_tuple() { + fn test_parse_seq_tuple_unsupported_key_type() { test_parse_error!( vec![("a", "false")], - Vec<(u32, u32)>, - ErrorKind::UnsupportedType { name: "(u32, u32)" } + Vec<(u32, String)>, + ErrorKind::Message("Unexpected key type".to_owned()) + ); + } + + #[test] + fn test_parse_seq_wrong_tuple_length() { + test_parse_error!( + vec![("a", "false")], + Vec<(String, String, String)>, + ErrorKind::UnsupportedType { + name: "(alloc::string::String, alloc::string::String, alloc::string::String)", + } + ); + } + + #[test] + fn test_parse_seq_seq() { + test_parse_error!( + vec![("a", "false")], + Vec>, + ErrorKind::UnsupportedType { + name: "alloc::vec::Vec", + } ); } } diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index 540ff8a7..f559ae1f 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -571,4 +571,25 @@ mod tests { Note that multiple parameters must be extracted with a tuple `Path<(_, _)>` or a struct `Path`", ); } + + #[tokio::test] + async fn deserialize_into_vec_of_tuples() { + let app = Router::new().route( + "/:a/:b", + get(|Path(params): Path>| async move { + assert_eq!( + params, + vec![ + ("a".to_owned(), "foo".to_owned()), + ("b".to_owned(), "bar".to_owned()) + ] + ); + }), + ); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::OK); + } } diff --git a/axum/src/util.rs b/axum/src/util.rs index b5805e07..b494159c 100644 --- a/axum/src/util.rs +++ b/axum/src/util.rs @@ -18,6 +18,10 @@ impl PercentDecodedStr { pub(crate) fn as_str(&self) -> &str { &*self.0 } + + pub(crate) fn into_inner(self) -> Arc { + self.0 + } } impl Deref for PercentDecodedStr {