Make sse::Event build event as a BytesMut (#647)

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
Sabrina Jewson 2022-01-04 00:21:34 +00:00 committed by David Pedersen
parent e0a463b463
commit d5694f0d0d
3 changed files with 197 additions and 129 deletions

View file

@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- **breaking:** `sse::Event` now accepts types implementing `AsRef<str>` instead of `Into<String>`
as field values.
- **breaking:** `sse::Event` now panics if a setter method is called twice instead of silently
overwriting old values.
- **breaking:** Require `Output = ()` on `WebSocketStream::on_upgrade` ([#644])
- **breaking:** Make `TypedHeaderRejectionReason` `#[non_exhaustive]` ([#665])

View file

@ -29,10 +29,11 @@ bytes = "1.0"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
http = "0.2.5"
http-body = "0.4.4"
mime = "0.3.16"
hyper = { version = "0.14.14", features = ["server", "tcp", "stream"] }
itoa = "1.0.1"
matchit = "0.4.4"
memchr = "2.4.1"
mime = "0.3.16"
percent-encoding = "2.1"
pin-project-lite = "0.2.7"
serde = "1.0"

View file

@ -32,15 +32,14 @@ use crate::{
response::{IntoResponse, Response},
BoxError,
};
use bytes::{BufMut, BytesMut};
use futures_util::{
ready,
stream::{Stream, TryStream},
};
use pin_project_lite::pin_project;
use std::{
borrow::Cow,
fmt,
fmt::Write,
future::Future,
pin::Pin,
task::{Context, Poll},
@ -134,9 +133,7 @@ where
match this.event_stream.get_pin_mut().poll_next(cx) {
Poll::Pending => {
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
keep_alive
.poll_event(cx)
.map(|e| Some(Ok(Bytes::from(e.to_string()))))
keep_alive.poll_event(cx).map(|e| Some(Ok(e)))
} else {
Poll::Pending
}
@ -145,7 +142,7 @@ where
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
keep_alive.reset();
}
Poll::Ready(Some(Ok(Bytes::from(event.to_string()))))
Poll::Ready(Some(Ok(event.finalize())))
}
Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
Poll::Ready(None) => Poll::Ready(None),
@ -161,21 +158,10 @@ where
}
/// Server-sent event
#[derive(Default, Debug)]
#[derive(Debug, Default, Clone)]
pub struct Event {
id: Option<String>,
data: Option<DataType>,
event: Option<String>,
comment: Option<String>,
retry: Option<Duration>,
}
// Server-sent event data type
#[derive(Debug)]
enum DataType {
Text(String),
#[cfg(feature = "json")]
Json(String),
buffer: BytesMut,
flags: EventFlags,
}
impl Event {
@ -189,18 +175,22 @@ impl Event {
///
/// # Panics
///
/// Panics if `data` contains any carriage returns, as they cannot be transmitted over SSE.
/// - Panics if `data` contains any carriage returns, as they cannot be transmitted over SSE.
/// - Panics if `data` or `json_data` have already been called.
pub fn data<T>(mut self, data: T) -> Event
where
T: Into<String>,
T: AsRef<str>,
{
let data = data.into();
assert_eq!(
memchr::memchr(b'\r', data.as_bytes()),
None,
"SSE data cannot contain carriage returns",
);
self.data = Some(DataType::Text(data));
if self.flags.contains(EventFlags::HAS_DATA) {
panic!("Called `EventBuilder::data` multiple times");
}
for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
self.field("data", line);
}
self.flags.insert(EventFlags::HAS_DATA);
self
}
@ -209,13 +199,26 @@ impl Event {
/// This corresponds to [`MessageEvent`'s data field].
///
/// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
///
/// # Panics
///
/// Panics if `data` or `json_data` have already been called.
#[cfg(feature = "json")]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
pub fn json_data<T>(mut self, data: T) -> Result<Event, serde_json::Error>
pub fn json_data<T>(mut self, data: T) -> serde_json::Result<Event>
where
T: serde::Serialize,
{
self.data = Some(DataType::Json(serde_json::to_string(&data)?));
if self.flags.contains(EventFlags::HAS_DATA) {
panic!("Called `EventBuilder::json_data` multiple times");
}
self.buffer.extend_from_slice(b"data:");
serde_json::to_writer((&mut self.buffer).writer(), &data)?;
self.buffer.put_u8(b'\n');
self.flags.insert(EventFlags::HAS_DATA);
Ok(self)
}
@ -223,21 +226,17 @@ impl Event {
///
/// This field will be ignored by most SSE clients.
///
/// Unlike other functions, this function can be called multiple times to add many comments.
///
/// # Panics
///
/// Panics if `comment` contains any newlines or carriage returns, as they are not allowed in
/// comments.
pub fn comment<T>(mut self, comment: T) -> Event
where
T: Into<String>,
T: AsRef<str>,
{
let comment = comment.into();
assert_eq!(
memchr::memchr2(b'\r', b'\n', comment.as_bytes()),
None,
"SSE comment cannot contain newlines or carriage returns"
);
self.comment = Some(comment);
self.field("", comment.as_ref());
self
}
@ -253,18 +252,19 @@ impl Event {
///
/// # Panics
///
/// Panics if `event` contains any newlines or carriage returns.
/// - Panics if `event` contains any newlines or carriage returns.
/// - Panics if this function has already been called on this event.
pub fn event<T>(mut self, event: T) -> Event
where
T: Into<String>,
T: AsRef<str>,
{
let event = event.into();
assert_eq!(
memchr::memchr2(b'\r', b'\n', event.as_bytes()),
None,
"SSE event name cannot contain newlines or carriage returns"
);
self.event = Some(event);
if self.flags.contains(EventFlags::HAS_EVENT) {
panic!("Called `EventBuilder::event` multiple times");
}
self.flags.insert(EventFlags::HAS_EVENT);
self.field("event", event.as_ref());
self
}
@ -273,8 +273,40 @@ impl Event {
/// This sets how long clients will wait before reconnecting if they are disconnected from the
/// SSE endpoint. Note that this is just a hint: clients are free to wait for longer if they
/// wish, such as if they implement exponential backoff.
///
/// # Panics
///
/// Panics if this function has already been called on this event.
pub fn retry(mut self, duration: Duration) -> Event {
self.retry = Some(duration);
if self.flags.contains(EventFlags::HAS_RETRY) {
panic!("Called `EventBuilder::retry` multiple times");
}
self.flags.insert(EventFlags::HAS_RETRY);
self.buffer.extend_from_slice(b"retry:");
let secs = duration.as_secs();
let millis = duration.subsec_millis();
if secs > 0 {
// format seconds
self.buffer
.extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
// pad milliseconds
if millis < 10 {
self.buffer.extend_from_slice(b"00");
} else if millis < 100 {
self.buffer.extend_from_slice(b"0");
}
}
// format milliseconds
self.buffer
.extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
self.buffer.put_u8(b'\n');
self
}
@ -288,86 +320,58 @@ impl Event {
///
/// # Panics
///
/// Panics if `id` contains any newlines, carriage returns or null characters.
/// - Panics if `id` contains any newlines, carriage returns or null characters.
/// - Panics if this function has already been called on this event.
pub fn id<T>(mut self, id: T) -> Event
where
T: Into<String>,
T: AsRef<str>,
{
let id = id.into();
if self.flags.contains(EventFlags::HAS_ID) {
panic!("Called `EventBuilder::id` multiple times");
}
self.flags.insert(EventFlags::HAS_ID);
let id = id.as_ref().as_bytes();
assert_eq!(
memchr::memchr3(b'\r', b'\n', b'\0', id.as_bytes()),
memchr::memchr(b'\0', id),
None,
"Event ID cannot contain newlines, carriage returns or null characters",
"Event ID cannot contain null characters",
);
self.id = Some(id);
self.field("id", id);
self
}
fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
let value = value.as_ref();
assert_eq!(
memchr::memchr2(b'\r', b'\n', value),
None,
"SSE field value cannot contain newlines or carriage returns",
);
self.buffer.extend_from_slice(name.as_bytes());
self.buffer.put_u8(b':');
// Prevent values that start with spaces having that space stripped
if value.starts_with(b" ") {
self.buffer.put_u8(b' ');
}
self.buffer.extend_from_slice(value);
self.buffer.put_u8(b'\n');
}
fn finalize(mut self) -> Bytes {
self.buffer.put_u8(b'\n');
self.buffer.freeze()
}
}
impl fmt::Display for Event {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(comment) = &self.comment {
":".fmt(f)?;
comment.fmt(f)?;
f.write_char('\n')?;
}
if let Some(event) = &self.event {
"event: ".fmt(f)?;
event.fmt(f)?;
f.write_char('\n')?;
}
match &self.data {
Some(DataType::Text(data)) => {
for line in data.split('\n') {
"data: ".fmt(f)?;
line.fmt(f)?;
f.write_char('\n')?;
}
}
#[cfg(feature = "json")]
Some(DataType::Json(data)) => {
"data:".fmt(f)?;
data.fmt(f)?;
f.write_char('\n')?;
}
None => {}
}
if let Some(id) = &self.id {
"id: ".fmt(f)?;
id.fmt(f)?;
f.write_char('\n')?;
}
if let Some(duration) = &self.retry {
"retry:".fmt(f)?;
let secs = duration.as_secs();
let millis = duration.subsec_millis();
if secs > 0 {
// format seconds
secs.fmt(f)?;
// pad milliseconds
if millis < 10 {
f.write_str("00")?;
} else if millis < 100 {
f.write_char('0')?;
}
}
// format milliseconds
millis.fmt(f)?;
f.write_char('\n')?;
}
f.write_char('\n')?;
Ok(())
bitflags::bitflags! {
#[derive(Default)]
struct EventFlags: u8 {
const HAS_DATA = 0b0001;
const HAS_EVENT = 0b0010;
const HAS_RETRY = 0b0100;
const HAS_ID = 0b1000;
}
}
@ -375,7 +379,7 @@ impl fmt::Display for Event {
/// of each message, and the associated stream.
#[derive(Debug, Clone)]
pub struct KeepAlive {
comment_text: Cow<'static, str>,
event: Bytes,
max_interval: Duration,
}
@ -383,7 +387,7 @@ impl KeepAlive {
/// Create a new `KeepAlive`.
pub fn new() -> Self {
Self {
comment_text: Cow::Borrowed(""),
event: Bytes::from_static(b":\n\n"),
max_interval: Duration::from_secs(15),
}
}
@ -399,11 +403,17 @@ impl KeepAlive {
/// Customize the text of the keep-alive message.
///
/// Default is an empty comment.
///
///
/// # Panics
///
/// Panics if `text` contains any newline or carriage returns, as they are not allowed in SSE
/// comments.
pub fn text<I>(mut self, text: I) -> Self
where
I: Into<Cow<'static, str>>,
I: AsRef<str>,
{
self.comment_text = text.into();
self.event = Event::default().comment(text).finalize();
self
}
}
@ -437,13 +447,12 @@ impl KeepAliveStream {
.reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
}
fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Event> {
fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
let this = self.as_mut().project();
ready!(this.alive_timer.poll(cx));
let comment_str = this.keep_alive.comment_text.clone();
let event = Event::default().comment(comment_str);
let event = this.keep_alive.event.clone();
self.reset();
@ -451,6 +460,32 @@ impl KeepAliveStream {
}
}
fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
MemchrSplit {
needle,
haystack: Some(haystack),
}
}
struct MemchrSplit<'a> {
needle: u8,
haystack: Option<&'a [u8]>,
}
impl<'a> Iterator for MemchrSplit<'a> {
type Item = &'a [u8];
fn next(&mut self) -> Option<Self::Item> {
let haystack = self.haystack?;
if let Some(pos) = memchr::memchr(self.needle, haystack) {
let (front, back) = haystack.split_at(pos);
self.haystack = Some(&back[1..]);
Some(front)
} else {
self.haystack.take()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -462,10 +497,10 @@ mod tests {
#[test]
fn leading_space_is_not_stripped() {
let no_leading_space = Event::default().data("\tfoobar");
assert_eq!(no_leading_space.to_string(), "data: \tfoobar\n\n");
assert_eq!(&*no_leading_space.finalize(), b"data:\tfoobar\n\n");
let leading_space = Event::default().data(" foobar");
assert_eq!(leading_space.to_string(), "data: foobar\n\n");
assert_eq!(&*leading_space.finalize(), b"data: foobar\n\n");
}
#[tokio::test]
@ -609,4 +644,32 @@ mod tests {
fields
}
#[test]
fn memchr_spliting() {
assert_eq!(
memchr_split(2, &[]).collect::<Vec<_>>(),
[&[]] as [&[u8]; 1]
);
assert_eq!(
memchr_split(2, &[2]).collect::<Vec<_>>(),
[&[], &[]] as [&[u8]; 2]
);
assert_eq!(
memchr_split(2, &[1]).collect::<Vec<_>>(),
[&[1]] as [&[u8]; 1]
);
assert_eq!(
memchr_split(2, &[1, 2]).collect::<Vec<_>>(),
[&[1], &[]] as [&[u8]; 2]
);
assert_eq!(
memchr_split(2, &[2, 1]).collect::<Vec<_>>(),
[&[], &[1]] as [&[u8]; 2]
);
assert_eq!(
memchr_split(2, &[1, 2, 2, 1]).collect::<Vec<_>>(),
[&[1], &[], &[1]] as [&[u8]; 3]
);
}
}