From d88212c0156df8be3ff230fc8fc1f432cf5ff4f8 Mon Sep 17 00:00:00 2001 From: David Pedersen <david.pdrsn@gmail.com> Date: Sat, 31 Jul 2021 10:51:41 +0200 Subject: [PATCH] Implement `Sink` and `Stream` for `WebSocket` (#52) Among other things, this makes [`StreamExt::split`](https://docs.rs/futures/0.3.16/futures/stream/trait.StreamExt.html#method.split) accessible so one can read and write at the same time. --- CHANGELOG.md | 3 ++- src/ws/mod.rs | 49 +++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 51844cbe..67c48bb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -None. +- Implement `Stream` for `WebSocket`. +- Implement `Sink` for `WebSocket`. ## Breaking changes diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 67e912c2..c8d0501b 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -61,7 +61,10 @@ use crate::response::IntoResponse; use async_trait::async_trait; use bytes::Bytes; use future::ResponseFuture; -use futures_util::{sink::SinkExt, stream::StreamExt}; +use futures_util::{ + sink::{Sink, SinkExt}, + stream::{Stream, StreamExt}, +}; use http::{ header::{self, HeaderName}, HeaderValue, Request, Response, StatusCode, @@ -69,6 +72,7 @@ use http::{ use http_body::Full; use hyper::upgrade::{OnUpgrade, Upgraded}; use sha1::{Digest, Sha1}; +use std::pin::Pin; use std::{ borrow::Cow, convert::Infallible, fmt, future::Future, marker::PhantomData, task::Context, task::Poll, @@ -348,12 +352,9 @@ pub struct WebSocket { impl WebSocket { /// Receive another message. /// - /// Returns `None` is stream has closed. + /// Returns `None` if the stream stream has closed. pub async fn recv(&mut self) -> Option<Result<Message, BoxError>> { - self.inner - .next() - .await - .map(|result| result.map_err(Into::into).map(|inner| Message { inner })) + self.next().await } /// Send a message. @@ -367,6 +368,42 @@ impl WebSocket { } } +impl Stream for WebSocket { + type Item = Result<Message, BoxError>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.inner.poll_next_unpin(cx).map(|option_msg| { + option_msg.map(|result_msg| { + result_msg + .map_err(Into::into) + .map(|inner| Message { inner }) + }) + }) + } +} + +impl Sink<Message> for WebSocket { + type Error = BoxError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into) + } + + fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + Pin::new(&mut self.inner) + .start_send(item.inner) + .map_err(Into::into) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> { + Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> { + Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into) + } +} + /// A WebSocket message. #[derive(Eq, PartialEq, Clone)] pub struct Message {