From 56159b0d4e0db46b45d471bc291ab270c1e1dd77 Mon Sep 17 00:00:00 2001 From: future-highway <113635015+future-highway@users.noreply.github.com> Date: Fri, 29 Dec 2023 06:06:47 -0500 Subject: [PATCH] JsonDeserializer extractor for zero-copy deserialization (#2431) --- axum-extra/CHANGELOG.md | 2 + axum-extra/Cargo.toml | 2 + axum-extra/src/extract/json_deserializer.rs | 446 ++++++++++++++++++++ axum-extra/src/extract/mod.rs | 9 + axum-extra/src/lib.rs | 1 + axum/src/json.rs | 4 +- 6 files changed, 462 insertions(+), 2 deletions(-) create mode 100644 axum-extra/src/extract/json_deserializer.rs diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index 7120f6bc..c2d0a1f6 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -8,8 +8,10 @@ and this project adheres to [Semantic Versioning]. # Unreleased - **change:** Update version of multer used internally for multipart ([#2433]) +- **added:** `JsonDeserializer` extractor ([#2431]) [#2433]: https://github.com/tokio-rs/axum/pull/2433 +[#2431]: https://github.com/tokio-rs/axum/pull/2431 # 0.9.0 (27. November, 2023) diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index 04929dad..d2fa8993 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -21,6 +21,7 @@ cookie-signed = ["cookie", "cookie?/signed"] cookie-key-expansion = ["cookie", "cookie?/key-expansion"] erased-json = ["dep:serde_json"] form = ["dep:serde_html_form"] +json-deserializer = ["dep:serde_json", "dep:serde_path_to_error"] json-lines = [ "dep:serde_json", "dep:tokio-util", @@ -60,6 +61,7 @@ percent-encoding = { version = "2.1", optional = true } prost = { version = "0.12", optional = true } serde_html_form = { version = "0.2.0", optional = true } serde_json = { version = "1.0.71", optional = true } +serde_path_to_error = { version = "0.1.8", optional = true } tokio = { version = "1.19", optional = true } tokio-stream = { version = "0.1.9", optional = true } tokio-util = { version = "0.7", optional = true } diff --git a/axum-extra/src/extract/json_deserializer.rs b/axum-extra/src/extract/json_deserializer.rs new file mode 100644 index 00000000..0a307987 --- /dev/null +++ b/axum-extra/src/extract/json_deserializer.rs @@ -0,0 +1,446 @@ +use axum::async_trait; +use axum::extract::{FromRequest, Request}; +use axum_core::__composite_rejection as composite_rejection; +use axum_core::__define_rejection as define_rejection; +use axum_core::extract::rejection::BytesRejection; +use bytes::Bytes; +use http::{header, HeaderMap}; +use serde::Deserialize; +use std::marker::PhantomData; + +/// JSON Extractor for zero-copy deserialization. +/// +/// Deserialize request bodies into some type that implements [`serde::Deserialize<'de>`]. +/// Parsing JSON is delayed until [`deserialize`](JsonDeserializer::deserialize) is called. +/// If the type implements [`serde::de::DeserializeOwned`], the [`Json`](axum::Json) extractor should +/// be preferred. +/// +/// The request will be rejected (and a [`JsonDeserializerRejection`] will be returned) if: +/// +/// - The request doesn't have a `Content-Type: application/json` (or similar) header. +/// - Buffering the request body fails. +/// +/// Additionally, a `JsonRejection` error will be returned, when calling `deserialize` if: +/// +/// - The body doesn't contain syntactically valid JSON. +/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target +/// type. +/// - Attempting to deserialize escaped JSON into a type that must be borrowed (e.g. `&'a str`). +/// +/// ⚠️ `serde` will implicitly try to borrow for `&str` and `&[u8]` types, but will error if the +/// input contains escaped characters. Use `Cow<'a, str>` or `Cow<'a, [u8]>`, with the +/// `#[serde(borrow)]` attribute, to allow serde to fall back to an owned type when encountering +/// escaped characters. +/// +/// ⚠️ Since parsing JSON requires consuming the request body, the `Json` extractor must be +/// *last* if there are multiple extractors in a handler. +/// See ["the order of extractors"][order-of-extractors] +/// +/// [order-of-extractors]: axum::extract#the-order-of-extractors +/// +/// See [`JsonDeserializerRejection`] for more details. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::{ +/// routing::post, +/// Router, +/// response::{IntoResponse, Response} +/// }; +/// use axum_extra::extract::JsonDeserializer; +/// use serde::Deserialize; +/// use std::borrow::Cow; +/// use http::StatusCode; +/// +/// #[derive(Deserialize)] +/// struct Data<'a> { +/// #[serde(borrow)] +/// borrow_text: Cow<'a, str>, +/// #[serde(borrow)] +/// borrow_bytes: Cow<'a, [u8]>, +/// borrow_dangerous: &'a str, +/// not_borrowed: String, +/// } +/// +/// async fn upload(deserializer: JsonDeserializer<Data<'_>>) -> Response { +/// let data = match deserializer.deserialize() { +/// Ok(data) => data, +/// Err(e) => return e.into_response(), +/// }; +/// +/// // payload is a `Data` with borrowed data from `deserializer`, +/// // which owns the request body (`Bytes`). +/// +/// StatusCode::OK.into_response() +/// } +/// +/// let app = Router::new().route("/upload", post(upload)); +/// # let _: Router = app; +/// ``` +#[derive(Debug, Clone, Default)] +#[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))] +pub struct JsonDeserializer<T> { + bytes: Bytes, + _marker: PhantomData<T>, +} + +#[async_trait] +impl<T, S> FromRequest<S> for JsonDeserializer<T> +where + T: Deserialize<'static>, + S: Send + Sync, +{ + type Rejection = JsonDeserializerRejection; + + async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> { + if json_content_type(req.headers()) { + let bytes = Bytes::from_request(req, state).await?; + Ok(Self { + bytes, + _marker: PhantomData, + }) + } else { + Err(MissingJsonContentType.into()) + } + } +} + +impl<'de, 'a: 'de, T> JsonDeserializer<T> +where + T: Deserialize<'de>, +{ + /// Deserialize the request body into the target type. + /// See [`JsonDeserializer`] for more details. + pub fn deserialize(&'a self) -> Result<T, JsonDeserializerRejection> { + let deserializer = &mut serde_json::Deserializer::from_slice(&self.bytes); + + let value = match serde_path_to_error::deserialize(deserializer) { + Ok(value) => value, + Err(err) => { + let rejection = match err.inner().classify() { + serde_json::error::Category::Data => JsonDataError::from_err(err).into(), + serde_json::error::Category::Syntax | serde_json::error::Category::Eof => { + JsonSyntaxError::from_err(err).into() + } + serde_json::error::Category::Io => { + if cfg!(debug_assertions) { + // we don't use `serde_json::from_reader` and instead always buffer + // bodies first, so we shouldn't encounter any IO errors + unreachable!() + } else { + JsonSyntaxError::from_err(err).into() + } + } + }; + return Err(rejection); + } + }; + + Ok(value) + } +} + +define_rejection! { + #[status = UNPROCESSABLE_ENTITY] + #[body = "Failed to deserialize the JSON body into the target type"] + #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))] + /// Rejection type for [`JsonDeserializer`]. + /// + /// This rejection is used if the request body is syntactically valid JSON but couldn't be + /// deserialized into the target type. + pub struct JsonDataError(Error); +} + +define_rejection! { + #[status = BAD_REQUEST] + #[body = "Failed to parse the request body as JSON"] + #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))] + /// Rejection type for [`JsonDeserializer`]. + /// + /// This rejection is used if the request body didn't contain syntactically valid JSON. + pub struct JsonSyntaxError(Error); +} + +define_rejection! { + #[status = UNSUPPORTED_MEDIA_TYPE] + #[body = "Expected request with `Content-Type: application/json`"] + #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))] + /// Rejection type for [`JsonDeserializer`] used if the `Content-Type` + /// header is missing. + pub struct MissingJsonContentType; +} + +composite_rejection! { + /// Rejection used for [`JsonDeserializer`]. + /// + /// Contains one variant for each way the [`JsonDeserializer`] extractor + /// can fail. + #[cfg_attr(docsrs, doc(cfg(feature = "json-deserializer")))] + pub enum JsonDeserializerRejection { + JsonDataError, + JsonSyntaxError, + MissingJsonContentType, + BytesRejection, + } +} + +fn json_content_type(headers: &HeaderMap) -> bool { + let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) { + content_type + } else { + return false; + }; + + let content_type = if let Ok(content_type) = content_type.to_str() { + content_type + } else { + return false; + }; + + let mime = if let Ok(mime) = content_type.parse::<mime::Mime>() { + mime + } else { + return false; + }; + + let is_json_content_type = mime.type_() == "application" + && (mime.subtype() == "json" || mime.suffix().map_or(false, |name| name == "json")); + + is_json_content_type +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::*; + use axum::{ + response::{IntoResponse, Response}, + routing::post, + Router, + }; + use http::StatusCode; + use serde::Deserialize; + use serde_json::{json, Value}; + use std::borrow::Cow; + + #[tokio::test] + async fn deserialize_body() { + #[derive(Debug, Deserialize)] + struct Input<'a> { + #[serde(borrow)] + foo: Cow<'a, str>, + } + + async fn handler(deserializer: JsonDeserializer<Input<'_>>) -> Response { + match deserializer.deserialize() { + Ok(input) => { + assert!(matches!(input.foo, Cow::Borrowed(_))); + input.foo.into_owned().into_response() + } + Err(e) => e.into_response(), + } + } + + let app = Router::new().route("/", post(handler)); + + let client = TestClient::new(app); + let res = client.post("/").json(&json!({ "foo": "bar" })).send().await; + let body = res.text().await; + + assert_eq!(body, "bar"); + } + + #[tokio::test] + async fn deserialize_body_escaped_to_cow() { + #[derive(Debug, Deserialize)] + struct Input<'a> { + #[serde(borrow)] + foo: Cow<'a, str>, + } + + async fn handler(deserializer: JsonDeserializer<Input<'_>>) -> Response { + match deserializer.deserialize() { + Ok(Input { foo }) => { + let Cow::Owned(foo) = foo else { + panic!("Deserializer is expected to fallback to Cow::Owned when encountering escaped characters") + }; + + foo.into_response() + } + Err(e) => e.into_response(), + } + } + + let app = Router::new().route("/", post(handler)); + + let client = TestClient::new(app); + + // The escaped characters prevent serde_json from borrowing. + let res = client + .post("/") + .json(&json!({ "foo": "\"bar\"" })) + .send() + .await; + + let body = res.text().await; + + assert_eq!(body, r#""bar""#); + } + + #[tokio::test] + async fn deserialize_body_escaped_to_str() { + #[derive(Debug, Deserialize)] + struct Input<'a> { + // Explicit `#[serde(borrow)]` attribute is not required for `&str` or &[u8]. + // See: https://serde.rs/lifetimes.html#borrowing-data-in-a-derived-impl + foo: &'a str, + } + + async fn handler(deserializer: JsonDeserializer<Input<'_>>) -> Response { + match deserializer.deserialize() { + Ok(Input { foo }) => foo.to_owned().into_response(), + Err(e) => e.into_response(), + } + } + + let app = Router::new().route("/", post(handler)); + + let client = TestClient::new(app); + + let res = client + .post("/") + .json(&json!({ "foo": "good" })) + .send() + .await; + let body = res.text().await; + assert_eq!(body, "good"); + + let res = client + .post("/") + .json(&json!({ "foo": "\"bad\"" })) + .send() + .await; + assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); + let body_text = res.text().await; + assert_eq!( + body_text, + "Failed to deserialize the JSON body into the target type: foo: invalid type: string \"\\\"bad\\\"\", expected a borrowed string at line 1 column 16" + ); + } + + #[tokio::test] + async fn consume_body_to_json_requires_json_content_type() { + #[derive(Debug, Deserialize)] + struct Input<'a> { + #[allow(dead_code)] + foo: Cow<'a, str>, + } + + async fn handler(_deserializer: JsonDeserializer<Input<'_>>) -> Response { + panic!("This handler should not be called") + } + + let app = Router::new().route("/", post(handler)); + + let client = TestClient::new(app); + let res = client.post("/").body(r#"{ "foo": "bar" }"#).send().await; + + let status = res.status(); + + assert_eq!(status, StatusCode::UNSUPPORTED_MEDIA_TYPE); + } + + #[tokio::test] + async fn json_content_types() { + async fn valid_json_content_type(content_type: &str) -> bool { + println!("testing {content_type:?}"); + + async fn handler(_deserializer: JsonDeserializer<Value>) -> Response { + StatusCode::OK.into_response() + } + + let app = Router::new().route("/", post(handler)); + + let res = TestClient::new(app) + .post("/") + .header("content-type", content_type) + .body("{}") + .send() + .await; + + res.status() == StatusCode::OK + } + + assert!(valid_json_content_type("application/json").await); + assert!(valid_json_content_type("application/json; charset=utf-8").await); + assert!(valid_json_content_type("application/json;charset=utf-8").await); + assert!(valid_json_content_type("application/cloudevents+json").await); + assert!(!valid_json_content_type("text/json").await); + } + + #[tokio::test] + async fn invalid_json_syntax() { + async fn handler(deserializer: JsonDeserializer<Value>) -> Response { + match deserializer.deserialize() { + Ok(_) => panic!("Should have matched `Err`"), + Err(e) => e.into_response(), + } + } + + let app = Router::new().route("/", post(handler)); + + let client = TestClient::new(app); + let res = client + .post("/") + .body("{") + .header("content-type", "application/json") + .send() + .await; + + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + } + + #[derive(Deserialize)] + struct Foo { + #[allow(dead_code)] + a: i32, + #[allow(dead_code)] + b: Vec<Bar>, + } + + #[derive(Deserialize)] + struct Bar { + #[allow(dead_code)] + x: i32, + #[allow(dead_code)] + y: i32, + } + + #[tokio::test] + async fn invalid_json_data() { + async fn handler(deserializer: JsonDeserializer<Foo>) -> Response { + match deserializer.deserialize() { + Ok(_) => panic!("Should have matched `Err`"), + Err(e) => e.into_response(), + } + } + + let app = Router::new().route("/", post(handler)); + + let client = TestClient::new(app); + let res = client + .post("/") + .body("{\"a\": 1, \"b\": [{\"x\": 2}]}") + .header("content-type", "application/json") + .send() + .await; + + assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); + let body_text = res.text().await; + assert_eq!( + body_text, + "Failed to deserialize the JSON body into the target type: b[0]: missing field `y` at line 1 column 23" + ); + } +} diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs index 8435fc84..1f9974de 100644 --- a/axum-extra/src/extract/mod.rs +++ b/axum-extra/src/extract/mod.rs @@ -10,6 +10,9 @@ mod form; #[cfg(feature = "cookie")] pub mod cookie; +#[cfg(feature = "json-deserializer")] +mod json_deserializer; + #[cfg(feature = "query")] mod query; @@ -36,6 +39,12 @@ pub use self::query::{OptionalQuery, OptionalQueryRejection, Query, QueryRejecti #[cfg(feature = "multipart")] pub use self::multipart::Multipart; +#[cfg(feature = "json-deserializer")] +pub use self::json_deserializer::{ + JsonDataError, JsonDeserializer, JsonDeserializerRejection, JsonSyntaxError, + MissingJsonContentType, +}; + #[cfg(feature = "json-lines")] #[doc(no_inline)] pub use crate::json_lines::JsonLines; diff --git a/axum-extra/src/lib.rs b/axum-extra/src/lib.rs index 12aa2801..eb93b0a3 100644 --- a/axum-extra/src/lib.rs +++ b/axum-extra/src/lib.rs @@ -16,6 +16,7 @@ //! `cookie-key-expansion` | Enables the `Key::derive_from` method | No //! `erased-json` | Enables the `ErasedJson` response | No //! `form` | Enables the `Form` extractor | No +//! `json-deserializer` | Enables the `JsonDeserializer` extractor | No //! `json-lines` | Enables the `JsonLines` extractor and response | No //! `multipart` | Enables the `Multipart` extractor | No //! `protobuf` | Enables the `Protobuf` extractor and response | No diff --git a/axum/src/json.rs b/axum/src/json.rs index ebff242d..e96be5b8 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -12,12 +12,12 @@ use serde::{de::DeserializeOwned, Serialize}; /// JSON Extractor / Response. /// /// When used as an extractor, it can deserialize request bodies into some type that -/// implements [`serde::Deserialize`]. The request will be rejected (and a [`JsonRejection`] will +/// implements [`serde::de::DeserializeOwned`]. The request will be rejected (and a [`JsonRejection`] will /// be returned) if: /// /// - The request doesn't have a `Content-Type: application/json` (or similar) header. /// - The body doesn't contain syntactically valid JSON. -/// - The body contains syntactically valid JSON but it couldn't be deserialized into the target +/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target /// type. /// - Buffering the request body fails. ///