diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index 9f229bf7..a383f477 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -7,8 +7,10 @@ and this project adheres to [Semantic Versioning]. # Unreleased +- **added:** Add `JsonLines` for streaming newline delimited JSON ([#1093]) - **change:** axum's MSRV is now 1.56 ([#1098]) +[#1093]: https://github.com/tokio-rs/axum/pull/1093 [#1098]: https://github.com/tokio-rs/axum/pull/1098 # 0.3.4 (08. June, 2022) diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index 508f6f27..2917c8e9 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -19,6 +19,7 @@ cookie-private = ["cookie", "cookie-lib/private"] cookie-signed = ["cookie", "cookie-lib/signed"] erased-json = ["serde_json", "serde"] form = ["serde", "serde_html_form"] +json-lines = ["serde_json", "serde", "tokio-util/io", "tokio-stream/io-util"] query = ["serde", "serde_html_form"] spa = ["tower-http/fs"] typed-routing = ["axum-macros", "serde", "percent-encoding"] @@ -26,6 +27,7 @@ typed-routing = ["axum-macros", "serde", "percent-encoding"] [dependencies] axum = { path = "../axum", version = "0.5", default-features = false } bytes = "1.1.0" +futures-util = { version = "0.3", default-features = false, features = ["alloc"] } http = "0.2" mime = "0.3" pin-project-lite = "0.2" @@ -42,10 +44,12 @@ percent-encoding = { version = "2.1", optional = true } serde = { version = "1.0", optional = true } serde_html_form = { version = "0.1", optional = true } serde_json = { version = "1.0.71", optional = true } +tokio-stream = { version = "0.1", optional = true } tokio-util = { version = "0.7", optional = true } [dev-dependencies] axum = { path = "../axum", version = "0.5", features = ["headers"] } +futures = "0.3" hyper = "0.14" reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "multipart"] } serde = { version = "1.0", features = ["derive"] } diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs index 9abed09f..22e19559 100644 --- a/axum-extra/src/extract/mod.rs +++ b/axum-extra/src/extract/mod.rs @@ -27,3 +27,7 @@ pub use self::form::Form; #[cfg(feature = "query")] pub use self::query::Query; + +#[cfg(feature = "json-lines")] +#[doc(no_inline)] +pub use crate::json_lines::JsonLines; diff --git a/axum-extra/src/json_lines.rs b/axum-extra/src/json_lines.rs new file mode 100644 index 00000000..1bc8d1a0 --- /dev/null +++ b/axum-extra/src/json_lines.rs @@ -0,0 +1,286 @@ +//! Newline delimited JSON extractor and response. + +use axum::{ + async_trait, + body::{HttpBody, StreamBody}, + extract::{rejection::BodyAlreadyExtracted, FromRequest, RequestParts}, + response::{IntoResponse, Response}, + BoxError, +}; +use bytes::{BufMut, Bytes, BytesMut}; +use futures_util::stream::{BoxStream, Stream, TryStream, TryStreamExt}; +use pin_project_lite::pin_project; +use serde::{de::DeserializeOwned, Serialize}; +use std::{ + io::{self, Write}, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::io::AsyncBufReadExt; +use tokio_stream::wrappers::LinesStream; +use tokio_util::io::StreamReader; + +pin_project! { + /// A stream of newline delimited JSON. + /// + /// This can be used both as an extractor and as a response. + /// + /// # As extractor + /// + /// ```rust + /// use axum_extra::json_lines::JsonLines; + /// use futures::stream::StreamExt; + /// + /// async fn handler(mut stream: JsonLines<serde_json::Value>) { + /// while let Some(value) = stream.next().await { + /// // ... + /// } + /// } + /// ``` + /// + /// # As response + /// + /// ```rust + /// use axum::{BoxError, response::{IntoResponse, Response}}; + /// use axum_extra::json_lines::JsonLines; + /// use futures::stream::Stream; + /// + /// fn stream_of_values() -> impl Stream<Item = Result<serde_json::Value, BoxError>> { + /// # futures::stream::empty() + /// } + /// + /// async fn handler() -> Response { + /// JsonLines::new(stream_of_values()).into_response() + /// } + /// ``` + // we use `AsExtractor` as the default because you're more likely to name this type if its used + // as an extractor + pub struct JsonLines<S, T = AsExtractor> { + #[pin] + inner: Inner<S>, + _marker: PhantomData<T>, + } +} + +pin_project! { + #[project = InnerProj] + enum Inner<S> { + Response { + #[pin] + stream: S, + }, + Extractor { + #[pin] + stream: BoxStream<'static, Result<S, axum::Error>>, + }, + } +} + +/// Maker type used to prove that an `JsonLines` was constructed via `FromRequest`. +#[derive(Debug)] +#[non_exhaustive] +pub struct AsExtractor; + +/// Maker type used to prove that an `JsonLines` was constructed via `JsonLines::new`. +#[derive(Debug)] +#[non_exhaustive] +pub struct AsResponse; + +impl<S> JsonLines<S, AsResponse> { + /// Create a new `JsonLines` from a stream of items. + pub fn new(stream: S) -> Self { + Self { + inner: Inner::Response { stream }, + _marker: PhantomData, + } + } +} + +#[async_trait] +impl<B, T> FromRequest<B> for JsonLines<T, AsExtractor> +where + B: HttpBody + Send + 'static, + B::Data: Into<Bytes>, + B::Error: Into<BoxError>, + T: DeserializeOwned, +{ + type Rejection = BodyAlreadyExtracted; + + async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + // `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead` + // so we can call `AsyncRead::lines` and then convert it back to a `Stream` + + let body = req.take_body().ok_or_else(BodyAlreadyExtracted::default)?; + let body = BodyStream { body }; + + let stream = body + .map_ok(Into::into) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err)); + let read = StreamReader::new(stream); + let lines_stream = LinesStream::new(read.lines()); + + let deserialized_stream = + lines_stream + .map_err(axum::Error::new) + .and_then(|value| async move { + serde_json::from_str::<T>(&value).map_err(axum::Error::new) + }); + + Ok(Self { + inner: Inner::Extractor { + stream: Box::pin(deserialized_stream), + }, + _marker: PhantomData, + }) + } +} + +// like `axum::extract::BodyStream` except it doesn't box the inner body +// we don't need that since we box the final stream in `Inner::Extractor` +pin_project! { + struct BodyStream<B> { + #[pin] + body: B, + } +} + +impl<B> Stream for BodyStream<B> +where + B: HttpBody + Send + 'static, +{ + type Item = Result<B::Data, B::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.project().body.poll_data(cx) + } +} + +impl<T> Stream for JsonLines<T, AsExtractor> { + type Item = Result<T, axum::Error>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + match self.project().inner.project() { + InnerProj::Extractor { stream } => stream.poll_next(cx), + // `JsonLines<_, AsExtractor>` can only be constructed via `FromRequest` + // which doesn't use this variant + InnerProj::Response { .. } => unreachable!(), + } + } +} + +impl<S> IntoResponse for JsonLines<S, AsResponse> +where + S: TryStream + Send + 'static, + S::Ok: Serialize + Send, + S::Error: Into<BoxError>, +{ + fn into_response(self) -> Response { + let inner = match self.inner { + Inner::Response { stream } => stream, + // `JsonLines<_, AsResponse>` can only be constructed via `JsonLines::new` + // which doesn't use this variant + Inner::Extractor { .. } => unreachable!(), + }; + + let stream = inner.map_err(Into::into).and_then(|value| async move { + let mut buf = BytesMut::new().writer(); + serde_json::to_writer(&mut buf, &value)?; + buf.write_all(b"\n")?; + Ok::<_, BoxError>(buf.into_inner().freeze()) + }); + let stream = StreamBody::new(stream); + + // there is no consensus around mime type yet + // https://github.com/wardi/jsonlines/issues/36 + stream.into_response() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::*; + use axum::{ + routing::{get, post}, + Router, + }; + use futures_util::StreamExt; + use http::StatusCode; + use serde::Deserialize; + use std::{convert::Infallible, error::Error}; + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] + struct User { + id: i32, + } + + #[tokio::test] + async fn extractor() { + let app = Router::new().route( + "/", + post(|mut stream: JsonLines<User>| async move { + assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 1 }); + assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 2 }); + assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 3 }); + + // sources are downcastable to `serde_json::Error` + let err = stream.next().await.unwrap().unwrap_err(); + let _: &serde_json::Error = err + .source() + .unwrap() + .downcast_ref::<serde_json::Error>() + .unwrap(); + }), + ); + + let client = TestClient::new(app); + + let res = client + .post("/") + .body( + vec![ + "{\"id\":1}", + "{\"id\":2}", + "{\"id\":3}", + // to trigger an error for source downcasting + "{\"id\":false}", + ] + .join("\n"), + ) + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn response() { + let app = Router::new().route( + "/", + get(|| async { + let values = futures_util::stream::iter(vec![ + Ok::<_, Infallible>(User { id: 1 }), + Ok::<_, Infallible>(User { id: 2 }), + Ok::<_, Infallible>(User { id: 3 }), + ]); + JsonLines::new(values) + }), + ); + + let client = TestClient::new(app); + + let res = client.get("/").send().await; + + let values = res + .text() + .await + .lines() + .map(|line| serde_json::from_str::<User>(line).unwrap()) + .collect::<Vec<_>>(); + + assert_eq!( + values, + vec![User { id: 1 }, User { id: 2 }, User { id: 3 },] + ); + } +} diff --git a/axum-extra/src/lib.rs b/axum-extra/src/lib.rs index 258ce4f6..a63080eb 100644 --- a/axum-extra/src/lib.rs +++ b/axum-extra/src/lib.rs @@ -15,6 +15,7 @@ //! `cookie-signed` | Enables the `SignedCookieJar` extractor | No //! `erased-json` | Enables the `ErasedJson` response | No //! `form` | Enables the `Form` extractor | No +//! `json-lines` | Enables the `json-lines` extractor and response | No //! `query` | Enables the `Query` extractor | No //! `spa` | Enables the `Spa` router | No //! `typed-routing` | Enables the `TypedPath` routing utilities | No @@ -67,6 +68,9 @@ pub mod extract; pub mod response; pub mod routing; +#[cfg(feature = "json-lines")] +pub mod json_lines; + #[cfg(feature = "typed-routing")] #[doc(hidden)] pub mod __private { diff --git a/axum-extra/src/response/mod.rs b/axum-extra/src/response/mod.rs index 9324c805..7926b8c8 100644 --- a/axum-extra/src/response/mod.rs +++ b/axum-extra/src/response/mod.rs @@ -5,3 +5,7 @@ mod erased_json; #[cfg(feature = "erased-json")] pub use erased_json::ErasedJson; + +#[cfg(feature = "json-lines")] +#[doc(no_inline)] +pub use crate::json_lines::JsonLines;