From d88212c0156df8be3ff230fc8fc1f432cf5ff4f8 Mon Sep 17 00:00:00 2001
From: David Pedersen <david.pdrsn@gmail.com>
Date: Sat, 31 Jul 2021 10:51:41 +0200
Subject: [PATCH] Implement `Sink` and `Stream` for `WebSocket` (#52)

Among other things, this makes [`StreamExt::split`](https://docs.rs/futures/0.3.16/futures/stream/trait.StreamExt.html#method.split) accessible so one can read and write at the same time.
---
 CHANGELOG.md  |  3 ++-
 src/ws/mod.rs | 49 +++++++++++++++++++++++++++++++++++++++++++------
 2 files changed, 45 insertions(+), 7 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 51844cbe..67c48bb8 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,7 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 # Unreleased
 
-None.
+- Implement `Stream` for `WebSocket`.
+- Implement `Sink` for `WebSocket`.
 
 ## Breaking changes
 
diff --git a/src/ws/mod.rs b/src/ws/mod.rs
index 67e912c2..c8d0501b 100644
--- a/src/ws/mod.rs
+++ b/src/ws/mod.rs
@@ -61,7 +61,10 @@ use crate::response::IntoResponse;
 use async_trait::async_trait;
 use bytes::Bytes;
 use future::ResponseFuture;
-use futures_util::{sink::SinkExt, stream::StreamExt};
+use futures_util::{
+    sink::{Sink, SinkExt},
+    stream::{Stream, StreamExt},
+};
 use http::{
     header::{self, HeaderName},
     HeaderValue, Request, Response, StatusCode,
@@ -69,6 +72,7 @@ use http::{
 use http_body::Full;
 use hyper::upgrade::{OnUpgrade, Upgraded};
 use sha1::{Digest, Sha1};
+use std::pin::Pin;
 use std::{
     borrow::Cow, convert::Infallible, fmt, future::Future, marker::PhantomData, task::Context,
     task::Poll,
@@ -348,12 +352,9 @@ pub struct WebSocket {
 impl WebSocket {
     /// Receive another message.
     ///
-    /// Returns `None` is stream has closed.
+    /// Returns `None` if the stream stream has closed.
     pub async fn recv(&mut self) -> Option<Result<Message, BoxError>> {
-        self.inner
-            .next()
-            .await
-            .map(|result| result.map_err(Into::into).map(|inner| Message { inner }))
+        self.next().await
     }
 
     /// Send a message.
@@ -367,6 +368,42 @@ impl WebSocket {
     }
 }
 
+impl Stream for WebSocket {
+    type Item = Result<Message, BoxError>;
+
+    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+        self.inner.poll_next_unpin(cx).map(|option_msg| {
+            option_msg.map(|result_msg| {
+                result_msg
+                    .map_err(Into::into)
+                    .map(|inner| Message { inner })
+            })
+        })
+    }
+}
+
+impl Sink<Message> for WebSocket {
+    type Error = BoxError;
+
+    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+        Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into)
+    }
+
+    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
+        Pin::new(&mut self.inner)
+            .start_send(item.inner)
+            .map_err(Into::into)
+    }
+
+    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+        Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into)
+    }
+
+    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+        Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into)
+    }
+}
+
 /// A WebSocket message.
 #[derive(Eq, PartialEq, Clone)]
 pub struct Message {