Support deserialize_any for Path (#1693)

This commit is contained in:
David Pedersen 2023-01-13 13:27:38 +01:00 committed by GitHub
parent 25a46fbe79
commit 607a20dfac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 137 additions and 2 deletions

View file

@ -8,8 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased # Unreleased
- **added:** Implement `IntoResponse` for `&'static [u8; N]` and `[u8; N]` ([#1690]) - **added:** Implement `IntoResponse` for `&'static [u8; N]` and `[u8; N]` ([#1690])
- **fixed:** Make `Path` support types uses `serde::Deserializer::deserialize_any` ([#1693])
[#1690]: https://github.com/tokio-rs/axum/pull/1690 [#1690]: https://github.com/tokio-rs/axum/pull/1690
[#1693]: https://github.com/tokio-rs/axum/pull/1693
# 0.6.2 (9. January, 2023) # 0.6.2 (9. January, 2023)

View file

@ -74,6 +74,7 @@ quickcheck_macros = "1.0"
reqwest = { version = "0.11.11", default-features = false, features = ["json", "stream", "multipart"] } reqwest = { version = "0.11.11", default-features = false, features = ["json", "stream", "multipart"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
time = { version = "0.3", features = ["serde-human-readable"] }
tokio = { package = "tokio", version = "1.21", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] } tokio = { package = "tokio", version = "1.21", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] }
tokio-stream = "0.1" tokio-stream = "0.1"
tracing = "0.1" tracing = "0.1"

View file

@ -56,7 +56,6 @@ impl<'de> PathDeserializer<'de> {
impl<'de> Deserializer<'de> for PathDeserializer<'de> { impl<'de> Deserializer<'de> for PathDeserializer<'de> {
type Error = PathDeserializationError; type Error = PathDeserializationError;
unsupported_type!(deserialize_any);
unsupported_type!(deserialize_bytes); unsupported_type!(deserialize_bytes);
unsupported_type!(deserialize_option); unsupported_type!(deserialize_option);
unsupported_type!(deserialize_identifier); unsupported_type!(deserialize_identifier);
@ -79,6 +78,13 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
parse_single_value!(deserialize_byte_buf, visit_string, "String"); parse_single_value!(deserialize_byte_buf, visit_string, "String");
parse_single_value!(deserialize_char, visit_char, "char"); parse_single_value!(deserialize_char, visit_char, "char");
fn deserialize_any<V>(self, v: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_str(v)
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where where
V: Visitor<'de>, V: Visitor<'de>,
@ -328,7 +334,6 @@ struct ValueDeserializer<'de> {
impl<'de> Deserializer<'de> for ValueDeserializer<'de> { impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
type Error = PathDeserializationError; type Error = PathDeserializationError;
unsupported_type!(deserialize_any);
unsupported_type!(deserialize_map); unsupported_type!(deserialize_map);
unsupported_type!(deserialize_identifier); unsupported_type!(deserialize_identifier);
@ -349,6 +354,13 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
parse_value!(deserialize_byte_buf, visit_string, "String"); parse_value!(deserialize_byte_buf, visit_string, "String");
parse_value!(deserialize_char, visit_char, "char"); parse_value!(deserialize_char, visit_char, "char");
fn deserialize_any<V>(self, v: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_str(v)
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where where
V: Visitor<'de>, V: Visitor<'de>,

View file

@ -191,8 +191,11 @@ where
} }
}; };
// dbg!(&params);
T::deserialize(de::PathDeserializer::new(params)) T::deserialize(de::PathDeserializer::new(params))
.map_err(|err| { .map_err(|err| {
// dbg!(&err);
PathRejection::FailedToDeserializePathParams(FailedToDeserializePathParams(err)) PathRejection::FailedToDeserializePathParams(FailedToDeserializePathParams(err))
}) })
.map(Path) .map(Path)
@ -215,7 +218,9 @@ impl PathDeserializationError {
WrongNumberOfParameters { got: () } WrongNumberOfParameters { got: () }
} }
#[track_caller]
pub(super) fn unsupported_type(name: &'static str) -> Self { pub(super) fn unsupported_type(name: &'static str) -> Self {
println!("{}", std::panic::Location::caller());
Self::new(ErrorKind::UnsupportedType { name }) Self::new(ErrorKind::UnsupportedType { name })
} }
} }
@ -432,6 +437,7 @@ mod tests {
use super::*; use super::*;
use crate::{routing::get, test_helpers::*, Router}; use crate::{routing::get, test_helpers::*, Router};
use http::StatusCode; use http::StatusCode;
use serde::Deserialize;
use std::collections::HashMap; use std::collections::HashMap;
#[tokio::test] #[tokio::test]
@ -596,4 +602,118 @@ mod tests {
let res = client.get("/foo/bar").send().await; let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.status(), StatusCode::OK);
} }
#[tokio::test]
async fn type_that_uses_deserialize_any() {
use time::Date;
#[derive(Deserialize)]
struct Params {
a: Date,
b: Date,
c: Date,
}
let app = Router::new()
.route(
"/single/:a",
get(|Path(a): Path<Date>| async move { format!("single: {a}") }),
)
.route(
"/tuple/:a/:b/:c",
get(|Path((a, b, c)): Path<(Date, Date, Date)>| async move {
format!("tuple: {a} {b} {c}")
}),
)
.route(
"/vec/:a/:b/:c",
get(|Path(vec): Path<Vec<Date>>| async move {
let [a, b, c]: [Date; 3] = vec.try_into().unwrap();
format!("vec: {a} {b} {c}")
}),
)
.route(
"/vec_pairs/:a/:b/:c",
get(|Path(vec): Path<Vec<(String, Date)>>| async move {
let [(_, a), (_, b), (_, c)]: [(String, Date); 3] = vec.try_into().unwrap();
format!("vec_pairs: {a} {b} {c}")
}),
)
.route(
"/map/:a/:b/:c",
get(|Path(mut map): Path<HashMap<String, Date>>| async move {
let a = map.remove("a").unwrap();
let b = map.remove("b").unwrap();
let c = map.remove("c").unwrap();
format!("map: {a} {b} {c}")
}),
)
.route(
"/struct/:a/:b/:c",
get(|Path(params): Path<Params>| async move {
format!("struct: {} {} {}", params.a, params.b, params.c)
}),
);
let client = TestClient::new(app);
let res = client.get("/single/2023-01-01").send().await;
assert_eq!(res.text().await, "single: 2023-01-01");
let res = client
.get("/tuple/2023-01-01/2023-01-02/2023-01-03")
.send()
.await;
assert_eq!(res.text().await, "tuple: 2023-01-01 2023-01-02 2023-01-03");
let res = client
.get("/vec/2023-01-01/2023-01-02/2023-01-03")
.send()
.await;
assert_eq!(res.text().await, "vec: 2023-01-01 2023-01-02 2023-01-03");
let res = client
.get("/vec_pairs/2023-01-01/2023-01-02/2023-01-03")
.send()
.await;
assert_eq!(
res.text().await,
"vec_pairs: 2023-01-01 2023-01-02 2023-01-03",
);
let res = client
.get("/map/2023-01-01/2023-01-02/2023-01-03")
.send()
.await;
assert_eq!(res.text().await, "map: 2023-01-01 2023-01-02 2023-01-03");
let res = client
.get("/struct/2023-01-01/2023-01-02/2023-01-03")
.send()
.await;
assert_eq!(res.text().await, "struct: 2023-01-01 2023-01-02 2023-01-03");
}
#[tokio::test]
async fn wrong_number_of_parameters_json() {
use serde_json::Value;
let app = Router::new()
.route("/one/:a", get(|_: Path<(Value, Value)>| async {}))
.route("/two/:a/:b", get(|_: Path<Value>| async {}));
let client = TestClient::new(app);
let res = client.get("/one/1").send().await;
assert!(res
.text()
.await
.starts_with("Wrong number of path arguments for `Path`. Expected 2 but got 1"));
let res = client.get("/two/1/2").send().await;
assert!(res
.text()
.await
.starts_with("Wrong number of path arguments for `Path`. Expected 1 but got 2"));
}
} }