diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 1773f131..ad9078eb 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -57,6 +57,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 requests ([#734]) - **fixed:** Fix wrong `content-length` for `HEAD` requests to endpoints that returns chunked responses ([#755]) +- **changed:** Update to tokio-tungstenite 0.17 ([#791]) [#644]: https://github.com/tokio-rs/axum/pull/644 [#665]: https://github.com/tokio-rs/axum/pull/665 @@ -67,6 +68,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#734]: https://github.com/tokio-rs/axum/pull/734 [#755]: https://github.com/tokio-rs/axum/pull/755 [#783]: https://github.com/tokio-rs/axum/pull/783 +[#791]: https://github.com/tokio-rs/axum/pull/791 # 0.4.4 (13. January, 2022) diff --git a/axum/Cargo.toml b/axum/Cargo.toml index e4cc471c..ef0e276e 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -53,7 +53,7 @@ multer = { optional = true, version = "2.0.0" } serde_json = { version = "1.0", optional = true, features = ["raw_value"] } serde_urlencoded = { version = "0.7", optional = true } sha-1 = { optional = true, version = "0.10" } -tokio-tungstenite = { optional = true, version = "0.16" } +tokio-tungstenite = { optional = true, version = "0.17" } [dev-dependencies] futures = "0.3" diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 44dc8ae9..8923ad92 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -337,13 +337,17 @@ impl Stream for WebSocket { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_next_unpin(cx).map(|option_msg| { - option_msg.map(|result_msg| { - result_msg - .map_err(Error::new) - .map(Message::from_tungstenite) - }) - }) + loop { + match futures_util::ready!(self.inner.poll_next_unpin(cx)) { + Some(Ok(msg)) => { + if let Some(msg) = Message::from_tungstenite(msg) { + return Poll::Ready(Some(Ok(msg))); + } + } + Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))), + None => return Poll::Ready(None), + } + } } } @@ -444,17 +448,20 @@ impl Message { } } - fn from_tungstenite(message: ts::Message) -> Self { + fn from_tungstenite(message: ts::Message) -> Option { match message { - ts::Message::Text(text) => Self::Text(text), - ts::Message::Binary(binary) => Self::Binary(binary), - ts::Message::Ping(ping) => Self::Ping(ping), - ts::Message::Pong(pong) => Self::Pong(pong), - ts::Message::Close(Some(close)) => Self::Close(Some(CloseFrame { + ts::Message::Text(text) => Some(Self::Text(text)), + ts::Message::Binary(binary) => Some(Self::Binary(binary)), + ts::Message::Ping(ping) => Some(Self::Ping(ping)), + ts::Message::Pong(pong) => Some(Self::Pong(pong)), + ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame { code: close.code.into(), reason: close.reason, - })), - ts::Message::Close(None) => Self::Close(None), + }))), + ts::Message::Close(None) => Some(Self::Close(None)), + // we can ignore `Frame` frames as recommended by the tungstenite maintainers + // https://github.com/snapview/tungstenite-rs/issues/268 + ts::Message::Frame(_) => None, } }