Implement SSE using responses (#98)

This commit is contained in:
Kai Jewson 2021-08-14 16:29:09 +01:00 committed by GitHub
parent 045287aef9
commit 9cd543401f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 404 additions and 540 deletions

View file

@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `handle_error` to `service::OnMethod` ([#160](https://github.com/tokio-rs/axum/pull/160))
- Add `NestedUri` for extracting request URI in nested services ([#161](https://github.com/tokio-rs/axum/pull/161))
- Implement `FromRequest` for `http::Extensions`
- Implement SSE as an `IntoResponse` instead of a service ([#98](https://github.com/tokio-rs/axum/pull/98))
## Breaking changes
@ -62,6 +63,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `service::OnMethod`
- `handler::OnMethod`
- `routing::Nested`
- Remove `axum::sse` ([#98](https://github.com/tokio-rs/axum/pull/98))
# 0.1.3 (06. August, 2021)

View file

@ -33,6 +33,7 @@ tokio = { version = "1", features = ["time"] }
tokio-util = "0.6"
tower = { version = "0.4", features = ["util", "buffer", "make"] }
tower-http = { version = "0.1", features = ["add-extension", "map-response-body"] }
sync_wrapper = "0.1.1"
# optional dependencies
tokio-tungstenite = { optional = true, version = "0.14" }

View file

@ -4,7 +4,12 @@
//! cargo run --example sse --features=headers
//! ```
use axum::{extract::TypedHeader, prelude::*, routing::nest, sse::Event};
use axum::{
extract::TypedHeader,
prelude::*,
response::sse::{sse, Event, Sse},
routing::nest,
};
use futures::stream::{self, Stream};
use http::StatusCode;
use std::{convert::Infallible, net::SocketAddr, time::Duration};
@ -30,7 +35,7 @@ async fn main() {
// build our application with a route
let app = nest("/", static_files_service)
.route("/sse", axum::sse::sse(make_stream))
.route("/sse", get(sse_handler))
.layer(TraceLayer::new_for_http());
// run it
@ -42,10 +47,9 @@ async fn main() {
.unwrap();
}
async fn make_stream(
// sse handlers can also use extractors
async fn sse_handler(
TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
) -> Result<impl Stream<Item = Result<Event, Infallible>>, Infallible> {
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
println!("`{}` connected", user_agent.as_str());
// A `Stream` that repeats an event every second
@ -53,5 +57,5 @@ async fn make_stream(
.map(Ok)
.throttle(Duration::from_secs(1));
Ok(stream)
sse(stream)
}

View file

@ -725,7 +725,6 @@ pub mod handler;
pub mod response;
pub mod routing;
pub mod service;
pub mod sse;
#[cfg(test)]
mod tests;

View file

@ -16,6 +16,10 @@ use tower::{util::Either, BoxError};
#[doc(no_inline)]
pub use crate::Json;
pub mod sse;
pub use sse::{sse, Sse};
/// Trait for generating responses.
///
/// Types that implement `IntoResponse` can be returned from handlers.

387
src/response/sse.rs Normal file
View file

@ -0,0 +1,387 @@
//! Server-Sent Events (SSE) responses.
//!
//! # Example
//!
//! ```
//! use axum::prelude::*;
//! use axum::response::sse::{sse, Event, KeepAlive, Sse};
//! use std::{time::Duration, convert::Infallible};
//! use tokio_stream::StreamExt as _ ;
//! use futures::stream::{self, Stream};
//!
//! let app = route("/sse", get(sse_handler));
//!
//! async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
//! // A `Stream` that repeats an event every second
//! let stream = stream::repeat_with(|| Event::default().data("hi!"))
//! .map(Ok)
//! .throttle(Duration::from_secs(1));
//!
//! sse(stream).keep_alive(KeepAlive::default())
//! }
//! # async {
//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
use crate::response::IntoResponse;
use bytes::Bytes;
use futures_util::{
ready,
stream::{Stream, TryStream},
};
use http::Response;
use http_body::Body as HttpBody;
use pin_project_lite::pin_project;
use serde::Serialize;
use std::{
borrow::Cow,
fmt,
fmt::Write,
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use sync_wrapper::SyncWrapper;
use tokio::time::Sleep;
use tower::BoxError;
/// Create a new [`Sse`] response that will respond with the given stream of
/// [`Event`]s.
///
/// See the [module docs](self) for more details.
pub fn sse<S>(stream: S) -> Sse<S>
where
S: TryStream<Ok = Event> + Send + 'static,
S::Error: Into<BoxError>,
{
Sse {
stream,
keep_alive: None,
}
}
/// An SSE response, created by [`sse`].
#[derive(Clone)]
pub struct Sse<S> {
stream: S,
keep_alive: Option<KeepAlive>,
}
impl<S> Sse<S> {
/// Configure the interval between keep-alive messages.
///
/// Defaults to no keep-alive messages.
pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
self.keep_alive = Some(keep_alive);
self
}
}
impl<S> fmt::Debug for Sse<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Sse")
.field("stream", &format_args!("{}", std::any::type_name::<S>()))
.field("keep_alive", &self.keep_alive)
.finish()
}
}
impl<S, E> IntoResponse for Sse<S>
where
S: Stream<Item = Result<Event, E>> + Send + 'static,
E: Into<BoxError>,
{
type Body = Body<S>;
type BodyError = E;
fn into_response(self) -> Response<Self::Body> {
let body = Body {
event_stream: SyncWrapper::new(self.stream),
keep_alive: self.keep_alive.map(KeepAliveStream::new),
};
Response::builder()
.header(http::header::CONTENT_TYPE, "text/event-stream")
.header(http::header::CACHE_CONTROL, "no-cache")
.body(body)
.unwrap()
}
}
pin_project! {
/// The body of an SSE response.
#[derive(Debug)]
pub struct Body<S> {
#[pin]
event_stream: SyncWrapper<S>,
#[pin]
keep_alive: Option<KeepAliveStream>,
}
}
impl<S, E> HttpBody for Body<S>
where
S: Stream<Item = Result<Event, E>>,
{
type Data = Bytes;
type Error = E;
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
let this = self.project();
match this.event_stream.get_pin_mut().poll_next(cx) {
Poll::Pending => {
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
keep_alive
.poll_event(cx)
.map(|e| Some(Ok(Bytes::from(e.to_string()))))
} else {
Poll::Pending
}
}
Poll::Ready(Some(Ok(event))) => {
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
keep_alive.reset();
}
Poll::Ready(Some(Ok(Bytes::from(event.to_string()))))
}
Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
Poll::Ready(None) => Poll::Ready(None),
}
}
fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None))
}
}
/// Server-sent event
#[derive(Default, Debug)]
pub struct Event {
name: Option<String>,
id: Option<String>,
data: Option<DataType>,
event: Option<String>,
comment: Option<String>,
retry: Option<Duration>,
}
// Server-sent event data type
#[derive(Debug)]
enum DataType {
Text(String),
Json(String),
}
impl Event {
/// Set Server-sent event data
/// data field(s) ("data:<content>")
pub fn data<T>(mut self, data: T) -> Event
where
T: Into<String>,
{
self.data = Some(DataType::Text(data.into()));
self
}
/// Set Server-sent event data
/// data field(s) ("data:<content>")
pub fn json_data<T>(mut self, data: T) -> Result<Event, serde_json::Error>
where
T: Serialize,
{
self.data = Some(DataType::Json(serde_json::to_string(&data)?));
Ok(self)
}
/// Set Server-sent event comment
/// Comment field (":<comment-text>")
pub fn comment<T>(mut self, comment: T) -> Event
where
T: Into<String>,
{
self.comment = Some(comment.into());
self
}
/// Set Server-sent event event
/// Event name field ("event:<event-name>")
pub fn event<T>(mut self, event: T) -> Event
where
T: Into<String>,
{
self.event = Some(event.into());
self
}
/// Set Server-sent event retry
/// Retry timeout field ("retry:<timeout>")
pub fn retry(mut self, duration: Duration) -> Event {
self.retry = Some(duration);
self
}
/// Set Server-sent event id
/// Identifier field ("id:<identifier>")
pub fn id<T>(mut self, id: T) -> Event
where
T: Into<String>,
{
self.id = Some(id.into());
self
}
}
impl fmt::Display for Event {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(comment) = &self.comment {
":".fmt(f)?;
comment.fmt(f)?;
f.write_char('\n')?;
}
if let Some(event) = &self.event {
"event:".fmt(f)?;
event.fmt(f)?;
f.write_char('\n')?;
}
match &self.data {
Some(DataType::Text(data)) => {
for line in data.split('\n') {
"data:".fmt(f)?;
line.fmt(f)?;
f.write_char('\n')?;
}
}
Some(DataType::Json(data)) => {
"data:".fmt(f)?;
data.fmt(f)?;
f.write_char('\n')?;
}
None => {}
}
if let Some(id) = &self.id {
"id:".fmt(f)?;
id.fmt(f)?;
f.write_char('\n')?;
}
if let Some(duration) = &self.retry {
"retry:".fmt(f)?;
let secs = duration.as_secs();
let millis = duration.subsec_millis();
if secs > 0 {
// format seconds
secs.fmt(f)?;
// pad milliseconds
if millis < 10 {
f.write_str("00")?;
} else if millis < 100 {
f.write_char('0')?;
}
}
// format milliseconds
millis.fmt(f)?;
f.write_char('\n')?;
}
f.write_char('\n')?;
Ok(())
}
}
/// Configure the interval between keep-alive messages, the content
/// of each message, and the associated stream.
#[derive(Debug, Clone)]
pub struct KeepAlive {
comment_text: Cow<'static, str>,
max_interval: Duration,
}
impl KeepAlive {
/// Create a new `KeepAlive`.
pub fn new() -> Self {
Self {
comment_text: Cow::Borrowed(""),
max_interval: Duration::from_secs(15),
}
}
/// Customize the interval between keep-alive messages.
///
/// Default is 15 seconds.
pub fn interval(mut self, time: Duration) -> Self {
self.max_interval = time;
self
}
/// Customize the text of the keep-alive message.
///
/// Default is an empty comment.
pub fn text<I>(mut self, text: I) -> Self
where
I: Into<Cow<'static, str>>,
{
self.comment_text = text.into();
self
}
}
impl Default for KeepAlive {
fn default() -> Self {
Self::new()
}
}
pin_project! {
#[derive(Debug)]
struct KeepAliveStream {
keep_alive: KeepAlive,
#[pin]
alive_timer: Sleep,
}
}
impl KeepAliveStream {
fn new(keep_alive: KeepAlive) -> Self {
Self {
alive_timer: tokio::time::sleep(keep_alive.max_interval),
keep_alive,
}
}
fn reset(self: Pin<&mut Self>) {
let this = self.project();
this.alive_timer
.reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
}
fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Event> {
let this = self.as_mut().project();
ready!(this.alive_timer.poll(cx));
let comment_str = this.keep_alive.comment_text.clone();
let event = Event::default().comment(comment_str);
self.reset();
Poll::Ready(event)
}
}

View file

@ -1,533 +0,0 @@
//! Server-Sent Events (SSE)
//!
//! # Example
//!
//! ```
//! use axum::{prelude::*, sse::{sse, Event, KeepAlive}};
//! use tokio_stream::StreamExt as _;
//! use futures::stream::{self, Stream};
//! use std::{
//! time::Duration,
//! convert::Infallible,
//! };
//!
//! let app = route("/sse", sse(make_stream).keep_alive(KeepAlive::default()));
//!
//! async fn make_stream(
//! ) -> Result<impl Stream<Item = Result<Event, Infallible>>, Infallible> {
//! // A `Stream` that repeats an event every second
//! let stream = stream::repeat_with(|| Event::default().data("hi!"))
//! .map(Ok)
//! .throttle(Duration::from_secs(1));
//!
//! Ok(stream)
//! }
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
//!
//! SSE handlers can also use extractors:
//!
//! ```
//! use axum::{prelude::*, sse::{sse, Event}, extract::{RequestParts, FromRequest}};
//! use tokio_stream::StreamExt as _;
//! use futures::stream::{self, Stream};
//! use std::{
//! time::Duration,
//! convert::Infallible,
//! };
//! use http::{HeaderMap, StatusCode};
//!
//! /// An extractor that authorizes requests.
//! struct RequireAuth;
//!
//! #[async_trait::async_trait]
//! impl<B> FromRequest<B> for RequireAuth
//! where
//! B: Send,
//! {
//! type Rejection = StatusCode;
//!
//! async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
//! # unimplemented!()
//! // Put your auth logic here...
//! }
//! }
//!
//! let app = route("/sse", sse(make_stream));
//!
//! async fn make_stream(
//! // Run `RequireAuth` for each request before initiating the stream.
//! _auth: RequireAuth,
//! ) -> Result<impl Stream<Item = Result<Event, Infallible>>, Infallible> {
//! // ...
//! # Ok(futures::stream::pending())
//! }
//! # async {
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
//! # };
//! ```
use crate::{
body::{box_body, BoxBody},
extract::{FromRequest, RequestParts},
response::IntoResponse,
Error,
};
use async_trait::async_trait;
use futures_util::{
future::{TryFuture, TryFutureExt},
stream::{Stream, StreamExt, TryStream, TryStreamExt},
};
use http::{Request, Response};
use hyper::Body;
use pin_project_lite::pin_project;
use serde::Serialize;
use std::{
borrow::Cow,
convert::Infallible,
fmt::{self, Write},
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::time::Sleep;
use tower::{BoxError, Service};
/// Create a new [`Sse`] service that will call the closure to produce a stream
/// of [`Event`]s.
///
/// See the [module docs](crate::sse) for more details.
pub fn sse<H, B, T>(handler: H) -> Sse<H, B, T>
where
H: SseHandler<B, T>,
{
Sse {
handler,
keep_alive: None,
_request_body: PhantomData,
}
}
/// Trait for async functions that can be used to handle Server-sent event
/// requests.
///
/// You shouldn't need to depend on this trait directly. It is automatically
/// implemented to closures of the right types.
///
/// See the [module docs](crate::sse) for more details.
#[async_trait]
pub trait SseHandler<B, In>: Sized {
/// The stream of events produced by the handler.
type Stream: TryStream<Ok = Event> + Send + 'static;
/// The error handler might fail with.
type Error: IntoResponse;
// This seals the trait. We cannot use the regular "sealed super trait"
// approach due to coherence.
#[doc(hidden)]
type Sealed: crate::handler::sealed::HiddentTrait;
/// Call the handler with the given input parsed by extractors and produce
/// the stream of events.
async fn call(self, input: In) -> Result<Self::Stream, Self::Error>;
}
#[async_trait]
impl<F, Fut, S, B> SseHandler<B, ()> for F
where
F: FnOnce() -> Fut + Send,
Fut: TryFuture<Ok = S> + Send,
Fut::Error: IntoResponse,
S: TryStream<Ok = Event> + Send + 'static,
{
type Stream = S;
type Error = Fut::Error;
type Sealed = crate::handler::sealed::Hidden;
async fn call(self, _: ()) -> Result<Self::Stream, Self::Error> {
self().into_future().await
}
}
macro_rules! impl_sse_handler {
() => {
};
( $head:ident, $($tail:ident),* $(,)? ) => {
#[async_trait]
#[allow(non_snake_case)]
impl<F, Fut, S, B, $head, $($tail,)*> SseHandler<B, ($head, $($tail,)*)> for F
where
B: Send,
F: FnOnce($head, $($tail,)*) -> Fut + Send,
Fut: TryFuture<Ok = S> + Send,
Fut::Error: IntoResponse,
S: TryStream<Ok = Event> + Send + 'static,
$head: FromRequest<B> + Send + 'static,
$( $tail: FromRequest<B> + Send + 'static, )*
{
type Stream = S;
type Error = Fut::Error;
type Sealed = crate::handler::sealed::Hidden;
async fn call(self, ($head, $($tail,)*): ($head, $($tail,)*)) -> Result<Self::Stream, Self::Error> {
self($head, $($tail,)*).into_future().await
}
}
impl_sse_handler!($($tail,)*);
};
}
impl_sse_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
/// [`Service`] that handlers streams of Server-sent events.
///
/// See the [module docs](crate::sse) for more details.
pub struct Sse<H, B, T> {
handler: H,
keep_alive: Option<KeepAlive>,
_request_body: PhantomData<fn() -> (B, T)>,
}
impl<H, B, T> Sse<H, B, T> {
/// Configure the interval between keep-alive messages.
///
/// Defaults to no keep-alive messages.
pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
self.keep_alive = Some(keep_alive);
self
}
}
impl<H, B, T> fmt::Debug for Sse<H, B, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Sse")
.field("handler", &format_args!("{}", std::any::type_name::<H>()))
.field("keep_alive", &self.keep_alive)
.finish()
}
}
impl<H, B, T> Clone for Sse<H, B, T>
where
H: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
keep_alive: self.keep_alive.clone(),
_request_body: PhantomData,
}
}
}
impl<ReqBody, H, T> Service<Request<ReqBody>> for Sse<H, ReqBody, T>
where
H: SseHandler<ReqBody, T> + Clone + Send + 'static,
T: FromRequest<ReqBody> + Send,
ReqBody: Send + 'static,
<H::Stream as TryStream>::Error: Into<BoxError>,
{
type Response = Response<BoxBody>;
type Error = Infallible;
type Future = ResponseFuture;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let handler = self.handler.clone();
let keep_alive = self.keep_alive.clone();
ResponseFuture {
future: Box::pin(async move {
let mut req = RequestParts::new(req);
let input = match T::from_request(&mut req).await {
Ok(input) => input,
Err(err) => {
return Ok(err.into_response().map(box_body));
}
};
let stream = match handler.call(input).await {
Ok(stream) => stream,
Err(err) => {
return Ok(err.into_response().map(box_body));
}
};
let stream = if let Some(keep_alive) = keep_alive {
KeepAliveStream {
event_stream: stream,
comment_text: keep_alive.comment_text,
max_interval: keep_alive.max_interval,
alive_timer: tokio::time::sleep(keep_alive.max_interval),
}
.left_stream()
} else {
stream.into_stream().right_stream()
};
let stream = stream
.map_ok(|event| event.to_string())
.map_err(Error::new)
.into_stream();
let body = box_body(Body::wrap_stream(stream));
let response = Response::builder()
.header(http::header::CONTENT_TYPE, "text/event-stream")
.header(http::header::CACHE_CONTROL, "no-cache")
.body(body)
.unwrap();
Ok(response)
}),
}
}
}
opaque_future! {
/// Response future for [`Sse`].
pub type ResponseFuture =
futures_util::future::BoxFuture<'static, Result<Response<BoxBody>, Infallible>>;
}
/// Server-sent event
#[derive(Default, Debug)]
pub struct Event {
name: Option<String>,
id: Option<String>,
data: Option<DataType>,
event: Option<String>,
comment: Option<String>,
retry: Option<Duration>,
}
// Server-sent event data type
#[derive(Debug)]
enum DataType {
Text(String),
Json(String),
}
impl Event {
/// Set Server-sent event data
/// data field(s) ("data:<content>")
pub fn data<T>(mut self, data: T) -> Event
where
T: Into<String>,
{
self.data = Some(DataType::Text(data.into()));
self
}
/// Set Server-sent event data
/// data field(s) ("data:<content>")
pub fn json_data<T>(mut self, data: T) -> Result<Event, serde_json::Error>
where
T: Serialize,
{
self.data = Some(DataType::Json(serde_json::to_string(&data)?));
Ok(self)
}
/// Set Server-sent event comment
/// Comment field (":<comment-text>")
pub fn comment<T>(mut self, comment: T) -> Event
where
T: Into<String>,
{
self.comment = Some(comment.into());
self
}
/// Set Server-sent event event
/// Event name field ("event:<event-name>")
pub fn event<T>(mut self, event: T) -> Event
where
T: Into<String>,
{
self.event = Some(event.into());
self
}
/// Set Server-sent event retry
/// Retry timeout field ("retry:<timeout>")
pub fn retry(mut self, duration: Duration) -> Event {
self.retry = Some(duration);
self
}
/// Set Server-sent event id
/// Identifier field ("id:<identifier>")
pub fn id<T>(mut self, id: T) -> Event
where
T: Into<String>,
{
self.id = Some(id.into());
self
}
}
impl fmt::Display for Event {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(comment) = &self.comment {
":".fmt(f)?;
comment.fmt(f)?;
f.write_char('\n')?;
}
if let Some(event) = &self.event {
"event:".fmt(f)?;
event.fmt(f)?;
f.write_char('\n')?;
}
match &self.data {
Some(DataType::Text(data)) => {
for line in data.split('\n') {
"data:".fmt(f)?;
line.fmt(f)?;
f.write_char('\n')?;
}
}
Some(DataType::Json(data)) => {
"data:".fmt(f)?;
data.fmt(f)?;
f.write_char('\n')?;
}
None => {}
}
if let Some(id) = &self.id {
"id:".fmt(f)?;
id.fmt(f)?;
f.write_char('\n')?;
}
if let Some(duration) = &self.retry {
"retry:".fmt(f)?;
let secs = duration.as_secs();
let millis = duration.subsec_millis();
if secs > 0 {
// format seconds
secs.fmt(f)?;
// pad milliseconds
if millis < 10 {
f.write_str("00")?;
} else if millis < 100 {
f.write_char('0')?;
}
}
// format milliseconds
millis.fmt(f)?;
f.write_char('\n')?;
}
f.write_char('\n')?;
Ok(())
}
}
/// Configure the interval between keep-alive messages, the content
/// of each message, and the associated stream.
#[derive(Debug, Clone)]
pub struct KeepAlive {
comment_text: Cow<'static, str>,
max_interval: Duration,
}
impl KeepAlive {
/// Create a new `KeepAlive`.
pub fn new() -> Self {
Self {
comment_text: Cow::Borrowed(""),
max_interval: Duration::from_secs(15),
}
}
/// Customize the interval between keep-alive messages.
///
/// Default is 15 seconds.
pub fn interval(mut self, time: Duration) -> Self {
self.max_interval = time;
self
}
/// Customize the text of the keep-alive message.
///
/// Default is an empty comment.
pub fn text<I>(mut self, text: I) -> Self
where
I: Into<Cow<'static, str>>,
{
self.comment_text = text.into();
self
}
}
impl Default for KeepAlive {
fn default() -> Self {
Self::new()
}
}
pin_project! {
struct KeepAliveStream<S> {
#[pin]
event_stream: S,
comment_text: Cow<'static, str>,
max_interval: Duration,
#[pin]
alive_timer: Sleep,
}
}
impl<S> Stream for KeepAliveStream<S>
where
S: TryStream<Ok = Event>,
{
type Item = Result<Event, S::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
match this.event_stream.try_poll_next(cx) {
Poll::Pending => match Pin::new(&mut this.alive_timer).poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(_) => {
// restart timer
this.alive_timer
.reset(tokio::time::Instant::now() + *this.max_interval);
let comment_str = this.comment_text.clone();
let event = Event::default().comment(comment_str);
Poll::Ready(Some(Ok(event)))
}
},
Poll::Ready(Some(Ok(event))) => {
// restart timer
this.alive_timer
.reset(tokio::time::Instant::now() + *this.max_interval);
Poll::Ready(Some(Ok(event)))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
}
}
}