diff --git a/CHANGELOG.md b/CHANGELOG.md index cef5159a..d71c9531 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Improve documentation for routing ([#71](https://github.com/tokio-rs/axum/pull/71)) - Clarify required response body type when routing to `tower::Service`s ([#69](https://github.com/tokio-rs/axum/pull/69)) - Add `axum::body::box_body` to converting an `http_body::Body` to `axum::body::BoxBody` ([#69](https://github.com/tokio-rs/axum/pull/69)) +- Add `axum::sse` for Server-Sent Events ([#75](https://github.com/tokio-rs/axum/pull/75)) - Mention required dependencies in docs ([#77](https://github.com/tokio-rs/axum/pull/77)) - Fix WebSockets failing on Firefox ([#76](https://github.com/tokio-rs/axum/pull/76)) diff --git a/Cargo.toml b/Cargo.toml index 5ff8f99c..1d5f701b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ bytes = "1.0" futures-util = "0.3" http = "0.2" http-body = "0.4" -hyper = { version = "0.14", features = ["server", "tcp", "http1"] } +hyper = { version = "0.14", features = ["server", "tcp", "http1", "stream"] } pin-project = "1.0" regex = "1.5" serde = "1.0" @@ -57,6 +57,7 @@ tracing = "0.1" tracing-subscriber = "0.2" uuid = { version = "0.8", features = ["serde", "v4"] } async-session = "3.0.0" +tokio-stream = "0.1.7" [dev-dependencies.tower] version = "0.4" diff --git a/examples/sse.rs b/examples/sse.rs new file mode 100644 index 00000000..760ed2f4 --- /dev/null +++ b/examples/sse.rs @@ -0,0 +1,50 @@ +use axum::{extract::TypedHeader, prelude::*, routing::nest, service::ServiceExt, sse::Event}; +use futures::stream::{self, Stream}; +use http::StatusCode; +use std::{convert::Infallible, net::SocketAddr, time::Duration}; +use tokio_stream::StreamExt as _; +use tower_http::{services::ServeDir, trace::TraceLayer}; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + // build our application with a route + let app = nest( + "/", + axum::service::get( + ServeDir::new("examples/sse") + .append_index_html_on_directories(true) + .handle_error(|error: std::io::Error| { + Ok::<_, std::convert::Infallible>(( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Unhandled interal error: {}", error), + )) + }), + ), + ) + .route("/sse", axum::sse::sse(make_stream)) + .layer(TraceLayer::new_for_http()); + + // run it + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + tracing::debug!("listening on {}", addr); + hyper::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap(); +} + +async fn make_stream( + // sse handlers can also use extractors + TypedHeader(user_agent): TypedHeader, +) -> Result>, Infallible> { + println!("`{}` connected", user_agent.as_str()); + + // A `Stream` that repeats an event every second + let stream = stream::repeat_with(|| Event::default().data("hi!")) + .map(Ok) + .throttle(Duration::from_secs(1)); + + Ok(stream) +} diff --git a/examples/sse/index.html b/examples/sse/index.html new file mode 100644 index 00000000..390bb86b --- /dev/null +++ b/examples/sse/index.html @@ -0,0 +1 @@ + diff --git a/examples/sse/script.js b/examples/sse/script.js new file mode 100644 index 00000000..287dae22 --- /dev/null +++ b/examples/sse/script.js @@ -0,0 +1,5 @@ +var eventSource = new EventSource('sse'); + +eventSource.onmessage = function(event) { + console.log('Message from server ', event.data); +} diff --git a/src/lib.rs b/src/lib.rs index 6e9bbc08..3c97e8cb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -735,6 +735,7 @@ pub mod handler; pub mod response; pub mod routing; pub mod service; +pub mod sse; #[cfg(feature = "ws")] #[cfg_attr(docsrs, doc(cfg(feature = "ws")))] diff --git a/src/sse.rs b/src/sse.rs new file mode 100644 index 00000000..b7ee46ce --- /dev/null +++ b/src/sse.rs @@ -0,0 +1,529 @@ +//! Server-Sent Events (SSE) +//! +//! # Example +//! +//! ``` +//! use axum::{prelude::*, sse::{sse, Event, KeepAlive}}; +//! use tokio_stream::StreamExt as _; +//! use futures::stream::{self, Stream}; +//! use std::{ +//! time::Duration, +//! convert::Infallible, +//! }; +//! +//! let app = route("/sse", sse(make_stream).keep_alive(KeepAlive::default())); +//! +//! async fn make_stream( +//! ) -> Result>, Infallible> { +//! // A `Stream` that repeats an event every second +//! let stream = stream::repeat_with(|| Event::default().data("hi!")) +//! .map(Ok) +//! .throttle(Duration::from_secs(1)); +//! +//! Ok(stream) +//! } +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; +//! ``` +//! +//! SSE handlers can also use extractors: +//! +//! ``` +//! use axum::{prelude::*, sse::{sse, Event}, extract::{RequestParts, FromRequest}}; +//! use tokio_stream::StreamExt as _; +//! use futures::stream::{self, Stream}; +//! use std::{ +//! time::Duration, +//! convert::Infallible, +//! }; +//! use http::{HeaderMap, StatusCode}; +//! +//! /// An extractor that authorizes requests. +//! struct RequireAuth; +//! +//! #[async_trait::async_trait] +//! impl FromRequest for RequireAuth +//! where +//! B: Send, +//! { +//! type Rejection = StatusCode; +//! +//! async fn from_request(req: &mut RequestParts) -> Result { +//! # unimplemented!() +//! // Put your auth logic here... +//! } +//! } +//! +//! let app = route("/sse", sse(make_stream)); +//! +//! async fn make_stream( +//! // Run `RequireAuth` for each request before initiating the stream. +//! _auth: RequireAuth, +//! ) -> Result>, Infallible> { +//! // ... +//! # Ok(futures::stream::pending()) +//! } +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; +//! ``` + +use crate::{ + body::{box_body, BoxBody, BoxStdError}, + extract::{FromRequest, RequestParts}, + response::IntoResponse, +}; +use async_trait::async_trait; +use futures_util::{ + future::{TryFuture, TryFutureExt}, + stream::{Stream, StreamExt, TryStream, TryStreamExt}, +}; +use http::{Request, Response}; +use hyper::Body; +use pin_project::pin_project; +use serde::Serialize; +use std::{ + borrow::Cow, + convert::Infallible, + fmt::{self, Write}, + future::Future, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use tokio::time::Sleep; +use tower::{BoxError, Service}; + +/// Create a new [`Sse`] service that will call the closure to produce a stream +/// of [`Event`]s. +/// +/// See the [module docs](crate::sse) for more details. +pub fn sse(handler: H) -> Sse +where + H: SseHandler, +{ + Sse { + handler, + keep_alive: None, + _request_body: PhantomData, + } +} + +/// Trait for async functions that can be used to handle Server-sent event +/// requests. +/// +/// You shouldn't need to depend on this trait directly. It is automatically +/// implemented to closures of the right types. +/// +/// See the [module docs](crate::sse) for more details. +#[async_trait] +pub trait SseHandler: Sized { + /// The stream of events produced by the handler. + type Stream: TryStream + Send + 'static; + + /// The error handler might fail with. + type Error: IntoResponse; + + // This seals the trait. We cannot use the regular "sealed super trait" + // approach due to coherence. + #[doc(hidden)] + type Sealed: crate::handler::sealed::HiddentTrait; + + /// Call the handler with the given input parsed by extractors and produce + /// the stream of events. + async fn call(self, input: In) -> Result; +} + +#[async_trait] +impl SseHandler for F +where + F: FnOnce() -> Fut + Send, + Fut: TryFuture + Send, + Fut::Error: IntoResponse, + S: TryStream + Send + 'static, +{ + type Stream = S; + type Error = Fut::Error; + type Sealed = crate::handler::sealed::Hidden; + + async fn call(self, _: ()) -> Result { + self().into_future().await + } +} + +macro_rules! impl_sse_handler { + () => { + }; + + ( $head:ident, $($tail:ident),* $(,)? ) => { + #[async_trait] + #[allow(non_snake_case)] + impl SseHandler for F + where + B: Send, + F: FnOnce($head, $($tail,)*) -> Fut + Send, + Fut: TryFuture + Send, + Fut::Error: IntoResponse, + S: TryStream + Send + 'static, + $head: FromRequest + Send + 'static, + $( $tail: FromRequest + Send + 'static, )* + { + type Stream = S; + type Error = Fut::Error; + type Sealed = crate::handler::sealed::Hidden; + + async fn call(self, ($head, $($tail,)*): ($head, $($tail,)*)) -> Result { + self($head, $($tail,)*).into_future().await + } + } + + impl_sse_handler!($($tail,)*); + }; +} + +impl_sse_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); + +/// [`Service`] that handlers streams of Server-sent events. +/// +/// See the [module docs](crate::sse) for more details. +pub struct Sse { + handler: H, + keep_alive: Option, + _request_body: PhantomData (B, T)>, +} + +impl Sse { + /// Configure the interval between keep-alive messages. + /// + /// Defaults to no keep-alive messages. + pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self { + self.keep_alive = Some(keep_alive); + self + } +} + +impl fmt::Debug for Sse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Sse") + .field("handler", &format_args!("{}", std::any::type_name::())) + .field("keep_alive", &self.keep_alive) + .finish() + } +} + +impl Clone for Sse +where + H: Clone, +{ + fn clone(&self) -> Self { + Self { + handler: self.handler.clone(), + keep_alive: self.keep_alive.clone(), + _request_body: PhantomData, + } + } +} + +impl Service> for Sse +where + H: SseHandler + Clone + Send + 'static, + T: FromRequest + Send, + ReqBody: Send + 'static, + ::Error: Into, +{ + type Response = Response; + type Error = Infallible; + type Future = ResponseFuture; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let handler = self.handler.clone(); + let keep_alive = self.keep_alive.clone(); + + ResponseFuture(Box::pin(async move { + let mut req = RequestParts::new(req); + let input = match T::from_request(&mut req).await { + Ok(input) => input, + Err(err) => { + return Ok(err.into_response().map(box_body)); + } + }; + + let stream = match handler.call(input).await { + Ok(stream) => stream, + Err(err) => { + return Ok(err.into_response().map(box_body)); + } + }; + + let stream = if let Some(keep_alive) = keep_alive { + KeepAliveStream { + event_stream: stream, + comment_text: keep_alive.comment_text, + max_interval: keep_alive.max_interval, + alive_timer: tokio::time::sleep(keep_alive.max_interval), + } + .left_stream() + } else { + stream.into_stream().right_stream() + }; + + let stream = stream + .map_ok(|event| event.to_string()) + .map_err(|err| BoxStdError(err.into())) + .into_stream(); + + let body = box_body(Body::wrap_stream(stream)); + + let response = Response::builder() + .header(http::header::CONTENT_TYPE, "text/event-stream") + .header(http::header::CACHE_CONTROL, "no-cache") + .body(body) + .unwrap(); + + Ok(response) + })) + } +} + +opaque_future! { + /// Response future for [`Sse`]. + pub type ResponseFuture = + futures_util::future::BoxFuture<'static, Result, Infallible>>; +} + +/// Server-sent event +#[derive(Default, Debug)] +pub struct Event { + name: Option, + id: Option, + data: Option, + event: Option, + comment: Option, + retry: Option, +} + +// Server-sent event data type +#[derive(Debug)] +enum DataType { + Text(String), + Json(String), +} + +impl Event { + /// Set Server-sent event data + /// data field(s) ("data:") + pub fn data(mut self, data: T) -> Event + where + T: Into, + { + self.data = Some(DataType::Text(data.into())); + self + } + + /// Set Server-sent event data + /// data field(s) ("data:") + pub fn json_data(mut self, data: T) -> Result + where + T: Serialize, + { + self.data = Some(DataType::Json(serde_json::to_string(&data)?)); + Ok(self) + } + + /// Set Server-sent event comment + /// Comment field (":") + pub fn comment(mut self, comment: T) -> Event + where + T: Into, + { + self.comment = Some(comment.into()); + self + } + + /// Set Server-sent event event + /// Event name field ("event:") + pub fn event(mut self, event: T) -> Event + where + T: Into, + { + self.event = Some(event.into()); + self + } + + /// Set Server-sent event retry + /// Retry timeout field ("retry:") + pub fn retry(mut self, duration: Duration) -> Event { + self.retry = Some(duration); + self + } + + /// Set Server-sent event id + /// Identifier field ("id:") + pub fn id(mut self, id: T) -> Event + where + T: Into, + { + self.id = Some(id.into()); + self + } +} + +impl fmt::Display for Event { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if let Some(comment) = &self.comment { + ":".fmt(f)?; + comment.fmt(f)?; + f.write_char('\n')?; + } + + if let Some(event) = &self.event { + "event:".fmt(f)?; + event.fmt(f)?; + f.write_char('\n')?; + } + + match &self.data { + Some(DataType::Text(data)) => { + for line in data.split('\n') { + "data:".fmt(f)?; + line.fmt(f)?; + f.write_char('\n')?; + } + } + Some(DataType::Json(data)) => { + "data:".fmt(f)?; + data.fmt(f)?; + f.write_char('\n')?; + } + None => {} + } + + if let Some(id) = &self.id { + "id:".fmt(f)?; + id.fmt(f)?; + f.write_char('\n')?; + } + + if let Some(duration) = &self.retry { + "retry:".fmt(f)?; + + let secs = duration.as_secs(); + let millis = duration.subsec_millis(); + + if secs > 0 { + // format seconds + secs.fmt(f)?; + + // pad milliseconds + if millis < 10 { + f.write_str("00")?; + } else if millis < 100 { + f.write_char('0')?; + } + } + + // format milliseconds + millis.fmt(f)?; + + f.write_char('\n')?; + } + + f.write_char('\n')?; + + Ok(()) + } +} + +/// Configure the interval between keep-alive messages, the content +/// of each message, and the associated stream. +#[derive(Debug, Clone)] +pub struct KeepAlive { + comment_text: Cow<'static, str>, + max_interval: Duration, +} + +impl KeepAlive { + /// Create a new `KeepAlive`. + pub fn new() -> Self { + Self { + comment_text: Cow::Borrowed(""), + max_interval: Duration::from_secs(15), + } + } + + /// Customize the interval between keep-alive messages. + /// + /// Default is 15 seconds. + pub fn interval(mut self, time: Duration) -> Self { + self.max_interval = time; + self + } + + /// Customize the text of the keep-alive message. + /// + /// Default is an empty comment. + pub fn text(mut self, text: I) -> Self + where + I: Into>, + { + self.comment_text = text.into(); + self + } +} + +impl Default for KeepAlive { + fn default() -> Self { + Self::new() + } +} + +#[pin_project] +struct KeepAliveStream { + #[pin] + event_stream: S, + comment_text: Cow<'static, str>, + max_interval: Duration, + #[pin] + alive_timer: Sleep, +} + +impl Stream for KeepAliveStream +where + S: TryStream, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + match this.event_stream.try_poll_next(cx) { + Poll::Pending => match Pin::new(&mut this.alive_timer).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(_) => { + // restart timer + this.alive_timer + .reset(tokio::time::Instant::now() + *this.max_interval); + + let comment_str = this.comment_text.clone(); + let event = Event::default().comment(comment_str); + Poll::Ready(Some(Ok(event))) + } + }, + Poll::Ready(Some(Ok(event))) => { + // restart timer + this.alive_timer + .reset(tokio::time::Instant::now() + *this.max_interval); + + Poll::Ready(Some(Ok(event))) + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))), + } + } +}