diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b2bd993..b923d755 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 required for returning responses with bodies other than `hyper::Body` from handlers. See the docs for advice on how to implement `IntoResponse` ([#86](https://github.com/tokio-rs/axum/pull/86)) - Change WebSocket API to use an extractor ([#121](https://github.com/tokio-rs/axum/pull/121)) +- Make WebSocket Message an enum ([#116](https://github.com/tokio-rs/axum/pull/116)) - Add `RoutingDsl::or` for combining routes. ([#108](https://github.com/tokio-rs/axum/pull/108)) - Ensure a `HandleError` service created from `axum::ServiceExt::handle_error` _does not_ implement `RoutingDsl` as that could lead to confusing routing diff --git a/examples/chat.rs b/examples/chat.rs index 025a8ad9..6bc34511 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -59,17 +59,19 @@ async fn websocket(stream: WebSocket, state: Arc) { let mut username = String::new(); // Loop until a text message is found. - while let Some(Ok(msg)) = receiver.next().await { - if let Some(name) = msg.to_str() { + while let Some(Ok(message)) = receiver.next().await { + if let Message::Text(name) = message { // If username that is sent by client is not taken, fill username string. - check_username(&state, &mut username, name); + check_username(&state, &mut username, &name); // If not empty we want to quit the loop else we want to quit function. if !username.is_empty() { break; } else { // Only send our client that username is taken. - let _ = sender.send(Message::text("Username already taken.")).await; + let _ = sender + .send(Message::Text(String::from("Username already taken."))) + .await; return; } @@ -87,7 +89,7 @@ async fn websocket(stream: WebSocket, state: Arc) { let mut send_task = tokio::spawn(async move { while let Ok(msg) = rx.recv().await { // In any websocket error, break loop. - if sender.send(Message::text(msg)).await.is_err() { + if sender.send(Message::Text(msg)).await.is_err() { break; } } @@ -99,11 +101,9 @@ async fn websocket(stream: WebSocket, state: Arc) { // This task will receive messages from client and send them to broadcast subscribers. let mut recv_task = tokio::spawn(async move { - while let Some(Ok(msg)) = receiver.next().await { - if let Some(text) = msg.to_str() { - // Add username before message. - let _ = tx.send(format!("{}: {}", name, text)); - } + while let Some(Ok(Message::Text(text))) = receiver.next().await { + // Add username before message. + let _ = tx.send(format!("{}: {}", name, text)); } }); diff --git a/examples/websocket.rs b/examples/websocket.rs index 563cdc53..c5fd9587 100644 --- a/examples/websocket.rs +++ b/examples/websocket.rs @@ -84,7 +84,11 @@ async fn handle_socket(mut socket: WebSocket) { } loop { - if socket.send(Message::text("Hi!")).await.is_err() { + if socket + .send(Message::Text(String::from("Hi!"))) + .await + .is_err() + { println!("client disconnected"); return; } diff --git a/src/extract/ws.rs b/src/extract/ws.rs index ed9e63fa..87e22adc 100644 --- a/src/extract/ws.rs +++ b/src/extract/ws.rs @@ -53,13 +53,15 @@ use hyper::upgrade::{OnUpgrade, Upgraded}; use sha1::{Digest, Sha1}; use std::{ borrow::Cow, - fmt, future::Future, pin::Pin, task::{Context, Poll}, }; use tokio_tungstenite::{ - tungstenite::protocol::{self, WebSocketConfig}, + tungstenite::{ + self, + protocol::{self, WebSocketConfig}, + }, WebSocketStream, }; use tower::BoxError; @@ -336,7 +338,10 @@ impl WebSocket { /// Send a message. pub async fn send(&mut self, msg: Message) -> Result<(), BoxError> { - self.inner.send(msg.inner).await.map_err(Into::into) + self.inner + .send(msg.into_tungstenite()) + .await + .map_err(Into::into) } /// Gracefully close this WebSocket. @@ -353,7 +358,7 @@ impl Stream for WebSocket { option_msg.map(|result_msg| { result_msg .map_err(Into::into) - .map(|inner| Message { inner }) + .map(Message::from_tungstenite) }) }) } @@ -368,7 +373,7 @@ impl Sink for WebSocket { fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { Pin::new(&mut self.inner) - .start_send(item.inner) + .start_send(item.into_tungstenite()) .map_err(Into::into) } @@ -381,142 +386,131 @@ impl Sink for WebSocket { } } +/// Status code used to indicate why an endpoint is closing the WebSocket connection. +pub type CloseCode = u16; + +/// A struct representing the close command. +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct CloseFrame<'t> { + /// The reason as a code. + pub code: CloseCode, + /// The reason as text string. + pub reason: Cow<'t, str>, +} + /// A WebSocket message. -#[derive(Eq, PartialEq, Clone)] -pub struct Message { - inner: protocol::Message, +// +// This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license: +// Copyright (c) 2017 Alexey Galakhov +// Copyright (c) 2016 Jason Housley +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Message { + /// A text WebSocket message + Text(String), + /// A binary WebSocket message + Binary(Vec), + /// A ping message with the specified payload + /// + /// The payload here must have a length less than 125 bytes + Ping(Vec), + /// A pong message with the specified payload + /// + /// The payload here must have a length less than 125 bytes + Pong(Vec), + /// A close message with the optional close frame. + Close(Option>), } impl Message { - /// Construct a new Text `Message`. - pub fn text(s: S) -> Message - where - S: Into, - { - Message { - inner: protocol::Message::text(s), + fn into_tungstenite(self) -> tungstenite::Message { + // TODO: maybe some shorter way to do that? + match self { + Self::Text(text) => tungstenite::Message::Text(text), + Self::Binary(binary) => tungstenite::Message::Binary(binary), + Self::Ping(ping) => tungstenite::Message::Ping(ping), + Self::Pong(pong) => tungstenite::Message::Pong(pong), + Self::Close(Some(close)) => { + tungstenite::Message::Close(Some(tungstenite::protocol::CloseFrame { + code: tungstenite::protocol::frame::coding::CloseCode::from(close.code), + reason: close.reason, + })) + } + Self::Close(None) => tungstenite::Message::Close(None), } } - /// Construct a new Binary `Message`. - pub fn binary(v: V) -> Message - where - V: Into>, - { - Message { - inner: protocol::Message::binary(v), - } - } - - /// Construct a new Ping `Message`. - pub fn ping>>(v: V) -> Message { - Message { - inner: protocol::Message::Ping(v.into()), - } - } - - /// Construct a new Pong `Message`. - /// - /// Note that one rarely needs to manually construct a Pong message because - /// the underlying tungstenite socket automatically responds to the Ping - /// messages it receives. Manual construction might still be useful in some - /// cases like in tests or to send unidirectional heartbeats. - pub fn pong>>(v: V) -> Message { - Message { - inner: protocol::Message::Pong(v.into()), - } - } - - /// Construct the default Close `Message`. - pub fn close() -> Message { - Message { - inner: protocol::Message::Close(None), - } - } - - /// Construct a Close `Message` with a code and reason. - pub fn close_with(code: C, reason: R) -> Message - where - C: Into, - R: Into>, - { - Message { - inner: protocol::Message::Close(Some(protocol::frame::CloseFrame { - code: protocol::frame::coding::CloseCode::from(code.into()), - reason: reason.into(), + fn from_tungstenite(message: tungstenite::Message) -> Self { + // TODO: maybe some shorter way to do that? + match message { + tungstenite::Message::Text(text) => Self::Text(text), + tungstenite::Message::Binary(binary) => Self::Binary(binary), + tungstenite::Message::Ping(ping) => Self::Ping(ping), + tungstenite::Message::Pong(pong) => Self::Pong(pong), + tungstenite::Message::Close(Some(close)) => Self::Close(Some(CloseFrame { + code: close.code.into(), + reason: close.reason, })), + tungstenite::Message::Close(None) => Self::Close(None), } } - /// Returns true if this message is a Text message. - pub fn is_text(&self) -> bool { - self.inner.is_text() - } - - /// Returns true if this message is a Binary message. - pub fn is_binary(&self) -> bool { - self.inner.is_binary() - } - - /// Returns true if this message a is a Close message. - pub fn is_close(&self) -> bool { - self.inner.is_close() - } - - /// Returns true if this message is a Ping message. - pub fn is_ping(&self) -> bool { - self.inner.is_ping() - } - - /// Returns true if this message is a Pong message. - pub fn is_pong(&self) -> bool { - self.inner.is_pong() - } - - /// Try to get the close frame (close code and reason) - pub fn close_frame(&self) -> Option<(u16, &str)> { - if let protocol::Message::Close(Some(close_frame)) = &self.inner { - Some((close_frame.code.into(), close_frame.reason.as_ref())) - } else { - None + /// Consume the WebSocket and return it as binary data. + pub fn into_data(self) -> Vec { + match self { + Self::Text(string) => string.into_bytes(), + Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data, + Self::Close(None) => Vec::new(), + Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), } } - /// Try to get a reference to the string text, if this is a Text message. - pub fn to_str(&self) -> Option<&str> { - if let protocol::Message::Text(s) = &self.inner { - Some(s) - } else { - None + /// Attempt to consume the WebSocket message and convert it to a String. + pub fn into_text(self) -> Result { + match self { + Self::Text(string) => Ok(string), + Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => { + Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?) + } + Self::Close(None) => Ok(String::new()), + Self::Close(Some(frame)) => Ok(frame.reason.into_owned()), } } - /// Return the bytes of this message, if the message can contain data. - pub fn as_bytes(&self) -> &[u8] { - match self.inner { - protocol::Message::Text(ref s) => s.as_bytes(), - protocol::Message::Binary(ref v) => v, - protocol::Message::Ping(ref v) => v, - protocol::Message::Pong(ref v) => v, - protocol::Message::Close(_) => &[], + /// Attempt to get a &str from the WebSocket message, + /// this will try to convert binary data to utf8. + pub fn to_text(&self) -> Result<&str, BoxError> { + match *self { + Self::Text(ref string) => Ok(string), + Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => { + Ok(std::str::from_utf8(data)?) + } + Self::Close(None) => Ok(""), + Self::Close(Some(ref frame)) => Ok(&frame.reason), } } - - /// Destructure this message into binary data. - pub fn into_bytes(self) -> Vec { - self.inner.into_data() - } -} - -impl fmt::Debug for Message { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.inner.fmt(f) - } } impl From for Vec { fn from(msg: Message) -> Self { - msg.into_bytes() + msg.into_data() } }