mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-16 14:33:02 +01:00
Make sse::Event
build event as a BytesMut
(#647)
Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
parent
e0a463b463
commit
d5694f0d0d
3 changed files with 197 additions and 129 deletions
|
@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
# Unreleased
|
||||
|
||||
- **breaking:** `sse::Event` now accepts types implementing `AsRef<str>` instead of `Into<String>`
|
||||
as field values.
|
||||
- **breaking:** `sse::Event` now panics if a setter method is called twice instead of silently
|
||||
overwriting old values.
|
||||
- **breaking:** Require `Output = ()` on `WebSocketStream::on_upgrade` ([#644])
|
||||
- **breaking:** Make `TypedHeaderRejectionReason` `#[non_exhaustive]` ([#665])
|
||||
|
||||
|
|
|
@ -29,10 +29,11 @@ bytes = "1.0"
|
|||
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
|
||||
http = "0.2.5"
|
||||
http-body = "0.4.4"
|
||||
mime = "0.3.16"
|
||||
hyper = { version = "0.14.14", features = ["server", "tcp", "stream"] }
|
||||
itoa = "1.0.1"
|
||||
matchit = "0.4.4"
|
||||
memchr = "2.4.1"
|
||||
mime = "0.3.16"
|
||||
percent-encoding = "2.1"
|
||||
pin-project-lite = "0.2.7"
|
||||
serde = "1.0"
|
||||
|
|
|
@ -32,15 +32,14 @@ use crate::{
|
|||
response::{IntoResponse, Response},
|
||||
BoxError,
|
||||
};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use futures_util::{
|
||||
ready,
|
||||
stream::{Stream, TryStream},
|
||||
};
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
fmt,
|
||||
fmt::Write,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
|
@ -134,9 +133,7 @@ where
|
|||
match this.event_stream.get_pin_mut().poll_next(cx) {
|
||||
Poll::Pending => {
|
||||
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
|
||||
keep_alive
|
||||
.poll_event(cx)
|
||||
.map(|e| Some(Ok(Bytes::from(e.to_string()))))
|
||||
keep_alive.poll_event(cx).map(|e| Some(Ok(e)))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
|
@ -145,7 +142,7 @@ where
|
|||
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
|
||||
keep_alive.reset();
|
||||
}
|
||||
Poll::Ready(Some(Ok(Bytes::from(event.to_string()))))
|
||||
Poll::Ready(Some(Ok(event.finalize())))
|
||||
}
|
||||
Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
|
@ -161,21 +158,10 @@ where
|
|||
}
|
||||
|
||||
/// Server-sent event
|
||||
#[derive(Default, Debug)]
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct Event {
|
||||
id: Option<String>,
|
||||
data: Option<DataType>,
|
||||
event: Option<String>,
|
||||
comment: Option<String>,
|
||||
retry: Option<Duration>,
|
||||
}
|
||||
|
||||
// Server-sent event data type
|
||||
#[derive(Debug)]
|
||||
enum DataType {
|
||||
Text(String),
|
||||
#[cfg(feature = "json")]
|
||||
Json(String),
|
||||
buffer: BytesMut,
|
||||
flags: EventFlags,
|
||||
}
|
||||
|
||||
impl Event {
|
||||
|
@ -189,18 +175,22 @@ impl Event {
|
|||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `data` contains any carriage returns, as they cannot be transmitted over SSE.
|
||||
/// - Panics if `data` contains any carriage returns, as they cannot be transmitted over SSE.
|
||||
/// - Panics if `data` or `json_data` have already been called.
|
||||
pub fn data<T>(mut self, data: T) -> Event
|
||||
where
|
||||
T: Into<String>,
|
||||
T: AsRef<str>,
|
||||
{
|
||||
let data = data.into();
|
||||
assert_eq!(
|
||||
memchr::memchr(b'\r', data.as_bytes()),
|
||||
None,
|
||||
"SSE data cannot contain carriage returns",
|
||||
);
|
||||
self.data = Some(DataType::Text(data));
|
||||
if self.flags.contains(EventFlags::HAS_DATA) {
|
||||
panic!("Called `EventBuilder::data` multiple times");
|
||||
}
|
||||
|
||||
for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
|
||||
self.field("data", line);
|
||||
}
|
||||
|
||||
self.flags.insert(EventFlags::HAS_DATA);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -209,13 +199,26 @@ impl Event {
|
|||
/// This corresponds to [`MessageEvent`'s data field].
|
||||
///
|
||||
/// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `data` or `json_data` have already been called.
|
||||
#[cfg(feature = "json")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
|
||||
pub fn json_data<T>(mut self, data: T) -> Result<Event, serde_json::Error>
|
||||
pub fn json_data<T>(mut self, data: T) -> serde_json::Result<Event>
|
||||
where
|
||||
T: serde::Serialize,
|
||||
{
|
||||
self.data = Some(DataType::Json(serde_json::to_string(&data)?));
|
||||
if self.flags.contains(EventFlags::HAS_DATA) {
|
||||
panic!("Called `EventBuilder::json_data` multiple times");
|
||||
}
|
||||
|
||||
self.buffer.extend_from_slice(b"data:");
|
||||
serde_json::to_writer((&mut self.buffer).writer(), &data)?;
|
||||
self.buffer.put_u8(b'\n');
|
||||
|
||||
self.flags.insert(EventFlags::HAS_DATA);
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
|
@ -223,21 +226,17 @@ impl Event {
|
|||
///
|
||||
/// This field will be ignored by most SSE clients.
|
||||
///
|
||||
/// Unlike other functions, this function can be called multiple times to add many comments.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `comment` contains any newlines or carriage returns, as they are not allowed in
|
||||
/// comments.
|
||||
pub fn comment<T>(mut self, comment: T) -> Event
|
||||
where
|
||||
T: Into<String>,
|
||||
T: AsRef<str>,
|
||||
{
|
||||
let comment = comment.into();
|
||||
assert_eq!(
|
||||
memchr::memchr2(b'\r', b'\n', comment.as_bytes()),
|
||||
None,
|
||||
"SSE comment cannot contain newlines or carriage returns"
|
||||
);
|
||||
self.comment = Some(comment);
|
||||
self.field("", comment.as_ref());
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -253,18 +252,19 @@ impl Event {
|
|||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `event` contains any newlines or carriage returns.
|
||||
/// - Panics if `event` contains any newlines or carriage returns.
|
||||
/// - Panics if this function has already been called on this event.
|
||||
pub fn event<T>(mut self, event: T) -> Event
|
||||
where
|
||||
T: Into<String>,
|
||||
T: AsRef<str>,
|
||||
{
|
||||
let event = event.into();
|
||||
assert_eq!(
|
||||
memchr::memchr2(b'\r', b'\n', event.as_bytes()),
|
||||
None,
|
||||
"SSE event name cannot contain newlines or carriage returns"
|
||||
);
|
||||
self.event = Some(event);
|
||||
if self.flags.contains(EventFlags::HAS_EVENT) {
|
||||
panic!("Called `EventBuilder::event` multiple times");
|
||||
}
|
||||
self.flags.insert(EventFlags::HAS_EVENT);
|
||||
|
||||
self.field("event", event.as_ref());
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -273,8 +273,40 @@ impl Event {
|
|||
/// This sets how long clients will wait before reconnecting if they are disconnected from the
|
||||
/// SSE endpoint. Note that this is just a hint: clients are free to wait for longer if they
|
||||
/// wish, such as if they implement exponential backoff.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if this function has already been called on this event.
|
||||
pub fn retry(mut self, duration: Duration) -> Event {
|
||||
self.retry = Some(duration);
|
||||
if self.flags.contains(EventFlags::HAS_RETRY) {
|
||||
panic!("Called `EventBuilder::retry` multiple times");
|
||||
}
|
||||
self.flags.insert(EventFlags::HAS_RETRY);
|
||||
|
||||
self.buffer.extend_from_slice(b"retry:");
|
||||
|
||||
let secs = duration.as_secs();
|
||||
let millis = duration.subsec_millis();
|
||||
|
||||
if secs > 0 {
|
||||
// format seconds
|
||||
self.buffer
|
||||
.extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
|
||||
|
||||
// pad milliseconds
|
||||
if millis < 10 {
|
||||
self.buffer.extend_from_slice(b"00");
|
||||
} else if millis < 100 {
|
||||
self.buffer.extend_from_slice(b"0");
|
||||
}
|
||||
}
|
||||
|
||||
// format milliseconds
|
||||
self.buffer
|
||||
.extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
|
||||
|
||||
self.buffer.put_u8(b'\n');
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -288,86 +320,58 @@ impl Event {
|
|||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `id` contains any newlines, carriage returns or null characters.
|
||||
/// - Panics if `id` contains any newlines, carriage returns or null characters.
|
||||
/// - Panics if this function has already been called on this event.
|
||||
pub fn id<T>(mut self, id: T) -> Event
|
||||
where
|
||||
T: Into<String>,
|
||||
T: AsRef<str>,
|
||||
{
|
||||
let id = id.into();
|
||||
if self.flags.contains(EventFlags::HAS_ID) {
|
||||
panic!("Called `EventBuilder::id` multiple times");
|
||||
}
|
||||
self.flags.insert(EventFlags::HAS_ID);
|
||||
|
||||
let id = id.as_ref().as_bytes();
|
||||
assert_eq!(
|
||||
memchr::memchr3(b'\r', b'\n', b'\0', id.as_bytes()),
|
||||
memchr::memchr(b'\0', id),
|
||||
None,
|
||||
"Event ID cannot contain newlines, carriage returns or null characters",
|
||||
"Event ID cannot contain null characters",
|
||||
);
|
||||
self.id = Some(id);
|
||||
|
||||
self.field("id", id);
|
||||
self
|
||||
}
|
||||
|
||||
fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
|
||||
let value = value.as_ref();
|
||||
assert_eq!(
|
||||
memchr::memchr2(b'\r', b'\n', value),
|
||||
None,
|
||||
"SSE field value cannot contain newlines or carriage returns",
|
||||
);
|
||||
self.buffer.extend_from_slice(name.as_bytes());
|
||||
self.buffer.put_u8(b':');
|
||||
// Prevent values that start with spaces having that space stripped
|
||||
if value.starts_with(b" ") {
|
||||
self.buffer.put_u8(b' ');
|
||||
}
|
||||
self.buffer.extend_from_slice(value);
|
||||
self.buffer.put_u8(b'\n');
|
||||
}
|
||||
|
||||
fn finalize(mut self) -> Bytes {
|
||||
self.buffer.put_u8(b'\n');
|
||||
self.buffer.freeze()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Event {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
if let Some(comment) = &self.comment {
|
||||
":".fmt(f)?;
|
||||
comment.fmt(f)?;
|
||||
f.write_char('\n')?;
|
||||
}
|
||||
|
||||
if let Some(event) = &self.event {
|
||||
"event: ".fmt(f)?;
|
||||
event.fmt(f)?;
|
||||
f.write_char('\n')?;
|
||||
}
|
||||
|
||||
match &self.data {
|
||||
Some(DataType::Text(data)) => {
|
||||
for line in data.split('\n') {
|
||||
"data: ".fmt(f)?;
|
||||
line.fmt(f)?;
|
||||
f.write_char('\n')?;
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "json")]
|
||||
Some(DataType::Json(data)) => {
|
||||
"data:".fmt(f)?;
|
||||
data.fmt(f)?;
|
||||
f.write_char('\n')?;
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
if let Some(id) = &self.id {
|
||||
"id: ".fmt(f)?;
|
||||
id.fmt(f)?;
|
||||
f.write_char('\n')?;
|
||||
}
|
||||
|
||||
if let Some(duration) = &self.retry {
|
||||
"retry:".fmt(f)?;
|
||||
|
||||
let secs = duration.as_secs();
|
||||
let millis = duration.subsec_millis();
|
||||
|
||||
if secs > 0 {
|
||||
// format seconds
|
||||
secs.fmt(f)?;
|
||||
|
||||
// pad milliseconds
|
||||
if millis < 10 {
|
||||
f.write_str("00")?;
|
||||
} else if millis < 100 {
|
||||
f.write_char('0')?;
|
||||
}
|
||||
}
|
||||
|
||||
// format milliseconds
|
||||
millis.fmt(f)?;
|
||||
|
||||
f.write_char('\n')?;
|
||||
}
|
||||
|
||||
f.write_char('\n')?;
|
||||
|
||||
Ok(())
|
||||
bitflags::bitflags! {
|
||||
#[derive(Default)]
|
||||
struct EventFlags: u8 {
|
||||
const HAS_DATA = 0b0001;
|
||||
const HAS_EVENT = 0b0010;
|
||||
const HAS_RETRY = 0b0100;
|
||||
const HAS_ID = 0b1000;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -375,7 +379,7 @@ impl fmt::Display for Event {
|
|||
/// of each message, and the associated stream.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KeepAlive {
|
||||
comment_text: Cow<'static, str>,
|
||||
event: Bytes,
|
||||
max_interval: Duration,
|
||||
}
|
||||
|
||||
|
@ -383,7 +387,7 @@ impl KeepAlive {
|
|||
/// Create a new `KeepAlive`.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
comment_text: Cow::Borrowed(""),
|
||||
event: Bytes::from_static(b":\n\n"),
|
||||
max_interval: Duration::from_secs(15),
|
||||
}
|
||||
}
|
||||
|
@ -399,11 +403,17 @@ impl KeepAlive {
|
|||
/// Customize the text of the keep-alive message.
|
||||
///
|
||||
/// Default is an empty comment.
|
||||
///
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `text` contains any newline or carriage returns, as they are not allowed in SSE
|
||||
/// comments.
|
||||
pub fn text<I>(mut self, text: I) -> Self
|
||||
where
|
||||
I: Into<Cow<'static, str>>,
|
||||
I: AsRef<str>,
|
||||
{
|
||||
self.comment_text = text.into();
|
||||
self.event = Event::default().comment(text).finalize();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
@ -437,13 +447,12 @@ impl KeepAliveStream {
|
|||
.reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
|
||||
}
|
||||
|
||||
fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Event> {
|
||||
fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
|
||||
let this = self.as_mut().project();
|
||||
|
||||
ready!(this.alive_timer.poll(cx));
|
||||
|
||||
let comment_str = this.keep_alive.comment_text.clone();
|
||||
let event = Event::default().comment(comment_str);
|
||||
let event = this.keep_alive.event.clone();
|
||||
|
||||
self.reset();
|
||||
|
||||
|
@ -451,6 +460,32 @@ impl KeepAliveStream {
|
|||
}
|
||||
}
|
||||
|
||||
fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
|
||||
MemchrSplit {
|
||||
needle,
|
||||
haystack: Some(haystack),
|
||||
}
|
||||
}
|
||||
|
||||
struct MemchrSplit<'a> {
|
||||
needle: u8,
|
||||
haystack: Option<&'a [u8]>,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for MemchrSplit<'a> {
|
||||
type Item = &'a [u8];
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let haystack = self.haystack?;
|
||||
if let Some(pos) = memchr::memchr(self.needle, haystack) {
|
||||
let (front, back) = haystack.split_at(pos);
|
||||
self.haystack = Some(&back[1..]);
|
||||
Some(front)
|
||||
} else {
|
||||
self.haystack.take()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -462,10 +497,10 @@ mod tests {
|
|||
#[test]
|
||||
fn leading_space_is_not_stripped() {
|
||||
let no_leading_space = Event::default().data("\tfoobar");
|
||||
assert_eq!(no_leading_space.to_string(), "data: \tfoobar\n\n");
|
||||
assert_eq!(&*no_leading_space.finalize(), b"data:\tfoobar\n\n");
|
||||
|
||||
let leading_space = Event::default().data(" foobar");
|
||||
assert_eq!(leading_space.to_string(), "data: foobar\n\n");
|
||||
assert_eq!(&*leading_space.finalize(), b"data: foobar\n\n");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -609,4 +644,32 @@ mod tests {
|
|||
|
||||
fields
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn memchr_spliting() {
|
||||
assert_eq!(
|
||||
memchr_split(2, &[]).collect::<Vec<_>>(),
|
||||
[&[]] as [&[u8]; 1]
|
||||
);
|
||||
assert_eq!(
|
||||
memchr_split(2, &[2]).collect::<Vec<_>>(),
|
||||
[&[], &[]] as [&[u8]; 2]
|
||||
);
|
||||
assert_eq!(
|
||||
memchr_split(2, &[1]).collect::<Vec<_>>(),
|
||||
[&[1]] as [&[u8]; 1]
|
||||
);
|
||||
assert_eq!(
|
||||
memchr_split(2, &[1, 2]).collect::<Vec<_>>(),
|
||||
[&[1], &[]] as [&[u8]; 2]
|
||||
);
|
||||
assert_eq!(
|
||||
memchr_split(2, &[2, 1]).collect::<Vec<_>>(),
|
||||
[&[], &[1]] as [&[u8]; 2]
|
||||
);
|
||||
assert_eq!(
|
||||
memchr_split(2, &[1, 2, 2, 1]).collect::<Vec<_>>(),
|
||||
[&[1], &[], &[1]] as [&[u8]; 3]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue