1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-04-26 13:56:22 +02:00

Query/Form: Use serde_path_to_error to report fields that failed to parse ()

This commit is contained in:
Tobias Bieniek 2024-12-20 11:42:56 +01:00 committed by GitHub
parent ab8d0088d0
commit 9cd5cc4fc1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 111 additions and 16 deletions

View file

@ -10,8 +10,10 @@ and this project adheres to [Semantic Versioning].
- **breaking:** `axum::extract::ws::Message` now uses `Bytes` in place of `Vec<u8>`,
and a new `Utf8Bytes` type in place of `String`, for its variants ([#3078])
- **changed:** Upgraded `tokio-tungstenite` to 0.26 ([#3078])
- **changed:** Query/Form: Use `serde_path_to_error` to report fields that failed to parse ([#3081])
[#3078]: https://github.com/tokio-rs/axum/pull/3078
[#3081]: https://github.com/tokio-rs/axum/pull/3081
# 0.10.0

View file

@ -23,7 +23,7 @@ cookie-private = ["cookie", "cookie?/private"]
cookie-signed = ["cookie", "cookie?/signed"]
cookie-key-expansion = ["cookie", "cookie?/key-expansion"]
erased-json = ["dep:serde_json", "dep:typed-json"]
form = ["dep:serde_html_form"]
form = ["dep:form_urlencoded", "dep:serde_html_form", "dep:serde_path_to_error"]
json-deserializer = ["dep:serde_json", "dep:serde_path_to_error"]
json-lines = [
"dep:serde_json",
@ -36,7 +36,7 @@ json-lines = [
multipart = ["dep:multer", "dep:fastrand"]
protobuf = ["dep:prost"]
scheme = []
query = ["dep:serde_html_form"]
query = ["dep:form_urlencoded", "dep:serde_html_form", "dep:serde_path_to_error"]
tracing = ["axum-core/tracing", "axum/tracing"]
typed-header = ["dep:headers"]
typed-routing = ["dep:axum-macros", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"]

View file

@ -56,7 +56,9 @@ where
.await
.map_err(FormRejection::RawFormRejection)?;
serde_html_form::from_bytes::<T>(&bytes)
let deserializer = serde_html_form::Deserializer::new(form_urlencoded::parse(&bytes));
serde_path_to_error::deserialize::<_, T>(deserializer)
.map(Self)
.map_err(|err| FormRejection::FailedToDeserializeForm(Error::new(err)))
}
@ -115,8 +117,10 @@ impl std::error::Error for FormRejection {
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{routing::post, Router};
use axum::routing::{on, post, MethodFilter};
use axum::Router;
use http::header::CONTENT_TYPE;
use mime::APPLICATION_WWW_FORM_URLENCODED;
use serde::Deserialize;
#[tokio::test]
@ -143,4 +147,41 @@ mod tests {
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "one,two");
}
#[tokio::test]
async fn deserialize_error_status_codes() {
#[allow(dead_code)]
#[derive(Deserialize)]
struct Payload {
a: i32,
}
let app = Router::new().route(
"/",
on(
MethodFilter::GET.or(MethodFilter::POST),
|_: Form<Payload>| async {},
),
);
let client = TestClient::new(app);
let res = client.get("/?a=false").await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(
res.text().await,
"Failed to deserialize form: a: invalid digit found in string"
);
let res = client
.post("/")
.header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
.body("a=false")
.await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(
res.text().await,
"Failed to deserialize form: a: invalid digit found in string"
);
}
}

View file

@ -103,7 +103,9 @@ where
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let query = parts.uri.query().unwrap_or_default();
let value = serde_html_form::from_str(query)
let deserializer =
serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
let value = serde_path_to_error::deserialize(deserializer)
.map_err(|err| QueryRejection::FailedToDeserializeQueryString(Error::new(err)))?;
Ok(Query(value))
}
@ -121,7 +123,9 @@ where
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
if let Some(query) = parts.uri.query() {
let value = serde_html_form::from_str(query)
let deserializer =
serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
let value = serde_path_to_error::deserialize(deserializer)
.map_err(|err| QueryRejection::FailedToDeserializeQueryString(Error::new(err)))?;
Ok(Some(Self(value)))
} else {
@ -230,7 +234,9 @@ where
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(query) = parts.uri.query() {
let value = serde_html_form::from_str(query).map_err(|err| {
let deserializer =
serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
let value = serde_path_to_error::deserialize(deserializer).map_err(|err| {
OptionalQueryRejection::FailedToDeserializeQueryString(Error::new(err))
})?;
Ok(OptionalQuery(Some(value)))
@ -302,7 +308,8 @@ impl std::error::Error for OptionalQueryRejection {
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{routing::post, Router};
use axum::routing::{get, post};
use axum::Router;
use http::header::CONTENT_TYPE;
use serde::Deserialize;
@ -331,6 +338,27 @@ mod tests {
assert_eq!(res.text().await, "one,two");
}
#[tokio::test]
async fn correct_rejection_status_code() {
#[derive(Deserialize)]
#[allow(dead_code)]
struct Params {
n: i32,
}
async fn handler(_: Query<Params>) {}
let app = Router::new().route("/", get(handler));
let client = TestClient::new(app);
let res = client.get("/?n=hi").await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(
res.text().await,
"Failed to deserialize query string: n: invalid digit found in string"
);
}
#[tokio::test]
async fn optional_query_supports_multiple_values() {
#[derive(Deserialize)]

View file

@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- **changed:** Query/Form: Use `serde_path_to_error` to report fields that failed to parse ([#3081])
[#3081]: https://github.com/tokio-rs/axum/pull/3081
# 0.8.0
## rc.1

View file

@ -23,7 +23,7 @@ default = [
"tower-log",
"tracing",
]
form = ["dep:serde_urlencoded"]
form = ["dep:form_urlencoded", "dep:serde_urlencoded", "dep:serde_path_to_error"]
http1 = ["dep:hyper", "hyper?/http1", "hyper-util?/http1"]
http2 = ["dep:hyper", "hyper?/http2", "hyper-util?/http2"]
json = ["dep:serde_json", "dep:serde_path_to_error"]
@ -31,7 +31,7 @@ macros = ["dep:axum-macros"]
matched-path = []
multipart = ["dep:multer"]
original-uri = []
query = ["dep:serde_urlencoded"]
query = ["dep:form_urlencoded", "dep:serde_urlencoded", "dep:serde_path_to_error"]
tokio = ["dep:hyper-util", "dep:tokio", "tokio/net", "tokio/rt", "tower/make", "tokio/macros"]
tower-log = ["tower/log"]
tracing = ["dep:tracing", "axum-core/tracing"]
@ -68,6 +68,7 @@ tower-service = "0.3"
# optional dependencies
axum-macros = { path = "../axum-macros", version = "0.5.0-rc.1", optional = true }
base64 = { version = "0.22.1", optional = true }
form_urlencoded = { version = "1.1.0", optional = true }
hyper = { version = "1.1.0", optional = true }
hyper-util = { version = "0.1.3", features = ["tokio", "server", "service"], optional = true }
multer = { version = "3.0.0", optional = true }

View file

@ -87,7 +87,9 @@ where
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
if let Some(query) = parts.uri.query() {
let value = serde_urlencoded::from_str(query)
let deserializer =
serde_urlencoded::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
let value = serde_path_to_error::deserialize(deserializer)
.map_err(FailedToDeserializeQueryString::from_err)?;
Ok(Some(Self(value)))
} else {
@ -121,8 +123,10 @@ where
/// ```
pub fn try_from_uri(value: &Uri) -> Result<Self, QueryRejection> {
let query = value.query().unwrap_or_default();
let params =
serde_urlencoded::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?;
let deserializer =
serde_urlencoded::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
let params = serde_path_to_error::deserialize(deserializer)
.map_err(FailedToDeserializeQueryString::from_err)?;
Ok(Query(params))
}
}
@ -201,6 +205,10 @@ mod tests {
let res = client.get("/?n=hi").await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(
res.text().await,
"Failed to deserialize query string: n: invalid digit found in string"
);
}
#[test]

View file

@ -84,14 +84,17 @@ where
match req.extract().await {
Ok(RawForm(bytes)) => {
let value =
serde_urlencoded::from_bytes(&bytes).map_err(|err| -> FormRejection {
let deserializer =
serde_urlencoded::Deserializer::new(form_urlencoded::parse(&bytes));
let value = serde_path_to_error::deserialize(deserializer).map_err(
|err| -> FormRejection {
if is_get_or_head {
FailedToDeserializeForm::from_err(err).into()
} else {
FailedToDeserializeFormBody::from_err(err).into()
}
})?;
},
)?;
Ok(Form(value))
}
Err(RawFormRejection::BytesRejection(r)) => Err(FormRejection::BytesRejection(r)),
@ -252,6 +255,10 @@ mod tests {
let res = client.get("/?a=false").await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
assert_eq!(
res.text().await,
"Failed to deserialize form: a: invalid digit found in string"
);
let res = client
.post("/")
@ -259,5 +266,9 @@ mod tests {
.body("a=false")
.await;
assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
assert_eq!(
res.text().await,
"Failed to deserialize form body: a: invalid digit found in string"
);
}
}