From 8cc052f38bd295aceaa934fa60c1bffd11de2004 Mon Sep 17 00:00:00 2001 From: Thomas Scholtes Date: Tue, 3 May 2022 20:44:58 +0200 Subject: [PATCH] Make `Path` extractor work with `Deserialize` impls using `&str` (#990) * `Path` extractor works with `Deserialize` impls using `&str` Before this change the extractor `Path` would fail if the `Deserialize` implementation of `Test` was calling `Deserializer::deserialize_str()`. Now we use `Visitor::visit_borrowed_str()` instead of `Visitor::visit_str()` which is also recommended in the guide to implement a deserializer [1]. [1]: https://serde.rs/impl-deserializer.html * fixup! `Path` extractor works with `Deserialize` impls using `&str` * add test for percent decoding Co-authored-by: David Pedersen --- axum/src/extract/path/de.rs | 4 +++- axum/src/extract/path/mod.rs | 25 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/axum/src/extract/path/de.rs b/axum/src/extract/path/de.rs index 62236c6e..42403109 100644 --- a/axum/src/extract/path/de.rs +++ b/axum/src/extract/path/de.rs @@ -88,7 +88,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { .got(self.url_params.len()) .expected(1)); } - visitor.visit_str(&self.url_params[0].1) + visitor.visit_borrowed_str(&self.url_params[0].1) } fn deserialize_unit(self, visitor: V) -> Result @@ -608,6 +608,8 @@ mod tests { check_single_value!(f64, "123", 123.0); check_single_value!(String, "abc", "abc"); check_single_value!(String, "one%20two", "one two"); + check_single_value!(&str, "abc", "abc"); + check_single_value!(&str, "one%20two", "one two"); check_single_value!(char, "a", 'a'); let url_params = create_url_params(vec![("a", "B")]); diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index 39fd3bb1..0dc0167b 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -514,4 +514,29 @@ mod tests { "No paths parameters found for matched route. Are you also extracting `Request<_>`?" ); } + + #[tokio::test] + async fn str_reference_deserialize() { + struct Param(String); + impl<'de> serde::Deserialize<'de> for Param { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = <&str as serde::Deserialize>::deserialize(deserializer)?; + Ok(Param(s.to_owned())) + } + } + + let app = Router::new().route("/:key", get(|param: Path| async move { param.0 .0 })); + + let client = TestClient::new(app); + + let res = client.get("/foo").send().await; + assert_eq!(res.text().await, "foo"); + + // percent decoding should also work + let res = client.get("/foo%20bar").send().await; + assert_eq!(res.text().await, "foo bar"); + } }