fix(sse): skip sse incompatible chars of serde_json::RawValue (#2992)

This commit is contained in:
Jan 2024-10-17 16:49:39 +02:00 committed by GitHub
parent c5a451d94b
commit 04fee944ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 34 additions and 2 deletions

View file

@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- **fixed:** Skip SSE incompatible chars of `serde_json::RawValue` in `Event::json_data` ([#2992])
- **breaking:** Move `Host` extractor to `axum-extra` ([#2956])
- **added:** Add `method_not_allowed_fallback` to set a fallback when a path matches but there is no handler for the given HTTP method ([#2903])
- **added:** Add `NoContent` as a self-described shortcut for `StatusCode::NO_CONTENT` ([#2978])
@ -26,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#2961]: https://github.com/tokio-rs/axum/pull/2961
[#2974]: https://github.com/tokio-rs/axum/pull/2974
[#2978]: https://github.com/tokio-rs/axum/pull/2978
[#2992]: https://github.com/tokio-rs/axum/pull/2992
# 0.8.0

View file

@ -117,7 +117,7 @@ quickcheck = "1.0"
quickcheck_macros = "1.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
serde_json = { version = "1.0", features = ["raw_value"] }
time = { version = "0.3", features = ["serde-human-readable"] }
tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] }
tokio-stream = "0.1"

View file

@ -208,12 +208,29 @@ impl Event {
where
T: serde::Serialize,
{
struct IgnoreNewLines<'a>(bytes::buf::Writer<&'a mut BytesMut>);
impl std::io::Write for IgnoreNewLines<'_> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut last_split = 0;
for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) {
self.0.write_all(&buf[last_split..delimiter])?;
last_split = delimiter + 1;
}
self.0.write_all(&buf[last_split..])?;
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
self.0.flush()
}
}
if self.flags.contains(EventFlags::HAS_DATA) {
panic!("Called `EventBuilder::json_data` multiple times");
}
self.buffer.extend_from_slice(b"data: ");
serde_json::to_writer((&mut self.buffer).writer(), &data).map_err(axum_core::Error::new)?;
serde_json::to_writer(IgnoreNewLines((&mut self.buffer).writer()), &data)
.map_err(axum_core::Error::new)?;
self.buffer.put_u8(b'\n');
self.flags.insert(EventFlags::HAS_DATA);
@ -515,6 +532,7 @@ mod tests {
use super::*;
use crate::{routing::get, test_helpers::*, Router};
use futures_util::stream;
use serde_json::value::RawValue;
use std::{collections::HashMap, convert::Infallible};
use tokio_stream::StreamExt as _;
@ -527,6 +545,18 @@ mod tests {
assert_eq!(&*leading_space.finalize(), b"data: foobar\n\n");
}
#[test]
fn valid_json_raw_value_chars_stripped() {
let json_string = "{\r\"foo\": \n\r\r \"bar\\n\"\n}";
let json_raw_value_event = Event::default()
.json_data(serde_json::from_str::<&RawValue>(json_string).unwrap())
.unwrap();
assert_eq!(
&*json_raw_value_event.finalize(),
format!("data: {}\n\n", json_string.replace(['\n', '\r'], "")).as_bytes()
);
}
#[crate::test]
async fn basic() {
let app = Router::new().route(