diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 3cb1939e..26ed1fae 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -8,8 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased - **added:** Implement `IntoResponse` for `&'static [u8; N]` and `[u8; N]` ([#1690]) +- **fixed:** Make `Path` support types uses `serde::Deserializer::deserialize_any` ([#1693]) [#1690]: https://github.com/tokio-rs/axum/pull/1690 +[#1693]: https://github.com/tokio-rs/axum/pull/1693 # 0.6.2 (9. January, 2023) diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 9e565936..de003dd0 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -74,6 +74,7 @@ quickcheck_macros = "1.0" reqwest = { version = "0.11.11", default-features = false, features = ["json", "stream", "multipart"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +time = { version = "0.3", features = ["serde-human-readable"] } tokio = { package = "tokio", version = "1.21", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] } tokio-stream = "0.1" tracing = "0.1" diff --git a/axum/src/extract/path/de.rs b/axum/src/extract/path/de.rs index 0adf52de..bca576eb 100644 --- a/axum/src/extract/path/de.rs +++ b/axum/src/extract/path/de.rs @@ -56,7 +56,6 @@ impl<'de> PathDeserializer<'de> { impl<'de> Deserializer<'de> for PathDeserializer<'de> { type Error = PathDeserializationError; - unsupported_type!(deserialize_any); unsupported_type!(deserialize_bytes); unsupported_type!(deserialize_option); unsupported_type!(deserialize_identifier); @@ -79,6 +78,13 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { parse_single_value!(deserialize_byte_buf, visit_string, "String"); parse_single_value!(deserialize_char, visit_char, "char"); + fn deserialize_any(self, v: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_str(v) + } + fn deserialize_str(self, visitor: V) -> Result where V: Visitor<'de>, @@ -328,7 +334,6 @@ struct ValueDeserializer<'de> { impl<'de> Deserializer<'de> for ValueDeserializer<'de> { type Error = PathDeserializationError; - unsupported_type!(deserialize_any); unsupported_type!(deserialize_map); unsupported_type!(deserialize_identifier); @@ -349,6 +354,13 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { parse_value!(deserialize_byte_buf, visit_string, "String"); parse_value!(deserialize_char, visit_char, "char"); + fn deserialize_any(self, v: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_str(v) + } + fn deserialize_str(self, visitor: V) -> Result where V: Visitor<'de>, diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index a5a239af..5bcb4378 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -191,8 +191,11 @@ where } }; + // dbg!(¶ms); + T::deserialize(de::PathDeserializer::new(params)) .map_err(|err| { + // dbg!(&err); PathRejection::FailedToDeserializePathParams(FailedToDeserializePathParams(err)) }) .map(Path) @@ -215,7 +218,9 @@ impl PathDeserializationError { WrongNumberOfParameters { got: () } } + #[track_caller] pub(super) fn unsupported_type(name: &'static str) -> Self { + println!("{}", std::panic::Location::caller()); Self::new(ErrorKind::UnsupportedType { name }) } } @@ -432,6 +437,7 @@ mod tests { use super::*; use crate::{routing::get, test_helpers::*, Router}; use http::StatusCode; + use serde::Deserialize; use std::collections::HashMap; #[tokio::test] @@ -596,4 +602,118 @@ mod tests { let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::OK); } + + #[tokio::test] + async fn type_that_uses_deserialize_any() { + use time::Date; + + #[derive(Deserialize)] + struct Params { + a: Date, + b: Date, + c: Date, + } + + let app = Router::new() + .route( + "/single/:a", + get(|Path(a): Path| async move { format!("single: {a}") }), + ) + .route( + "/tuple/:a/:b/:c", + get(|Path((a, b, c)): Path<(Date, Date, Date)>| async move { + format!("tuple: {a} {b} {c}") + }), + ) + .route( + "/vec/:a/:b/:c", + get(|Path(vec): Path>| async move { + let [a, b, c]: [Date; 3] = vec.try_into().unwrap(); + format!("vec: {a} {b} {c}") + }), + ) + .route( + "/vec_pairs/:a/:b/:c", + get(|Path(vec): Path>| async move { + let [(_, a), (_, b), (_, c)]: [(String, Date); 3] = vec.try_into().unwrap(); + format!("vec_pairs: {a} {b} {c}") + }), + ) + .route( + "/map/:a/:b/:c", + get(|Path(mut map): Path>| async move { + let a = map.remove("a").unwrap(); + let b = map.remove("b").unwrap(); + let c = map.remove("c").unwrap(); + format!("map: {a} {b} {c}") + }), + ) + .route( + "/struct/:a/:b/:c", + get(|Path(params): Path| async move { + format!("struct: {} {} {}", params.a, params.b, params.c) + }), + ); + + let client = TestClient::new(app); + + let res = client.get("/single/2023-01-01").send().await; + assert_eq!(res.text().await, "single: 2023-01-01"); + + let res = client + .get("/tuple/2023-01-01/2023-01-02/2023-01-03") + .send() + .await; + assert_eq!(res.text().await, "tuple: 2023-01-01 2023-01-02 2023-01-03"); + + let res = client + .get("/vec/2023-01-01/2023-01-02/2023-01-03") + .send() + .await; + assert_eq!(res.text().await, "vec: 2023-01-01 2023-01-02 2023-01-03"); + + let res = client + .get("/vec_pairs/2023-01-01/2023-01-02/2023-01-03") + .send() + .await; + assert_eq!( + res.text().await, + "vec_pairs: 2023-01-01 2023-01-02 2023-01-03", + ); + + let res = client + .get("/map/2023-01-01/2023-01-02/2023-01-03") + .send() + .await; + assert_eq!(res.text().await, "map: 2023-01-01 2023-01-02 2023-01-03"); + + let res = client + .get("/struct/2023-01-01/2023-01-02/2023-01-03") + .send() + .await; + assert_eq!(res.text().await, "struct: 2023-01-01 2023-01-02 2023-01-03"); + } + + #[tokio::test] + async fn wrong_number_of_parameters_json() { + use serde_json::Value; + + let app = Router::new() + .route("/one/:a", get(|_: Path<(Value, Value)>| async {})) + .route("/two/:a/:b", get(|_: Path| async {})); + + let client = TestClient::new(app); + + let res = client.get("/one/1").send().await; + assert!(res + .text() + .await + .starts_with("Wrong number of path arguments for `Path`. Expected 2 but got 1")); + + let res = client.get("/two/1/2").send().await; + assert!(res + .text() + .await + .starts_with("Wrong number of path arguments for `Path`. Expected 1 but got 2")); + } }