mirror of
https://github.com/tokio-rs/axum.git
synced 2025-03-08 08:56:24 +01:00
Change WebSocket API to use an extractor (#121)
Fixes https://github.com/tokio-rs/axum/issues/111 Example usage: ```rust use axum::{ prelude::*, extract::ws::{WebSocketUpgrade, WebSocket}, response::IntoResponse, }; let app = route("/ws", get(handler)); async fn handler(ws: WebSocketUpgrade) -> impl IntoResponse { ws.on_upgrade(handle_socket) } async fn handle_socket(mut socket: WebSocket) { while let Some(msg) = socket.recv().await { let msg = if let Ok(msg) = msg { msg } else { // client disconnected return; }; if socket.send(msg).await.is_err() { // client disconnected return; } } } ```
This commit is contained in:
parent
404a3b5e8a
commit
4194cf70da
8 changed files with 619 additions and 652 deletions
|
@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
## Breaking changes
|
||||
|
||||
- Change WebSocket API to use an extractor ([#121](https://github.com/tokio-rs/axum/pull/121))
|
||||
- Add `RoutingDsl::or` for combining routes. ([#108](https://github.com/tokio-rs/axum/pull/108))
|
||||
- Ensure a `HandleError` service created from `axum::ServiceExt::handle_error`
|
||||
_does not_ implement `RoutingDsl` as that could lead to confusing routing
|
||||
|
|
|
@ -14,9 +14,9 @@ use futures::{sink::SinkExt, stream::StreamExt};
|
|||
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
|
||||
use axum::prelude::*;
|
||||
use axum::response::Html;
|
||||
use axum::ws::{ws, Message, WebSocket};
|
||||
use axum::response::{Html, IntoResponse};
|
||||
use axum::AddExtensionLayer;
|
||||
|
||||
// Our shared state
|
||||
|
@ -33,7 +33,7 @@ async fn main() {
|
|||
let app_state = Arc::new(AppState { user_set, tx });
|
||||
|
||||
let app = route("/", get(index))
|
||||
.route("/websocket", ws(websocket))
|
||||
.route("/websocket", get(websocket_handler))
|
||||
.layer(AddExtensionLayer::new(app_state));
|
||||
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
|
@ -44,10 +44,14 @@ async fn main() {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
async fn websocket(
|
||||
stream: WebSocket,
|
||||
async fn websocket_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
extract::Extension(state): extract::Extension<Arc<AppState>>,
|
||||
) {
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(|socket| websocket(socket, state))
|
||||
}
|
||||
|
||||
async fn websocket(stream: WebSocket, state: Arc<AppState>) {
|
||||
// By splitting we can send and receive at the same time.
|
||||
let (mut sender, mut receiver) = stream.split();
|
||||
|
||||
|
|
|
@ -7,11 +7,14 @@
|
|||
//! ```
|
||||
|
||||
use axum::{
|
||||
extract::TypedHeader,
|
||||
extract::{
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
TypedHeader,
|
||||
},
|
||||
prelude::*,
|
||||
response::IntoResponse,
|
||||
routing::nest,
|
||||
service::ServiceExt,
|
||||
ws::{ws, Message, WebSocket},
|
||||
};
|
||||
use http::StatusCode;
|
||||
use std::net::SocketAddr;
|
||||
|
@ -44,7 +47,7 @@ async fn main() {
|
|||
)
|
||||
// routes are matched from bottom to top, so we have to put `nest` at the
|
||||
// top since it matches all routes
|
||||
.route("/ws", ws(handle_socket))
|
||||
.route("/ws", get(ws_handler))
|
||||
// logging so we can see whats going on
|
||||
.layer(
|
||||
TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true)),
|
||||
|
@ -59,15 +62,18 @@ async fn main() {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
async fn handle_socket(
|
||||
mut socket: WebSocket,
|
||||
// websocket handlers can also use extractors
|
||||
async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
) {
|
||||
) -> impl IntoResponse {
|
||||
if let Some(TypedHeader(user_agent)) = user_agent {
|
||||
println!("`{}` connected", user_agent.as_str());
|
||||
}
|
||||
|
||||
ws.on_upgrade(handle_socket)
|
||||
}
|
||||
|
||||
async fn handle_socket(mut socket: WebSocket) {
|
||||
if let Some(msg) = socket.recv().await {
|
||||
if let Ok(msg) = msg {
|
||||
println!("Client says: {:?}", msg);
|
||||
|
|
|
@ -254,6 +254,10 @@ pub mod connect_info;
|
|||
pub mod extractor_middleware;
|
||||
pub mod rejection;
|
||||
|
||||
#[cfg(feature = "ws")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
|
||||
pub mod ws;
|
||||
|
||||
mod content_length_limit;
|
||||
mod extension;
|
||||
mod form;
|
||||
|
@ -292,6 +296,11 @@ pub mod multipart;
|
|||
#[doc(inline)]
|
||||
pub use self::multipart::Multipart;
|
||||
|
||||
#[cfg(feature = "ws")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
|
||||
#[doc(inline)]
|
||||
pub use self::ws::WebSocketUpgrade;
|
||||
|
||||
#[cfg(feature = "headers")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
|
||||
mod typed_header;
|
||||
|
|
586
src/extract/ws.rs
Normal file
586
src/extract/ws.rs
Normal file
|
@ -0,0 +1,586 @@
|
|||
//! Handle WebSocket connections.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```
|
||||
//! use axum::{
|
||||
//! prelude::*,
|
||||
//! extract::ws::{WebSocketUpgrade, WebSocket},
|
||||
//! response::IntoResponse,
|
||||
//! };
|
||||
//!
|
||||
//! let app = route("/ws", get(handler));
|
||||
//!
|
||||
//! async fn handler(ws: WebSocketUpgrade) -> impl IntoResponse {
|
||||
//! ws.on_upgrade(handle_socket)
|
||||
//! }
|
||||
//!
|
||||
//! async fn handle_socket(mut socket: WebSocket) {
|
||||
//! while let Some(msg) = socket.recv().await {
|
||||
//! let msg = if let Ok(msg) = msg {
|
||||
//! msg
|
||||
//! } else {
|
||||
//! // client disconnected
|
||||
//! return;
|
||||
//! };
|
||||
//!
|
||||
//! if socket.send(msg).await.is_err() {
|
||||
//! // client disconnected
|
||||
//! return;
|
||||
//! }
|
||||
//! }
|
||||
//! }
|
||||
//! # async {
|
||||
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
//! # };
|
||||
//! ```
|
||||
|
||||
use self::rejection::*;
|
||||
use super::{rejection::*, FromRequest, RequestParts};
|
||||
use crate::response::IntoResponse;
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use futures_util::{
|
||||
sink::{Sink, SinkExt},
|
||||
stream::{Stream, StreamExt},
|
||||
};
|
||||
use http::{
|
||||
header::{self, HeaderName, HeaderValue},
|
||||
Method, Response, StatusCode,
|
||||
};
|
||||
use hyper::{
|
||||
upgrade::{OnUpgrade, Upgraded},
|
||||
Body,
|
||||
};
|
||||
use sha1::{Digest, Sha1};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
fmt,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio_tungstenite::{
|
||||
tungstenite::protocol::{self, WebSocketConfig},
|
||||
WebSocketStream,
|
||||
};
|
||||
use tower::BoxError;
|
||||
|
||||
/// Extractor for establishing WebSocket connections.
|
||||
///
|
||||
/// Note: This extractor requires the request method to be `GET` so it should
|
||||
/// always be used with [`get`](crate::handler::get). Requests with other methods will be
|
||||
/// rejected.
|
||||
///
|
||||
/// See the [module docs](self) for an example.
|
||||
#[derive(Debug)]
|
||||
pub struct WebSocketUpgrade {
|
||||
config: WebSocketConfig,
|
||||
protocols: Option<Box<[Cow<'static, str>]>>,
|
||||
sec_websocket_key: HeaderValue,
|
||||
on_upgrade: OnUpgrade,
|
||||
sec_websocket_protocol: Option<HeaderValue>,
|
||||
}
|
||||
|
||||
impl WebSocketUpgrade {
|
||||
/// Set the size of the internal message send queue.
|
||||
pub fn max_send_queue(mut self, max: usize) -> Self {
|
||||
self.config.max_send_queue = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum message size (defaults to 64 megabytes)
|
||||
pub fn max_message_size(mut self, max: usize) -> Self {
|
||||
self.config.max_message_size = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum frame size (defaults to 16 megabytes)
|
||||
pub fn max_frame_size(mut self, max: usize) -> Self {
|
||||
self.config.max_frame_size = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the known protocols.
|
||||
///
|
||||
/// If the protocol name specified by `Sec-WebSocket-Protocol` header
|
||||
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and return the protocol name.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use axum::{
|
||||
/// prelude::*,
|
||||
/// extract::ws::{WebSocketUpgrade, WebSocket},
|
||||
/// response::IntoResponse,
|
||||
/// };
|
||||
///
|
||||
/// let app = route("/ws", get(handler));
|
||||
///
|
||||
/// async fn handler(ws: WebSocketUpgrade) -> impl IntoResponse {
|
||||
/// ws.protocols(["graphql-ws", "graphql-transport-ws"])
|
||||
/// .on_upgrade(|socket| async {
|
||||
/// // ...
|
||||
/// })
|
||||
/// }
|
||||
/// # async {
|
||||
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
pub fn protocols<I>(mut self, protocols: I) -> Self
|
||||
where
|
||||
I: IntoIterator,
|
||||
I::Item: Into<Cow<'static, str>>,
|
||||
{
|
||||
self.protocols = Some(
|
||||
protocols
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect::<Vec<_>>()
|
||||
.into(),
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
/// Finalize upgrading the connection and call the provided callback with
|
||||
/// the stream.
|
||||
///
|
||||
/// When using `WebSocketUpgrade`, the response produced by this method
|
||||
/// should be returned from the handler. See the [module docs](self) for an
|
||||
/// example.
|
||||
pub fn on_upgrade<F, Fut>(self, callback: F) -> impl IntoResponse
|
||||
where
|
||||
F: FnOnce(WebSocket) -> Fut + Send + 'static,
|
||||
Fut: Future + Send + 'static,
|
||||
{
|
||||
WebSocketUpgradeResponse {
|
||||
extractor: self,
|
||||
callback,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<B> FromRequest<B> for WebSocketUpgrade
|
||||
where
|
||||
B: Send,
|
||||
{
|
||||
type Rejection = WebSocketUpgradeRejection;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
if req.method().ok_or(MethodAlreadyExtracted)? != Method::GET {
|
||||
return Err(MethodNotGet.into());
|
||||
}
|
||||
|
||||
if !header_contains(req, header::CONNECTION, "upgrade")? {
|
||||
return Err(InvalidConnectionHeader.into());
|
||||
}
|
||||
|
||||
if !header_eq(req, header::UPGRADE, "websocket")? {
|
||||
return Err(InvalidUpgradeHeader.into());
|
||||
}
|
||||
|
||||
if !header_eq(req, header::SEC_WEBSOCKET_VERSION, "13")? {
|
||||
return Err(InvalidWebsocketVersionHeader.into());
|
||||
}
|
||||
|
||||
let sec_websocket_key = if let Some(key) = req
|
||||
.headers_mut()
|
||||
.ok_or(HeadersAlreadyExtracted)?
|
||||
.remove(header::SEC_WEBSOCKET_KEY)
|
||||
{
|
||||
key
|
||||
} else {
|
||||
return Err(WebsocketKeyHeaderMissing.into());
|
||||
};
|
||||
|
||||
let on_upgrade = req
|
||||
.extensions_mut()
|
||||
.ok_or(ExtensionsAlreadyExtracted)?
|
||||
.remove::<OnUpgrade>()
|
||||
.unwrap();
|
||||
|
||||
let sec_websocket_protocol = req
|
||||
.headers()
|
||||
.ok_or(HeadersAlreadyExtracted)?
|
||||
.get(header::SEC_WEBSOCKET_PROTOCOL)
|
||||
.cloned();
|
||||
|
||||
Ok(Self {
|
||||
config: Default::default(),
|
||||
protocols: None,
|
||||
sec_websocket_key,
|
||||
on_upgrade,
|
||||
sec_websocket_protocol,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn header_eq<B>(
|
||||
req: &RequestParts<B>,
|
||||
key: HeaderName,
|
||||
value: &'static str,
|
||||
) -> Result<bool, HeadersAlreadyExtracted> {
|
||||
if let Some(header) = req.headers().ok_or(HeadersAlreadyExtracted)?.get(&key) {
|
||||
Ok(header.as_bytes().eq_ignore_ascii_case(value.as_bytes()))
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
fn header_contains<B>(
|
||||
req: &RequestParts<B>,
|
||||
key: HeaderName,
|
||||
value: &'static str,
|
||||
) -> Result<bool, HeadersAlreadyExtracted> {
|
||||
let header = if let Some(header) = req.headers().ok_or(HeadersAlreadyExtracted)?.get(&key) {
|
||||
header
|
||||
} else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
|
||||
Ok(header.to_ascii_lowercase().contains(value))
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
struct WebSocketUpgradeResponse<F> {
|
||||
extractor: WebSocketUpgrade,
|
||||
callback: F,
|
||||
}
|
||||
|
||||
impl<F, Fut> IntoResponse for WebSocketUpgradeResponse<F>
|
||||
where
|
||||
F: FnOnce(WebSocket) -> Fut + Send + 'static,
|
||||
Fut: Future + Send + 'static,
|
||||
{
|
||||
fn into_response(self) -> Response<Body> {
|
||||
// check requested protocols
|
||||
let protocol = self
|
||||
.extractor
|
||||
.sec_websocket_protocol
|
||||
.as_ref()
|
||||
.and_then(|req_protocols| {
|
||||
let req_protocols = req_protocols.to_str().ok()?;
|
||||
let protocols = self.extractor.protocols.as_ref()?;
|
||||
req_protocols
|
||||
.split(',')
|
||||
.map(|req_p| req_p.trim())
|
||||
.find(|req_p| protocols.iter().any(|p| p == req_p))
|
||||
});
|
||||
|
||||
let protocol = match protocol {
|
||||
Some(protocol) => {
|
||||
if let Ok(protocol) = HeaderValue::from_str(protocol) {
|
||||
Some(protocol)
|
||||
} else {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Protocol` header is invalid",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let callback = self.callback;
|
||||
let on_upgrade = self.extractor.on_upgrade;
|
||||
let config = self.extractor.config;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let upgraded = on_upgrade.await.expect("connection upgrade failed");
|
||||
let socket =
|
||||
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
|
||||
.await;
|
||||
let socket = WebSocket { inner: socket };
|
||||
callback(socket).await;
|
||||
});
|
||||
|
||||
let mut builder = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(
|
||||
header::CONNECTION,
|
||||
HeaderValue::from_str("upgrade").unwrap(),
|
||||
)
|
||||
.header(header::UPGRADE, HeaderValue::from_str("websocket").unwrap())
|
||||
.header(
|
||||
header::SEC_WEBSOCKET_ACCEPT,
|
||||
sign(self.extractor.sec_websocket_key.as_bytes()),
|
||||
);
|
||||
|
||||
if let Some(protocol) = protocol {
|
||||
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
||||
}
|
||||
|
||||
builder.body(Body::empty()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream of WebSocket messages.
|
||||
#[derive(Debug)]
|
||||
pub struct WebSocket {
|
||||
inner: WebSocketStream<Upgraded>,
|
||||
}
|
||||
|
||||
impl WebSocket {
|
||||
/// Receive another message.
|
||||
///
|
||||
/// Returns `None` if the stream stream has closed.
|
||||
pub async fn recv(&mut self) -> Option<Result<Message, BoxError>> {
|
||||
self.next().await
|
||||
}
|
||||
|
||||
/// Send a message.
|
||||
pub async fn send(&mut self, msg: Message) -> Result<(), BoxError> {
|
||||
self.inner.send(msg.inner).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Gracefully close this WebSocket.
|
||||
pub async fn close(mut self) -> Result<(), BoxError> {
|
||||
self.inner.close(None).await.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for WebSocket {
|
||||
type Item = Result<Message, BoxError>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.inner.poll_next_unpin(cx).map(|option_msg| {
|
||||
option_msg.map(|result_msg| {
|
||||
result_msg
|
||||
.map_err(Into::into)
|
||||
.map(|inner| Message { inner })
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<Message> for WebSocket {
|
||||
type Error = BoxError;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
|
||||
Pin::new(&mut self.inner)
|
||||
.start_send(item.inner)
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
/// A WebSocket message.
|
||||
#[derive(Eq, PartialEq, Clone)]
|
||||
pub struct Message {
|
||||
inner: protocol::Message,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
/// Construct a new Text `Message`.
|
||||
pub fn text<S>(s: S) -> Message
|
||||
where
|
||||
S: Into<String>,
|
||||
{
|
||||
Message {
|
||||
inner: protocol::Message::text(s),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a new Binary `Message`.
|
||||
pub fn binary<V>(v: V) -> Message
|
||||
where
|
||||
V: Into<Vec<u8>>,
|
||||
{
|
||||
Message {
|
||||
inner: protocol::Message::binary(v),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a new Ping `Message`.
|
||||
pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message {
|
||||
Message {
|
||||
inner: protocol::Message::Ping(v.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a new Pong `Message`.
|
||||
///
|
||||
/// Note that one rarely needs to manually construct a Pong message because
|
||||
/// the underlying tungstenite socket automatically responds to the Ping
|
||||
/// messages it receives. Manual construction might still be useful in some
|
||||
/// cases like in tests or to send unidirectional heartbeats.
|
||||
pub fn pong<V: Into<Vec<u8>>>(v: V) -> Message {
|
||||
Message {
|
||||
inner: protocol::Message::Pong(v.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct the default Close `Message`.
|
||||
pub fn close() -> Message {
|
||||
Message {
|
||||
inner: protocol::Message::Close(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a Close `Message` with a code and reason.
|
||||
pub fn close_with<C, R>(code: C, reason: R) -> Message
|
||||
where
|
||||
C: Into<u16>,
|
||||
R: Into<Cow<'static, str>>,
|
||||
{
|
||||
Message {
|
||||
inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
|
||||
code: protocol::frame::coding::CloseCode::from(code.into()),
|
||||
reason: reason.into(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if this message is a Text message.
|
||||
pub fn is_text(&self) -> bool {
|
||||
self.inner.is_text()
|
||||
}
|
||||
|
||||
/// Returns true if this message is a Binary message.
|
||||
pub fn is_binary(&self) -> bool {
|
||||
self.inner.is_binary()
|
||||
}
|
||||
|
||||
/// Returns true if this message a is a Close message.
|
||||
pub fn is_close(&self) -> bool {
|
||||
self.inner.is_close()
|
||||
}
|
||||
|
||||
/// Returns true if this message is a Ping message.
|
||||
pub fn is_ping(&self) -> bool {
|
||||
self.inner.is_ping()
|
||||
}
|
||||
|
||||
/// Returns true if this message is a Pong message.
|
||||
pub fn is_pong(&self) -> bool {
|
||||
self.inner.is_pong()
|
||||
}
|
||||
|
||||
/// Try to get the close frame (close code and reason)
|
||||
pub fn close_frame(&self) -> Option<(u16, &str)> {
|
||||
if let protocol::Message::Close(Some(close_frame)) = &self.inner {
|
||||
Some((close_frame.code.into(), close_frame.reason.as_ref()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to get a reference to the string text, if this is a Text message.
|
||||
pub fn to_str(&self) -> Option<&str> {
|
||||
if let protocol::Message::Text(s) = &self.inner {
|
||||
Some(s)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the bytes of this message, if the message can contain data.
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
match self.inner {
|
||||
protocol::Message::Text(ref s) => s.as_bytes(),
|
||||
protocol::Message::Binary(ref v) => v,
|
||||
protocol::Message::Ping(ref v) => v,
|
||||
protocol::Message::Pong(ref v) => v,
|
||||
protocol::Message::Close(_) => &[],
|
||||
}
|
||||
}
|
||||
|
||||
/// Destructure this message into binary data.
|
||||
pub fn into_bytes(self) -> Vec<u8> {
|
||||
self.inner.into_data()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Message {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
self.inner.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Message> for Vec<u8> {
|
||||
fn from(msg: Message) -> Self {
|
||||
msg.into_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
fn sign(key: &[u8]) -> HeaderValue {
|
||||
let mut sha1 = Sha1::default();
|
||||
sha1.update(key);
|
||||
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
|
||||
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
|
||||
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
|
||||
}
|
||||
|
||||
pub mod rejection {
|
||||
//! WebSocket specific rejections.
|
||||
|
||||
use crate::extract::rejection::*;
|
||||
|
||||
define_rejection! {
|
||||
#[status = METHOD_NOT_ALLOWED]
|
||||
#[body = "Request method must be `GET`"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct MethodNotGet;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "Connection header did not include 'upgrade'"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct InvalidConnectionHeader;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "`Upgrade` header did not include 'websocket'"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct InvalidUpgradeHeader;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "`Sec-Websocket-Version` header did not include '13'"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct InvalidWebsocketVersionHeader;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "`Sec-Websocket-Key` header missing"]
|
||||
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub struct WebsocketKeyHeaderMissing;
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
///
|
||||
/// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
|
||||
/// extractor can fail.
|
||||
pub enum WebSocketUpgradeRejection {
|
||||
MethodNotGet,
|
||||
InvalidConnectionHeader,
|
||||
InvalidUpgradeHeader,
|
||||
InvalidWebsocketVersionHeader,
|
||||
WebsocketKeyHeaderMissing,
|
||||
MethodAlreadyExtracted,
|
||||
HeadersAlreadyExtracted,
|
||||
ExtensionsAlreadyExtracted,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -727,10 +727,6 @@ pub mod routing;
|
|||
pub mod service;
|
||||
pub mod sse;
|
||||
|
||||
#[cfg(feature = "ws")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
|
||||
pub mod ws;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
//! Future types.
|
||||
|
||||
use crate::body::BoxBody;
|
||||
use http::Response;
|
||||
use std::convert::Infallible;
|
||||
|
||||
opaque_future! {
|
||||
/// Response future for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
pub type ResponseFuture = futures_util::future::BoxFuture<'static, Result<Response<BoxBody>, Infallible>>;
|
||||
}
|
625
src/ws/mod.rs
625
src/ws/mod.rs
|
@ -1,625 +0,0 @@
|
|||
//! Handle websocket connections.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```
|
||||
//! use axum::{prelude::*, ws::{ws, WebSocket}};
|
||||
//!
|
||||
//! let app = route("/ws", ws(handle_socket));
|
||||
//!
|
||||
//! async fn handle_socket(mut socket: WebSocket) {
|
||||
//! while let Some(msg) = socket.recv().await {
|
||||
//! let msg = msg.unwrap();
|
||||
//! socket.send(msg).await.unwrap();
|
||||
//! }
|
||||
//! }
|
||||
//! # async {
|
||||
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
//! # };
|
||||
//! ```
|
||||
//!
|
||||
//! Websocket handlers can also use extractors, however the first function
|
||||
//! argument must be of type [`WebSocket`]:
|
||||
//!
|
||||
//! ```
|
||||
//! use axum::{prelude::*, extract::{RequestParts, FromRequest}, ws::{ws, WebSocket}};
|
||||
//! use http::{HeaderMap, StatusCode};
|
||||
//!
|
||||
//! /// An extractor that authorizes requests.
|
||||
//! struct RequireAuth;
|
||||
//!
|
||||
//! #[async_trait::async_trait]
|
||||
//! impl<B> FromRequest<B> for RequireAuth
|
||||
//! where
|
||||
//! B: Send,
|
||||
//! {
|
||||
//! type Rejection = StatusCode;
|
||||
//!
|
||||
//! async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
//! # unimplemented!()
|
||||
//! // Put your auth logic here...
|
||||
//! }
|
||||
//! }
|
||||
//!
|
||||
//! let app = route("/ws", ws(handle_socket));
|
||||
//!
|
||||
//! async fn handle_socket(
|
||||
//! mut socket: WebSocket,
|
||||
//! // Run `RequireAuth` for each request before upgrading.
|
||||
//! _auth: RequireAuth,
|
||||
//! ) {
|
||||
//! // ...
|
||||
//! }
|
||||
//! # async {
|
||||
//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||
//! # };
|
||||
//! ```
|
||||
|
||||
use crate::body::{box_body, BoxBody};
|
||||
use crate::extract::{FromRequest, RequestParts};
|
||||
use crate::response::IntoResponse;
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use future::ResponseFuture;
|
||||
use futures_util::{
|
||||
sink::{Sink, SinkExt},
|
||||
stream::{Stream, StreamExt},
|
||||
};
|
||||
use http::{
|
||||
header::{self, HeaderName},
|
||||
HeaderValue, Request, Response, StatusCode,
|
||||
};
|
||||
use http_body::Full;
|
||||
use hyper::upgrade::{OnUpgrade, Upgraded};
|
||||
use sha1::{Digest, Sha1};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
borrow::Cow, convert::Infallible, fmt, future::Future, marker::PhantomData, task::Context,
|
||||
task::Poll,
|
||||
};
|
||||
use tokio_tungstenite::{
|
||||
tungstenite::protocol::{self, WebSocketConfig},
|
||||
WebSocketStream,
|
||||
};
|
||||
use tower::{BoxError, Service};
|
||||
|
||||
pub mod future;
|
||||
|
||||
/// Create a new [`WebSocketUpgrade`] service that will call the closure with
|
||||
/// each connection.
|
||||
///
|
||||
/// See the [module docs](crate::ws) for more details.
|
||||
pub fn ws<F, B, T>(callback: F) -> WebSocketUpgrade<F, B, T>
|
||||
where
|
||||
F: WebSocketHandler<B, T>,
|
||||
{
|
||||
WebSocketUpgrade {
|
||||
callback,
|
||||
config: WebSocketConfig::default(),
|
||||
protocols: Vec::new().into(),
|
||||
_request_body: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for async functions that can be used to handle websocket requests.
|
||||
///
|
||||
/// You shouldn't need to depend on this trait directly. It is automatically
|
||||
/// implemented to closures of the right types.
|
||||
///
|
||||
/// See the [module docs](crate::ws) for more details.
|
||||
#[async_trait]
|
||||
pub trait WebSocketHandler<B, In>: Sized {
|
||||
// This seals the trait. We cannot use the regular "sealed super trait"
|
||||
// approach due to coherence.
|
||||
#[doc(hidden)]
|
||||
type Sealed: crate::handler::sealed::HiddentTrait;
|
||||
|
||||
/// Call the handler with the given websocket stream and input parsed by
|
||||
/// extractors.
|
||||
async fn call(self, stream: WebSocket, input: In);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<F, Fut, B> WebSocketHandler<B, ()> for F
|
||||
where
|
||||
F: FnOnce(WebSocket) -> Fut + Send,
|
||||
Fut: Future<Output = ()> + Send,
|
||||
B: Send,
|
||||
{
|
||||
type Sealed = crate::handler::sealed::Hidden;
|
||||
|
||||
async fn call(self, stream: WebSocket, _: ()) {
|
||||
self(stream).await
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_ws_handler {
|
||||
() => {
|
||||
};
|
||||
|
||||
( $head:ident, $($tail:ident),* $(,)? ) => {
|
||||
#[async_trait]
|
||||
#[allow(non_snake_case)]
|
||||
impl<F, Fut, B, $head, $($tail,)*> WebSocketHandler<B, ($head, $($tail,)*)> for F
|
||||
where
|
||||
B: Send,
|
||||
$head: FromRequest<B> + Send + 'static,
|
||||
$( $tail: FromRequest<B> + Send + 'static, )*
|
||||
F: FnOnce(WebSocket, $head, $($tail,)*) -> Fut + Send,
|
||||
Fut: Future<Output = ()> + Send,
|
||||
{
|
||||
type Sealed = crate::handler::sealed::Hidden;
|
||||
|
||||
async fn call(self, stream: WebSocket, ($head, $($tail,)*): ($head, $($tail,)*)) {
|
||||
self(stream, $head, $($tail,)*).await
|
||||
}
|
||||
}
|
||||
|
||||
impl_ws_handler!($($tail,)*);
|
||||
};
|
||||
}
|
||||
|
||||
impl_ws_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
|
||||
|
||||
/// [`Service`] that upgrades connections to websockets and spawns a task to
|
||||
/// handle the stream.
|
||||
///
|
||||
/// Created with [`ws`].
|
||||
///
|
||||
/// See the [module docs](crate::ws) for more details.
|
||||
pub struct WebSocketUpgrade<F, B, T> {
|
||||
callback: F,
|
||||
config: WebSocketConfig,
|
||||
protocols: Arc<[Cow<'static, str>]>,
|
||||
_request_body: PhantomData<fn() -> (B, T)>,
|
||||
}
|
||||
|
||||
impl<F, B, T> Clone for WebSocketUpgrade<F, B, T>
|
||||
where
|
||||
F: Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
callback: self.callback.clone(),
|
||||
config: self.config,
|
||||
protocols: self.protocols.clone(),
|
||||
_request_body: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, B, T> fmt::Debug for WebSocketUpgrade<F, B, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("WebSocketUpgrade")
|
||||
.field("callback", &format_args!("{}", std::any::type_name::<F>()))
|
||||
.field("config", &self.config)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, B, T> WebSocketUpgrade<F, B, T> {
|
||||
/// Set the size of the internal message send queue.
|
||||
pub fn max_send_queue(mut self, max: usize) -> Self {
|
||||
self.config.max_send_queue = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum message size (defaults to 64 megabytes)
|
||||
pub fn max_message_size(mut self, max: usize) -> Self {
|
||||
self.config.max_message_size = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum frame size (defaults to 16 megabytes)
|
||||
pub fn max_frame_size(mut self, max: usize) -> Self {
|
||||
self.config.max_frame_size = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the known protocols.
|
||||
///
|
||||
/// If the protocol name specified by `Sec-WebSocket-Protocol` header
|
||||
/// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and return the protocol name.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use axum::prelude::*;
|
||||
/// # use axum::ws::{ws, WebSocket};
|
||||
/// # use std::net::SocketAddr;
|
||||
/// #
|
||||
/// # async fn handle_socket(socket: WebSocket) {
|
||||
/// # todo!()
|
||||
/// # }
|
||||
/// #
|
||||
/// # #[tokio::main]
|
||||
/// # async fn main() {
|
||||
/// let app = route("/ws", ws(handle_socket).protocols(["graphql-ws", "graphql-transport-ws"]));
|
||||
/// #
|
||||
/// # let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
/// # axum::Server::bind(&addr)
|
||||
/// # .serve(app.into_make_service())
|
||||
/// # .await
|
||||
/// # .unwrap();
|
||||
/// # }
|
||||
///
|
||||
/// ```
|
||||
pub fn protocols<I>(mut self, protocols: I) -> Self
|
||||
where
|
||||
I: IntoIterator,
|
||||
I::Item: Into<Cow<'static, str>>,
|
||||
{
|
||||
self.protocols = protocols
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect::<Vec<_>>()
|
||||
.into();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<ReqBody, F, T> Service<Request<ReqBody>> for WebSocketUpgrade<F, ReqBody, T>
|
||||
where
|
||||
F: WebSocketHandler<ReqBody, T> + Clone + Send + 'static,
|
||||
T: FromRequest<ReqBody> + Send + 'static,
|
||||
ReqBody: Send + 'static,
|
||||
{
|
||||
type Response = Response<BoxBody>;
|
||||
type Error = Infallible;
|
||||
type Future = ResponseFuture;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
|
||||
let this = self.clone();
|
||||
let protocols = self.protocols.clone();
|
||||
|
||||
ResponseFuture {
|
||||
future: Box::pin(async move {
|
||||
if req.method() != http::Method::GET {
|
||||
return response(StatusCode::NOT_FOUND, "Request method must be `GET`");
|
||||
}
|
||||
|
||||
if !header_contains(&req, header::CONNECTION, "upgrade") {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Connection header did not include 'upgrade'",
|
||||
);
|
||||
}
|
||||
|
||||
if !header_eq(&req, header::UPGRADE, "websocket") {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Upgrade` header did not include 'websocket'",
|
||||
);
|
||||
}
|
||||
|
||||
if !header_eq(&req, header::SEC_WEBSOCKET_VERSION, "13") {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Version` header did not include '13'",
|
||||
);
|
||||
}
|
||||
|
||||
// check requested protocols
|
||||
let protocol =
|
||||
req.headers()
|
||||
.get(&header::SEC_WEBSOCKET_PROTOCOL)
|
||||
.and_then(|req_protocols| {
|
||||
let req_protocols = req_protocols.to_str().ok()?;
|
||||
req_protocols
|
||||
.split(',')
|
||||
.map(|req_p| req_p.trim())
|
||||
.find(|req_p| protocols.iter().any(|p| p == req_p))
|
||||
});
|
||||
let protocol = match protocol {
|
||||
Some(protocol) => {
|
||||
if let Ok(protocol) = HeaderValue::from_str(protocol) {
|
||||
Some(protocol)
|
||||
} else {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Protocol` header is invalid",
|
||||
);
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
|
||||
key
|
||||
} else {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Key` header missing",
|
||||
);
|
||||
};
|
||||
|
||||
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
|
||||
|
||||
let config = this.config;
|
||||
let callback = this.callback.clone();
|
||||
|
||||
let mut req = RequestParts::new(req);
|
||||
let input = match T::from_request(&mut req).await {
|
||||
Ok(input) => input,
|
||||
Err(rejection) => {
|
||||
let res = rejection.into_response().map(box_body);
|
||||
return Ok(res);
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
let upgraded = on_upgrade.await.unwrap();
|
||||
let socket = WebSocketStream::from_raw_socket(
|
||||
upgraded,
|
||||
protocol::Role::Server,
|
||||
Some(config),
|
||||
)
|
||||
.await;
|
||||
let socket = WebSocket { inner: socket };
|
||||
callback.call(socket, input).await;
|
||||
});
|
||||
|
||||
let mut builder = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(
|
||||
http::header::CONNECTION,
|
||||
HeaderValue::from_str("upgrade").unwrap(),
|
||||
)
|
||||
.header(
|
||||
http::header::UPGRADE,
|
||||
HeaderValue::from_str("websocket").unwrap(),
|
||||
)
|
||||
.header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()));
|
||||
if let Some(protocol) = protocol {
|
||||
builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
||||
}
|
||||
let res = builder.body(box_body(Full::new(Bytes::new()))).unwrap();
|
||||
|
||||
Ok(res)
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn response<E>(status: StatusCode, body: &'static str) -> Result<Response<BoxBody>, E> {
|
||||
let res = Response::builder()
|
||||
.status(status)
|
||||
.body(box_body(Full::from(body)))
|
||||
.unwrap();
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn header_eq<B>(req: &Request<B>, key: HeaderName, value: &'static str) -> bool {
|
||||
if let Some(header) = req.headers().get(&key) {
|
||||
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn header_contains<B>(req: &Request<B>, key: HeaderName, value: &'static str) -> bool {
|
||||
let header = if let Some(header) = req.headers().get(&key) {
|
||||
header
|
||||
} else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
|
||||
header.to_ascii_lowercase().contains(value)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn sign(key: &[u8]) -> HeaderValue {
|
||||
let mut sha1 = Sha1::default();
|
||||
sha1.update(key);
|
||||
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
|
||||
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
|
||||
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
|
||||
}
|
||||
|
||||
/// A stream of websocket messages.
|
||||
#[derive(Debug)]
|
||||
pub struct WebSocket {
|
||||
inner: WebSocketStream<Upgraded>,
|
||||
}
|
||||
|
||||
impl WebSocket {
|
||||
/// Receive another message.
|
||||
///
|
||||
/// Returns `None` if the stream stream has closed.
|
||||
pub async fn recv(&mut self) -> Option<Result<Message, BoxError>> {
|
||||
self.next().await
|
||||
}
|
||||
|
||||
/// Send a message.
|
||||
pub async fn send(&mut self, msg: Message) -> Result<(), BoxError> {
|
||||
self.inner.send(msg.inner).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Gracefully close this websocket.
|
||||
pub async fn close(mut self) -> Result<(), BoxError> {
|
||||
self.inner.close(None).await.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for WebSocket {
|
||||
type Item = Result<Message, BoxError>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.inner.poll_next_unpin(cx).map(|option_msg| {
|
||||
option_msg.map(|result_msg| {
|
||||
result_msg
|
||||
.map_err(Into::into)
|
||||
.map(|inner| Message { inner })
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<Message> for WebSocket {
|
||||
type Error = BoxError;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
|
||||
Pin::new(&mut self.inner)
|
||||
.start_send(item.inner)
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
/// A WebSocket message.
|
||||
#[derive(Eq, PartialEq, Clone)]
|
||||
pub struct Message {
|
||||
inner: protocol::Message,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
/// Construct a new Text `Message`.
|
||||
pub fn text<S>(s: S) -> Message
|
||||
where
|
||||
S: Into<String>,
|
||||
{
|
||||
Message {
|
||||
inner: protocol::Message::text(s),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a new Binary `Message`.
|
||||
pub fn binary<V>(v: V) -> Message
|
||||
where
|
||||
V: Into<Vec<u8>>,
|
||||
{
|
||||
Message {
|
||||
inner: protocol::Message::binary(v),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a new Ping `Message`.
|
||||
pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message {
|
||||
Message {
|
||||
inner: protocol::Message::Ping(v.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a new Pong `Message`.
|
||||
///
|
||||
/// Note that one rarely needs to manually construct a Pong message because
|
||||
/// the underlying tungstenite socket automatically responds to the Ping
|
||||
/// messages it receives. Manual construction might still be useful in some
|
||||
/// cases like in tests or to send unidirectional heartbeats.
|
||||
pub fn pong<V: Into<Vec<u8>>>(v: V) -> Message {
|
||||
Message {
|
||||
inner: protocol::Message::Pong(v.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct the default Close `Message`.
|
||||
pub fn close() -> Message {
|
||||
Message {
|
||||
inner: protocol::Message::Close(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a Close `Message` with a code and reason.
|
||||
pub fn close_with<C, R>(code: C, reason: R) -> Message
|
||||
where
|
||||
C: Into<u16>,
|
||||
R: Into<Cow<'static, str>>,
|
||||
{
|
||||
Message {
|
||||
inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
|
||||
code: protocol::frame::coding::CloseCode::from(code.into()),
|
||||
reason: reason.into(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if this message is a Text message.
|
||||
pub fn is_text(&self) -> bool {
|
||||
self.inner.is_text()
|
||||
}
|
||||
|
||||
/// Returns true if this message is a Binary message.
|
||||
pub fn is_binary(&self) -> bool {
|
||||
self.inner.is_binary()
|
||||
}
|
||||
|
||||
/// Returns true if this message a is a Close message.
|
||||
pub fn is_close(&self) -> bool {
|
||||
self.inner.is_close()
|
||||
}
|
||||
|
||||
/// Returns true if this message is a Ping message.
|
||||
pub fn is_ping(&self) -> bool {
|
||||
self.inner.is_ping()
|
||||
}
|
||||
|
||||
/// Returns true if this message is a Pong message.
|
||||
pub fn is_pong(&self) -> bool {
|
||||
self.inner.is_pong()
|
||||
}
|
||||
|
||||
/// Try to get the close frame (close code and reason)
|
||||
pub fn close_frame(&self) -> Option<(u16, &str)> {
|
||||
if let protocol::Message::Close(Some(close_frame)) = &self.inner {
|
||||
Some((close_frame.code.into(), close_frame.reason.as_ref()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to get a reference to the string text, if this is a Text message.
|
||||
pub fn to_str(&self) -> Option<&str> {
|
||||
if let protocol::Message::Text(s) = &self.inner {
|
||||
Some(s)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the bytes of this message, if the message can contain data.
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
match self.inner {
|
||||
protocol::Message::Text(ref s) => s.as_bytes(),
|
||||
protocol::Message::Binary(ref v) => v,
|
||||
protocol::Message::Ping(ref v) => v,
|
||||
protocol::Message::Pong(ref v) => v,
|
||||
protocol::Message::Close(_) => &[],
|
||||
}
|
||||
}
|
||||
|
||||
/// Destructure this message into binary data.
|
||||
pub fn into_bytes(self) -> Vec<u8> {
|
||||
self.inner.into_data()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Message {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
self.inner.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Message> for Vec<u8> {
|
||||
fn from(msg: Message) -> Self {
|
||||
msg.into_bytes()
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue