From 6d24a8695f8cf063df8a7f1ed2d8b9ab579aeb49 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 3 Jan 2022 18:48:50 +0100 Subject: [PATCH] Add SSE tests (#652) * Add SSE tests * Simplify keep alive test a bit * More robust keep-alive tests * rename a bit --- axum/Cargo.toml | 2 +- axum/src/response/sse.rs | 163 +++++++++++++++++++++++++++++++++++++-- axum/src/test_helpers.rs | 14 ++++ examples/sse/src/main.rs | 8 +- 4 files changed, 178 insertions(+), 9 deletions(-) diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 332cb815..14ea03cc 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -56,7 +56,7 @@ futures = "0.3" reqwest = { version = "0.11", default-features = false, features = ["json", "stream"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread", "net"] } +tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] } tokio-stream = "0.1" tracing = "0.1" uuid = { version = "0.8", features = ["serde", "v4"] } diff --git a/axum/src/response/sse.rs b/axum/src/response/sse.rs index fbf92d88..c5b6f737 100644 --- a/axum/src/response/sse.rs +++ b/axum/src/response/sse.rs @@ -451,11 +451,162 @@ impl KeepAliveStream { } } -#[test] -fn leading_space_is_not_stripped() { - let no_leading_space = Event::default().data("\tfoobar"); - assert_eq!(no_leading_space.to_string(), "data: \tfoobar\n\n"); +#[cfg(test)] +mod tests { + use super::*; + use crate::{routing::get, test_helpers::*, Router}; + use futures::stream; + use std::{collections::HashMap, convert::Infallible}; + use tokio_stream::StreamExt as _; - let leading_space = Event::default().data(" foobar"); - assert_eq!(leading_space.to_string(), "data: foobar\n\n"); + #[test] + fn leading_space_is_not_stripped() { + let no_leading_space = Event::default().data("\tfoobar"); + assert_eq!(no_leading_space.to_string(), "data: \tfoobar\n\n"); + + let leading_space = Event::default().data(" foobar"); + assert_eq!(leading_space.to_string(), "data: foobar\n\n"); + } + + #[tokio::test] + async fn basic() { + let app = Router::new().route( + "/", + get(|| async { + let stream = stream::iter(vec![ + Event::default().data("one").comment("this is a comment"), + Event::default() + .json_data(serde_json::json!({ "foo": "bar" })) + .unwrap(), + Event::default() + .event("three") + .retry(Duration::from_secs(30)) + .id("unique-id"), + ]) + .map(Ok::<_, Infallible>); + Sse::new(stream) + }), + ); + + let client = TestClient::new(app); + let mut stream = client.get("/").send().await; + + assert_eq!(stream.headers()["content-type"], "text/event-stream"); + assert_eq!(stream.headers()["cache-control"], "no-cache"); + + let event_fields = parse_event(&stream.chunk_text().await.unwrap()); + assert_eq!(event_fields.get("data").unwrap(), "one"); + assert_eq!(event_fields.get("comment").unwrap(), "this is a comment"); + + let event_fields = parse_event(&stream.chunk_text().await.unwrap()); + assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}"); + assert!(event_fields.get("comment").is_none()); + + let event_fields = parse_event(&stream.chunk_text().await.unwrap()); + assert_eq!(event_fields.get("event").unwrap(), "three"); + assert_eq!(event_fields.get("retry").unwrap(), "30000"); + assert_eq!(event_fields.get("id").unwrap(), "unique-id"); + assert!(event_fields.get("comment").is_none()); + + assert!(stream.chunk_text().await.is_none()); + } + + #[tokio::test(start_paused = true)] + async fn keep_alive() { + const DELAY: Duration = Duration::from_secs(5); + + let app = Router::new().route( + "/", + get(|| async { + let stream = stream::repeat_with(|| Event::default().data("msg")) + .map(Ok::<_, Infallible>) + .throttle(DELAY); + + Sse::new(stream).keep_alive( + KeepAlive::new() + .interval(Duration::from_secs(1)) + .text("keep-alive-text"), + ) + }), + ); + + let client = TestClient::new(app); + let mut stream = client.get("/").send().await; + + for _ in 0..5 { + // first message should be an event + let event_fields = parse_event(&stream.chunk_text().await.unwrap()); + assert_eq!(event_fields.get("data").unwrap(), "msg"); + + // then 4 seconds of keep-alive messages + for _ in 0..4 { + tokio::time::sleep(Duration::from_secs(1)).await; + let event_fields = parse_event(&stream.chunk_text().await.unwrap()); + assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text"); + } + } + } + + #[tokio::test(start_paused = true)] + async fn keep_alive_ends_when_the_stream_ends() { + const DELAY: Duration = Duration::from_secs(5); + + let app = Router::new().route( + "/", + get(|| async { + let stream = stream::repeat_with(|| Event::default().data("msg")) + .map(Ok::<_, Infallible>) + .throttle(DELAY) + .take(2); + + Sse::new(stream).keep_alive( + KeepAlive::new() + .interval(Duration::from_secs(1)) + .text("keep-alive-text"), + ) + }), + ); + + let client = TestClient::new(app); + let mut stream = client.get("/").send().await; + + // first message should be an event + let event_fields = parse_event(&stream.chunk_text().await.unwrap()); + assert_eq!(event_fields.get("data").unwrap(), "msg"); + + // then 4 seconds of keep-alive messages + for _ in 0..4 { + tokio::time::sleep(Duration::from_secs(1)).await; + let event_fields = parse_event(&stream.chunk_text().await.unwrap()); + assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text"); + } + + // then the last event + let event_fields = parse_event(&stream.chunk_text().await.unwrap()); + assert_eq!(event_fields.get("data").unwrap(), "msg"); + + // then no more events or keep-alive messages + assert!(stream.chunk_text().await.is_none()); + } + + fn parse_event(payload: &str) -> HashMap { + let mut fields = HashMap::new(); + + let mut lines = payload.lines().peekable(); + while let Some(line) = lines.next() { + if line.is_empty() { + assert!(lines.next().is_none()); + break; + } + + let (mut key, value) = line.split_once(':').unwrap(); + let value = value.trim(); + if key.is_empty() { + key = "comment"; + } + fields.insert(key.to_owned(), value.to_owned()); + } + + fields + } } diff --git a/axum/src/test_helpers.rs b/axum/src/test_helpers.rs index fef2f08b..da9b4647 100644 --- a/axum/src/test_helpers.rs +++ b/axum/src/test_helpers.rs @@ -2,6 +2,7 @@ use crate::body::HttpBody; use crate::BoxError; +use bytes::Bytes; use http::{ header::{HeaderName, HeaderValue}, Request, StatusCode, @@ -132,4 +133,17 @@ impl TestResponse { pub(crate) fn status(&self) -> StatusCode { self.response.status() } + + pub(crate) fn headers(&self) -> &http::HeaderMap { + self.response.headers() + } + + pub(crate) async fn chunk(&mut self) -> Option { + self.response.chunk().await.unwrap() + } + + pub(crate) async fn chunk_text(&mut self) -> Option { + let chunk = self.chunk().await?; + Some(String::from_utf8(chunk.to_vec()).unwrap()) + } } diff --git a/examples/sse/src/main.rs b/examples/sse/src/main.rs index 8aaea644..772e9d01 100644 --- a/examples/sse/src/main.rs +++ b/examples/sse/src/main.rs @@ -56,7 +56,11 @@ async fn sse_handler( // A `Stream` that repeats an event every second let stream = stream::repeat_with(|| Event::default().data("hi!")) .map(Ok) - .throttle(Duration::from_secs(1)); + .throttle(Duration::from_secs(10)); - Sse::new(stream) + Sse::new(stream).keep_alive( + axum::response::sse::KeepAlive::new() + .interval(Duration::from_secs(1)) + .text("keep-alive-text"), + ) }