mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-28 07:20:12 +01:00
Upgrade tokio-tungstenite to 0.26 (#3078)
This commit is contained in:
parent
5cdd8a4f18
commit
96e071c8fb
9 changed files with 194 additions and 58 deletions
|
@ -76,7 +76,7 @@ serde_path_to_error = { version = "0.1.8", optional = true }
|
||||||
serde_urlencoded = { version = "0.7", optional = true }
|
serde_urlencoded = { version = "0.7", optional = true }
|
||||||
sha1 = { version = "0.10", optional = true }
|
sha1 = { version = "0.10", optional = true }
|
||||||
tokio = { package = "tokio", version = "1.25.0", features = ["time"], optional = true }
|
tokio = { package = "tokio", version = "1.25.0", features = ["time"], optional = true }
|
||||||
tokio-tungstenite = { version = "0.24.0", optional = true }
|
tokio-tungstenite = { version = "0.26.0", optional = true }
|
||||||
tracing = { version = "0.1", default-features = false, optional = true }
|
tracing = { version = "0.1", default-features = false, optional = true }
|
||||||
|
|
||||||
[dependencies.tower-http]
|
[dependencies.tower-http]
|
||||||
|
@ -127,7 +127,7 @@ serde_json = { version = "1.0", features = ["raw_value"] }
|
||||||
time = { version = "0.3", features = ["serde-human-readable"] }
|
time = { version = "0.3", features = ["serde-human-readable"] }
|
||||||
tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] }
|
tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] }
|
||||||
tokio-stream = "0.1"
|
tokio-stream = "0.1"
|
||||||
tokio-tungstenite = "0.24.0"
|
tokio-tungstenite = "0.26.0"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = { version = "0.3", features = ["json"] }
|
tracing-subscriber = { version = "0.3", features = ["json"] }
|
||||||
uuid = { version = "1.0", features = ["serde", "v4"] }
|
uuid = { version = "1.0", features = ["serde", "v4"] }
|
||||||
|
|
|
@ -553,16 +553,131 @@ impl Sink<Message> for WebSocket {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// UTF-8 wrapper for [Bytes].
|
||||||
|
///
|
||||||
|
/// An [Utf8Bytes] is always guaranteed to contain valid UTF-8.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||||
|
pub struct Utf8Bytes(ts::Utf8Bytes);
|
||||||
|
|
||||||
|
impl Utf8Bytes {
|
||||||
|
/// Creates from a static str.
|
||||||
|
#[inline]
|
||||||
|
pub const fn from_static(str: &'static str) -> Self {
|
||||||
|
Self(ts::Utf8Bytes::from_static(str))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns as a string slice.
|
||||||
|
#[inline]
|
||||||
|
pub fn as_str(&self) -> &str {
|
||||||
|
self.0.as_str()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn into_tungstenite(self) -> ts::Utf8Bytes {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for Utf8Bytes {
|
||||||
|
type Target = str;
|
||||||
|
|
||||||
|
/// ```
|
||||||
|
/// /// Example fn that takes a str slice
|
||||||
|
/// fn a(s: &str) {}
|
||||||
|
///
|
||||||
|
/// let data = axum::extract::ws::Utf8Bytes::from_static("foo123");
|
||||||
|
///
|
||||||
|
/// // auto-deref as arg
|
||||||
|
/// a(&data);
|
||||||
|
///
|
||||||
|
/// // deref to str methods
|
||||||
|
/// assert_eq!(data.len(), 6);
|
||||||
|
/// ```
|
||||||
|
#[inline]
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
self.as_str()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Utf8Bytes {
|
||||||
|
#[inline]
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.write_str(self.as_str())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<Bytes> for Utf8Bytes {
|
||||||
|
type Error = std::str::Utf8Error;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
|
||||||
|
Ok(Self(bytes.try_into()?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<Vec<u8>> for Utf8Bytes {
|
||||||
|
type Error = std::str::Utf8Error;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn try_from(v: Vec<u8>) -> Result<Self, Self::Error> {
|
||||||
|
Ok(Self(v.try_into()?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for Utf8Bytes {
|
||||||
|
#[inline]
|
||||||
|
fn from(s: String) -> Self {
|
||||||
|
Self(s.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for Utf8Bytes {
|
||||||
|
#[inline]
|
||||||
|
fn from(s: &str) -> Self {
|
||||||
|
Self(s.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&String> for Utf8Bytes {
|
||||||
|
#[inline]
|
||||||
|
fn from(s: &String) -> Self {
|
||||||
|
Self(s.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Utf8Bytes> for Bytes {
|
||||||
|
#[inline]
|
||||||
|
fn from(Utf8Bytes(bytes): Utf8Bytes) -> Self {
|
||||||
|
bytes.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> PartialEq<T> for Utf8Bytes
|
||||||
|
where
|
||||||
|
for<'a> &'a str: PartialEq<T>,
|
||||||
|
{
|
||||||
|
/// ```
|
||||||
|
/// let payload = axum::extract::ws::Utf8Bytes::from_static("foo123");
|
||||||
|
/// assert_eq!(payload, "foo123");
|
||||||
|
/// assert_eq!(payload, "foo123".to_string());
|
||||||
|
/// assert_eq!(payload, &"foo123".to_string());
|
||||||
|
/// assert_eq!(payload, std::borrow::Cow::from("foo123"));
|
||||||
|
/// ```
|
||||||
|
#[inline]
|
||||||
|
fn eq(&self, other: &T) -> bool {
|
||||||
|
self.as_str() == *other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Status code used to indicate why an endpoint is closing the WebSocket connection.
|
/// Status code used to indicate why an endpoint is closing the WebSocket connection.
|
||||||
pub type CloseCode = u16;
|
pub type CloseCode = u16;
|
||||||
|
|
||||||
/// A struct representing the close command.
|
/// A struct representing the close command.
|
||||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||||
pub struct CloseFrame<'t> {
|
pub struct CloseFrame {
|
||||||
/// The reason as a code.
|
/// The reason as a code.
|
||||||
pub code: CloseCode,
|
pub code: CloseCode,
|
||||||
/// The reason as text string.
|
/// The reason as text string.
|
||||||
pub reason: Cow<'t, str>,
|
pub reason: Utf8Bytes,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A WebSocket message.
|
/// A WebSocket message.
|
||||||
|
@ -591,16 +706,16 @@ pub struct CloseFrame<'t> {
|
||||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||||
pub enum Message {
|
pub enum Message {
|
||||||
/// A text WebSocket message
|
/// A text WebSocket message
|
||||||
Text(String),
|
Text(Utf8Bytes),
|
||||||
/// A binary WebSocket message
|
/// A binary WebSocket message
|
||||||
Binary(Vec<u8>),
|
Binary(Bytes),
|
||||||
/// A ping message with the specified payload
|
/// A ping message with the specified payload
|
||||||
///
|
///
|
||||||
/// The payload here must have a length less than 125 bytes.
|
/// The payload here must have a length less than 125 bytes.
|
||||||
///
|
///
|
||||||
/// Ping messages will be automatically responded to by the server, so you do not have to worry
|
/// Ping messages will be automatically responded to by the server, so you do not have to worry
|
||||||
/// about dealing with them yourself.
|
/// about dealing with them yourself.
|
||||||
Ping(Vec<u8>),
|
Ping(Bytes),
|
||||||
/// A pong message with the specified payload
|
/// A pong message with the specified payload
|
||||||
///
|
///
|
||||||
/// The payload here must have a length less than 125 bytes.
|
/// The payload here must have a length less than 125 bytes.
|
||||||
|
@ -608,7 +723,7 @@ pub enum Message {
|
||||||
/// Pong messages will be automatically sent to the client if a ping message is received, so
|
/// Pong messages will be automatically sent to the client if a ping message is received, so
|
||||||
/// you do not have to worry about constructing them yourself unless you want to implement a
|
/// you do not have to worry about constructing them yourself unless you want to implement a
|
||||||
/// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
|
/// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
|
||||||
Pong(Vec<u8>),
|
Pong(Bytes),
|
||||||
/// A close message with the optional close frame.
|
/// A close message with the optional close frame.
|
||||||
///
|
///
|
||||||
/// You may "uncleanly" close a WebSocket connection at any time
|
/// You may "uncleanly" close a WebSocket connection at any time
|
||||||
|
@ -628,19 +743,19 @@ pub enum Message {
|
||||||
/// Since no further messages will be received,
|
/// Since no further messages will be received,
|
||||||
/// you may either do nothing
|
/// you may either do nothing
|
||||||
/// or explicitly drop the connection.
|
/// or explicitly drop the connection.
|
||||||
Close(Option<CloseFrame<'static>>),
|
Close(Option<CloseFrame>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Message {
|
impl Message {
|
||||||
fn into_tungstenite(self) -> ts::Message {
|
fn into_tungstenite(self) -> ts::Message {
|
||||||
match self {
|
match self {
|
||||||
Self::Text(text) => ts::Message::Text(text),
|
Self::Text(text) => ts::Message::Text(text.into_tungstenite()),
|
||||||
Self::Binary(binary) => ts::Message::Binary(binary),
|
Self::Binary(binary) => ts::Message::Binary(binary),
|
||||||
Self::Ping(ping) => ts::Message::Ping(ping),
|
Self::Ping(ping) => ts::Message::Ping(ping),
|
||||||
Self::Pong(pong) => ts::Message::Pong(pong),
|
Self::Pong(pong) => ts::Message::Pong(pong),
|
||||||
Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
|
Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
|
||||||
code: ts::protocol::frame::coding::CloseCode::from(close.code),
|
code: ts::protocol::frame::coding::CloseCode::from(close.code),
|
||||||
reason: close.reason,
|
reason: close.reason.into_tungstenite(),
|
||||||
})),
|
})),
|
||||||
Self::Close(None) => ts::Message::Close(None),
|
Self::Close(None) => ts::Message::Close(None),
|
||||||
}
|
}
|
||||||
|
@ -648,13 +763,13 @@ impl Message {
|
||||||
|
|
||||||
fn from_tungstenite(message: ts::Message) -> Option<Self> {
|
fn from_tungstenite(message: ts::Message) -> Option<Self> {
|
||||||
match message {
|
match message {
|
||||||
ts::Message::Text(text) => Some(Self::Text(text)),
|
ts::Message::Text(text) => Some(Self::Text(Utf8Bytes(text))),
|
||||||
ts::Message::Binary(binary) => Some(Self::Binary(binary)),
|
ts::Message::Binary(binary) => Some(Self::Binary(binary)),
|
||||||
ts::Message::Ping(ping) => Some(Self::Ping(ping)),
|
ts::Message::Ping(ping) => Some(Self::Ping(ping)),
|
||||||
ts::Message::Pong(pong) => Some(Self::Pong(pong)),
|
ts::Message::Pong(pong) => Some(Self::Pong(pong)),
|
||||||
ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
|
ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
|
||||||
code: close.code.into(),
|
code: close.code.into(),
|
||||||
reason: close.reason,
|
reason: Utf8Bytes(close.reason),
|
||||||
}))),
|
}))),
|
||||||
ts::Message::Close(None) => Some(Self::Close(None)),
|
ts::Message::Close(None) => Some(Self::Close(None)),
|
||||||
// we can ignore `Frame` frames as recommended by the tungstenite maintainers
|
// we can ignore `Frame` frames as recommended by the tungstenite maintainers
|
||||||
|
@ -664,24 +779,24 @@ impl Message {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Consume the WebSocket and return it as binary data.
|
/// Consume the WebSocket and return it as binary data.
|
||||||
pub fn into_data(self) -> Vec<u8> {
|
pub fn into_data(self) -> Bytes {
|
||||||
match self {
|
match self {
|
||||||
Self::Text(string) => string.into_bytes(),
|
Self::Text(string) => Bytes::from(string),
|
||||||
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
|
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
|
||||||
Self::Close(None) => Vec::new(),
|
Self::Close(None) => Bytes::new(),
|
||||||
Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
|
Self::Close(Some(frame)) => Bytes::from(frame.reason),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Attempt to consume the WebSocket message and convert it to a String.
|
/// Attempt to consume the WebSocket message and convert it to a Utf8Bytes.
|
||||||
pub fn into_text(self) -> Result<String, Error> {
|
pub fn into_text(self) -> Result<Utf8Bytes, Error> {
|
||||||
match self {
|
match self {
|
||||||
Self::Text(string) => Ok(string),
|
Self::Text(string) => Ok(string),
|
||||||
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data)
|
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => {
|
||||||
.map_err(|err| err.utf8_error())
|
Ok(Utf8Bytes::try_from(data).map_err(Error::new)?)
|
||||||
.map_err(Error::new)?),
|
}
|
||||||
Self::Close(None) => Ok(String::new()),
|
Self::Close(None) => Ok(Utf8Bytes::default()),
|
||||||
Self::Close(Some(frame)) => Ok(frame.reason.into_owned()),
|
Self::Close(Some(frame)) => Ok(frame.reason),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -689,7 +804,7 @@ impl Message {
|
||||||
/// this will try to convert binary data to utf8.
|
/// this will try to convert binary data to utf8.
|
||||||
pub fn to_text(&self) -> Result<&str, Error> {
|
pub fn to_text(&self) -> Result<&str, Error> {
|
||||||
match *self {
|
match *self {
|
||||||
Self::Text(ref string) => Ok(string),
|
Self::Text(ref string) => Ok(string.as_str()),
|
||||||
Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
|
Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
|
||||||
Ok(std::str::from_utf8(data).map_err(Error::new)?)
|
Ok(std::str::from_utf8(data).map_err(Error::new)?)
|
||||||
}
|
}
|
||||||
|
@ -697,11 +812,27 @@ impl Message {
|
||||||
Self::Close(Some(ref frame)) => Ok(&frame.reason),
|
Self::Close(Some(ref frame)) => Ok(&frame.reason),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a new text WebSocket message from a stringable.
|
||||||
|
pub fn text<S>(string: S) -> Message
|
||||||
|
where
|
||||||
|
S: Into<Utf8Bytes>,
|
||||||
|
{
|
||||||
|
Message::Text(string.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new binary WebSocket message by converting to `Bytes`.
|
||||||
|
pub fn binary<B>(bin: B) -> Message
|
||||||
|
where
|
||||||
|
B: Into<Bytes>,
|
||||||
|
{
|
||||||
|
Message::Binary(bin.into())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<String> for Message {
|
impl From<String> for Message {
|
||||||
fn from(string: String) -> Self {
|
fn from(string: String) -> Self {
|
||||||
Message::Text(string)
|
Message::Text(string.into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -713,19 +844,19 @@ impl<'s> From<&'s str> for Message {
|
||||||
|
|
||||||
impl<'b> From<&'b [u8]> for Message {
|
impl<'b> From<&'b [u8]> for Message {
|
||||||
fn from(data: &'b [u8]) -> Self {
|
fn from(data: &'b [u8]) -> Self {
|
||||||
Message::Binary(data.into())
|
Message::Binary(Bytes::copy_from_slice(data))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Vec<u8>> for Message {
|
impl From<Vec<u8>> for Message {
|
||||||
fn from(data: Vec<u8>) -> Self {
|
fn from(data: Vec<u8>) -> Self {
|
||||||
Message::Binary(data)
|
Message::Binary(data.into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Message> for Vec<u8> {
|
impl From<Message> for Vec<u8> {
|
||||||
fn from(msg: Message) -> Self {
|
fn from(msg: Message) -> Self {
|
||||||
msg.into_data()
|
msg.into_data().to_vec()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1026,19 +1157,19 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(mut socket: WebSocketStream<S>) {
|
async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(mut socket: WebSocketStream<S>) {
|
||||||
let input = tungstenite::Message::Text("foobar".to_owned());
|
let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar"));
|
||||||
socket.send(input.clone()).await.unwrap();
|
socket.send(input.clone()).await.unwrap();
|
||||||
let output = socket.next().await.unwrap().unwrap();
|
let output = socket.next().await.unwrap().unwrap();
|
||||||
assert_eq!(input, output);
|
assert_eq!(input, output);
|
||||||
|
|
||||||
socket
|
socket
|
||||||
.send(tungstenite::Message::Ping("ping".to_owned().into_bytes()))
|
.send(tungstenite::Message::Ping(Bytes::from_static(b"ping")))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let output = socket.next().await.unwrap().unwrap();
|
let output = socket.next().await.unwrap().unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
output,
|
output,
|
||||||
tungstenite::Message::Pong("ping".to_owned().into_bytes())
|
tungstenite::Message::Pong(Bytes::from_static(b"ping"))
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{
|
extract::{
|
||||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade},
|
||||||
State,
|
State,
|
||||||
},
|
},
|
||||||
response::{Html, IntoResponse},
|
response::{Html, IntoResponse},
|
||||||
|
@ -79,7 +79,7 @@ async fn websocket(stream: WebSocket, state: Arc<AppState>) {
|
||||||
while let Some(Ok(message)) = receiver.next().await {
|
while let Some(Ok(message)) = receiver.next().await {
|
||||||
if let Message::Text(name) = message {
|
if let Message::Text(name) = message {
|
||||||
// If username that is sent by client is not taken, fill username string.
|
// If username that is sent by client is not taken, fill username string.
|
||||||
check_username(&state, &mut username, &name);
|
check_username(&state, &mut username, name.as_str());
|
||||||
|
|
||||||
// If not empty we want to quit the loop else we want to quit function.
|
// If not empty we want to quit the loop else we want to quit function.
|
||||||
if !username.is_empty() {
|
if !username.is_empty() {
|
||||||
|
@ -87,7 +87,9 @@ async fn websocket(stream: WebSocket, state: Arc<AppState>) {
|
||||||
} else {
|
} else {
|
||||||
// Only send our client that username is taken.
|
// Only send our client that username is taken.
|
||||||
let _ = sender
|
let _ = sender
|
||||||
.send(Message::Text(String::from("Username already taken.")))
|
.send(Message::Text(Utf8Bytes::from_static(
|
||||||
|
"Username already taken.",
|
||||||
|
)))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
return;
|
return;
|
||||||
|
@ -109,7 +111,7 @@ async fn websocket(stream: WebSocket, state: Arc<AppState>) {
|
||||||
let mut send_task = tokio::spawn(async move {
|
let mut send_task = tokio::spawn(async move {
|
||||||
while let Ok(msg) = rx.recv().await {
|
while let Ok(msg) = rx.recv().await {
|
||||||
// In any websocket error, break loop.
|
// In any websocket error, break loop.
|
||||||
if sender.send(Message::Text(msg)).await.is_err() {
|
if sender.send(Message::text(msg)).await.is_err() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,4 +8,4 @@ publish = false
|
||||||
axum = { path = "../../axum", features = ["ws"] }
|
axum = { path = "../../axum", features = ["ws"] }
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
tokio = { version = "1.0", features = ["full"] }
|
tokio = { version = "1.0", features = ["full"] }
|
||||||
tokio-tungstenite = "0.24"
|
tokio-tungstenite = "0.26"
|
||||||
|
|
|
@ -48,7 +48,7 @@ async fn integration_testable_handle_socket(mut socket: WebSocket) {
|
||||||
while let Some(Ok(msg)) = socket.recv().await {
|
while let Some(Ok(msg)) = socket.recv().await {
|
||||||
if let Message::Text(msg) = msg {
|
if let Message::Text(msg) = msg {
|
||||||
if socket
|
if socket
|
||||||
.send(Message::Text(format!("You said: {msg}")))
|
.send(Message::Text(format!("You said: {msg}").into()))
|
||||||
.await
|
.await
|
||||||
.is_err()
|
.is_err()
|
||||||
{
|
{
|
||||||
|
@ -79,7 +79,7 @@ where
|
||||||
while let Some(Ok(msg)) = read.next().await {
|
while let Some(Ok(msg)) = read.next().await {
|
||||||
if let Message::Text(msg) = msg {
|
if let Message::Text(msg) = msg {
|
||||||
if write
|
if write
|
||||||
.send(Message::Text(format!("You said: {msg}")))
|
.send(Message::Text(format!("You said: {msg}").into()))
|
||||||
.await
|
.await
|
||||||
.is_err()
|
.is_err()
|
||||||
{
|
{
|
||||||
|
@ -123,7 +123,7 @@ mod tests {
|
||||||
other => panic!("expected a text message but got {other:?}"),
|
other => panic!("expected a text message but got {other:?}"),
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(msg, "You said: foo");
|
assert_eq!(msg.as_str(), "You said: foo");
|
||||||
}
|
}
|
||||||
|
|
||||||
// We can unit test the other handler by creating channels to read and write from.
|
// We can unit test the other handler by creating channels to read and write from.
|
||||||
|
@ -136,16 +136,13 @@ mod tests {
|
||||||
|
|
||||||
tokio::spawn(unit_testable_handle_socket(socket_write, socket_read));
|
tokio::spawn(unit_testable_handle_socket(socket_write, socket_read));
|
||||||
|
|
||||||
test_tx
|
test_tx.send(Ok(Message::Text("foo".into()))).await.unwrap();
|
||||||
.send(Ok(Message::Text("foo".to_owned())))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let msg = match test_rx.next().await.unwrap() {
|
let msg = match test_rx.next().await.unwrap() {
|
||||||
Message::Text(msg) => msg,
|
Message::Text(msg) => msg,
|
||||||
other => panic!("expected a text message but got {other:?}"),
|
other => panic!("expected a text message but got {other:?}"),
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(msg, "You said: foo");
|
assert_eq!(msg.as_str(), "You said: foo");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,7 +75,7 @@ async fn ws_handler(
|
||||||
res = ws.recv() => {
|
res = ws.recv() => {
|
||||||
match res {
|
match res {
|
||||||
Some(Ok(ws::Message::Text(s))) => {
|
Some(Ok(ws::Message::Text(s))) => {
|
||||||
let _ = sender.send(s);
|
let _ = sender.send(s.to_string());
|
||||||
}
|
}
|
||||||
Some(Ok(_)) => {}
|
Some(Ok(_)) => {}
|
||||||
Some(Err(e)) => tracing::debug!("client disconnected abruptly: {e}"),
|
Some(Err(e)) => tracing::debug!("client disconnected abruptly: {e}"),
|
||||||
|
@ -85,7 +85,7 @@ async fn ws_handler(
|
||||||
// Tokio guarantees that `broadcast::Receiver::recv` is cancel-safe.
|
// Tokio guarantees that `broadcast::Receiver::recv` is cancel-safe.
|
||||||
res = receiver.recv() => {
|
res = receiver.recv() => {
|
||||||
match res {
|
match res {
|
||||||
Ok(msg) => if let Err(e) = ws.send(ws::Message::Text(msg)).await {
|
Ok(msg) => if let Err(e) = ws.send(ws::Message::Text(msg.into())).await {
|
||||||
tracing::debug!("client disconnected abruptly: {e}");
|
tracing::debug!("client disconnected abruptly: {e}");
|
||||||
}
|
}
|
||||||
Err(_) => continue,
|
Err(_) => continue,
|
||||||
|
|
|
@ -11,7 +11,7 @@ futures = "0.3"
|
||||||
futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] }
|
futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] }
|
||||||
headers = "0.4"
|
headers = "0.4"
|
||||||
tokio = { version = "1.0", features = ["full"] }
|
tokio = { version = "1.0", features = ["full"] }
|
||||||
tokio-tungstenite = "0.24.0"
|
tokio-tungstenite = "0.26.0"
|
||||||
tower-http = { version = "0.6.1", features = ["fs", "trace"] }
|
tower-http = { version = "0.6.1", features = ["fs", "trace"] }
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
|
|
@ -12,9 +12,9 @@
|
||||||
|
|
||||||
use futures_util::stream::FuturesUnordered;
|
use futures_util::stream::FuturesUnordered;
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use std::borrow::Cow;
|
|
||||||
use std::ops::ControlFlow;
|
use std::ops::ControlFlow;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
use tokio_tungstenite::tungstenite::Utf8Bytes;
|
||||||
|
|
||||||
// we will use tungstenite for websocket client impl (same library as what axum is using)
|
// we will use tungstenite for websocket client impl (same library as what axum is using)
|
||||||
use tokio_tungstenite::{
|
use tokio_tungstenite::{
|
||||||
|
@ -65,7 +65,9 @@ async fn spawn_client(who: usize) {
|
||||||
|
|
||||||
//we can ping the server for start
|
//we can ping the server for start
|
||||||
sender
|
sender
|
||||||
.send(Message::Ping("Hello, Server!".into()))
|
.send(Message::Ping(axum::body::Bytes::from_static(
|
||||||
|
b"Hello, Server!",
|
||||||
|
)))
|
||||||
.await
|
.await
|
||||||
.expect("Can not send!");
|
.expect("Can not send!");
|
||||||
|
|
||||||
|
@ -74,7 +76,7 @@ async fn spawn_client(who: usize) {
|
||||||
for i in 1..30 {
|
for i in 1..30 {
|
||||||
// In any websocket error, break loop.
|
// In any websocket error, break loop.
|
||||||
if sender
|
if sender
|
||||||
.send(Message::Text(format!("Message number {i}...")))
|
.send(Message::Text(format!("Message number {i}...").into()))
|
||||||
.await
|
.await
|
||||||
.is_err()
|
.is_err()
|
||||||
{
|
{
|
||||||
|
@ -90,7 +92,7 @@ async fn spawn_client(who: usize) {
|
||||||
if let Err(e) = sender
|
if let Err(e) = sender
|
||||||
.send(Message::Close(Some(CloseFrame {
|
.send(Message::Close(Some(CloseFrame {
|
||||||
code: CloseCode::Normal,
|
code: CloseCode::Normal,
|
||||||
reason: Cow::from("Goodbye"),
|
reason: Utf8Bytes::from_static("Goodbye"),
|
||||||
})))
|
})))
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
|
|
@ -17,14 +17,14 @@
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::ws::{Message, WebSocket, WebSocketUpgrade},
|
body::Bytes,
|
||||||
|
extract::ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade},
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
routing::any,
|
routing::any,
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
use axum_extra::TypedHeader;
|
use axum_extra::TypedHeader;
|
||||||
|
|
||||||
use std::borrow::Cow;
|
|
||||||
use std::ops::ControlFlow;
|
use std::ops::ControlFlow;
|
||||||
use std::{net::SocketAddr, path::PathBuf};
|
use std::{net::SocketAddr, path::PathBuf};
|
||||||
use tower_http::{
|
use tower_http::{
|
||||||
|
@ -101,7 +101,11 @@ async fn ws_handler(
|
||||||
/// Actual websocket statemachine (one will be spawned per connection)
|
/// Actual websocket statemachine (one will be spawned per connection)
|
||||||
async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
|
async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
|
||||||
// send a ping (unsupported by some browsers) just to kick things off and get a response
|
// send a ping (unsupported by some browsers) just to kick things off and get a response
|
||||||
if socket.send(Message::Ping(vec![1, 2, 3])).await.is_ok() {
|
if socket
|
||||||
|
.send(Message::Ping(Bytes::from_static(&[1, 2, 3])))
|
||||||
|
.await
|
||||||
|
.is_ok()
|
||||||
|
{
|
||||||
println!("Pinged {who}...");
|
println!("Pinged {who}...");
|
||||||
} else {
|
} else {
|
||||||
println!("Could not send ping {who}!");
|
println!("Could not send ping {who}!");
|
||||||
|
@ -131,7 +135,7 @@ async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
|
||||||
// connecting to server and receiving their greetings.
|
// connecting to server and receiving their greetings.
|
||||||
for i in 1..5 {
|
for i in 1..5 {
|
||||||
if socket
|
if socket
|
||||||
.send(Message::Text(format!("Hi {i} times!")))
|
.send(Message::Text(format!("Hi {i} times!").into()))
|
||||||
.await
|
.await
|
||||||
.is_err()
|
.is_err()
|
||||||
{
|
{
|
||||||
|
@ -151,7 +155,7 @@ async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
|
||||||
for i in 0..n_msg {
|
for i in 0..n_msg {
|
||||||
// In case of any websocket error, we exit.
|
// In case of any websocket error, we exit.
|
||||||
if sender
|
if sender
|
||||||
.send(Message::Text(format!("Server message {i} ...")))
|
.send(Message::Text(format!("Server message {i} ...").into()))
|
||||||
.await
|
.await
|
||||||
.is_err()
|
.is_err()
|
||||||
{
|
{
|
||||||
|
@ -165,7 +169,7 @@ async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
|
||||||
if let Err(e) = sender
|
if let Err(e) = sender
|
||||||
.send(Message::Close(Some(CloseFrame {
|
.send(Message::Close(Some(CloseFrame {
|
||||||
code: axum::extract::ws::close_code::NORMAL,
|
code: axum::extract::ws::close_code::NORMAL,
|
||||||
reason: Cow::from("Goodbye"),
|
reason: Utf8Bytes::from_static("Goodbye"),
|
||||||
})))
|
})))
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in a new issue