mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-29 11:42:47 +01:00
Add support for websockets (#3)
Basically a copy/paste of whats in warp. Example usage: ```rust use tower_web::{prelude::*, ws::{ws, WebSocket}}; let app = route("/ws", ws(handle_socket)); async fn handle_socket(mut socket: WebSocket) { while let Some(msg) = socket.recv().await { let msg = msg.unwrap(); socket.send(msg).await.unwrap(); } } ```
This commit is contained in:
parent
002e3f92b3
commit
c9c507aece
9 changed files with 518 additions and 2 deletions
21
Cargo.toml
21
Cargo.toml
|
@ -12,6 +12,9 @@ readme = "README.md"
|
||||||
repository = "https://github.com/davidpdrsn/tower-web"
|
repository = "https://github.com/davidpdrsn/tower-web"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
ws = ["tokio-tungstenite", "sha-1", "base64"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
bytes = "1.0"
|
bytes = "1.0"
|
||||||
|
@ -28,16 +31,32 @@ tokio = { version = "1", features = ["time"] }
|
||||||
tower = { version = "0.4", features = ["util", "buffer"] }
|
tower = { version = "0.4", features = ["util", "buffer"] }
|
||||||
tower-http = { version = "0.1", features = ["add-extension"] }
|
tower-http = { version = "0.1", features = ["add-extension"] }
|
||||||
|
|
||||||
|
# optional dependencies
|
||||||
|
tokio-tungstenite = { optional = true, version = "0.14" }
|
||||||
|
sha-1 = { optional = true, version = "0.9.6" }
|
||||||
|
base64 = { optional = true, version = "0.13" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
hyper = { version = "0.14", features = ["full"] }
|
hyper = { version = "0.14", features = ["full"] }
|
||||||
reqwest = { version = "0.11", features = ["json", "stream"] }
|
reqwest = { version = "0.11", features = ["json", "stream"] }
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] }
|
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] }
|
||||||
tower = { version = "0.4", features = ["util", "make", "timeout", "limit", "load-shed", "steer"] }
|
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = "0.2"
|
tracing-subscriber = "0.2"
|
||||||
uuid = "0.8"
|
uuid = "0.8"
|
||||||
|
|
||||||
|
[dev-dependencies.tower]
|
||||||
|
version = "0.4"
|
||||||
|
features = [
|
||||||
|
"util",
|
||||||
|
"make",
|
||||||
|
"timeout",
|
||||||
|
"limit",
|
||||||
|
"load-shed",
|
||||||
|
"steer",
|
||||||
|
"filter",
|
||||||
|
]
|
||||||
|
|
||||||
[dev-dependencies.tower-http]
|
[dev-dependencies.tower-http]
|
||||||
version = "0.1"
|
version = "0.1"
|
||||||
features = [
|
features = [
|
||||||
|
|
63
examples/websocket.rs
Normal file
63
examples/websocket.rs
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
//! Example websocket server.
|
||||||
|
//!
|
||||||
|
//! Run with
|
||||||
|
//!
|
||||||
|
//! ```
|
||||||
|
//! RUST_LOG=tower_http=debug,key_value_store=trace \
|
||||||
|
//! cargo run \
|
||||||
|
//! --features ws \
|
||||||
|
//! --example websocket
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use http::StatusCode;
|
||||||
|
use hyper::Server;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use tower::make::Shared;
|
||||||
|
use tower_http::{
|
||||||
|
services::ServeDir,
|
||||||
|
trace::{DefaultMakeSpan, TraceLayer},
|
||||||
|
};
|
||||||
|
use tower_web::{
|
||||||
|
prelude::*,
|
||||||
|
routing::nest,
|
||||||
|
service::ServiceExt,
|
||||||
|
ws::{ws, Message, WebSocket},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
|
// build our application with some routes
|
||||||
|
let app = nest(
|
||||||
|
"/",
|
||||||
|
ServeDir::new("examples/websocket")
|
||||||
|
.append_index_html_on_directories(true)
|
||||||
|
.handle_error(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string())),
|
||||||
|
)
|
||||||
|
// routes are matched from bottom to top, so we have to put `nest` at the
|
||||||
|
// top since it matches all routes
|
||||||
|
.route("/ws", ws(handle_socket))
|
||||||
|
// logging so we can see whats going on
|
||||||
|
.layer(
|
||||||
|
TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true)),
|
||||||
|
);
|
||||||
|
|
||||||
|
// run it with hyper
|
||||||
|
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||||
|
tracing::debug!("listening on {}", addr);
|
||||||
|
let server = Server::bind(&addr).serve(Shared::new(app));
|
||||||
|
server.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_socket(mut socket: WebSocket) {
|
||||||
|
if let Some(msg) = socket.recv().await {
|
||||||
|
let msg = msg.unwrap();
|
||||||
|
println!("Client says: {:?}", msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
loop {
|
||||||
|
socket.send(Message::text("Hi!")).await.unwrap();
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||||
|
}
|
||||||
|
}
|
1
examples/websocket/index.html
Normal file
1
examples/websocket/index.html
Normal file
|
@ -0,0 +1 @@
|
||||||
|
<script src='script.js'></script>
|
9
examples/websocket/script.js
Normal file
9
examples/websocket/script.js
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
const socket = new WebSocket('ws://localhost:3000/ws');
|
||||||
|
|
||||||
|
socket.addEventListener('open', function (event) {
|
||||||
|
socket.send('Hello Server!');
|
||||||
|
});
|
||||||
|
|
||||||
|
socket.addEventListener('message', function (event) {
|
||||||
|
console.log('Message from server ', event.data);
|
||||||
|
});
|
|
@ -610,6 +610,10 @@ pub mod response;
|
||||||
pub mod routing;
|
pub mod routing;
|
||||||
pub mod service;
|
pub mod service;
|
||||||
|
|
||||||
|
#[cfg(feature = "ws")]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
|
||||||
|
pub mod ws;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests;
|
mod tests;
|
||||||
|
|
||||||
|
|
|
@ -147,6 +147,17 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T> IntoResponse for (HeaderMap, T)
|
||||||
|
where
|
||||||
|
T: Into<Body>,
|
||||||
|
{
|
||||||
|
fn into_response(self) -> Response<Body> {
|
||||||
|
let mut res = Response::new(self.1.into());
|
||||||
|
*res.headers_mut() = self.0;
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T> IntoResponse for (StatusCode, HeaderMap, T)
|
impl<T> IntoResponse for (StatusCode, HeaderMap, T)
|
||||||
where
|
where
|
||||||
T: Into<Body>,
|
T: Into<Body>,
|
||||||
|
|
|
@ -770,12 +770,16 @@ where
|
||||||
|
|
||||||
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
|
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
|
||||||
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
|
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
|
||||||
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
|
let mut new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
|
||||||
path
|
path
|
||||||
} else {
|
} else {
|
||||||
path_and_query.path()
|
path_and_query.path()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if new_path.is_empty() {
|
||||||
|
new_path = "/";
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(query) = path_and_query.query() {
|
if let Some(query) = path_and_query.query() {
|
||||||
Some(
|
Some(
|
||||||
format!("{}?{}", new_path, query)
|
format!("{}?{}", new_path, query)
|
||||||
|
|
68
src/ws/future.rs
Normal file
68
src/ws/future.rs
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
//! Future types.
|
||||||
|
|
||||||
|
use bytes::Bytes;
|
||||||
|
use http::{HeaderValue, Response, StatusCode};
|
||||||
|
use http_body::Full;
|
||||||
|
use sha1::{Digest, Sha1};
|
||||||
|
use std::{
|
||||||
|
convert::Infallible,
|
||||||
|
future::Future,
|
||||||
|
pin::Pin,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Response future for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ResponseFuture(Result<Option<HeaderValue>, Option<(StatusCode, &'static str)>>);
|
||||||
|
|
||||||
|
impl ResponseFuture {
|
||||||
|
pub(super) fn ok(key: HeaderValue) -> Self {
|
||||||
|
Self(Ok(Some(key)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn err(status: StatusCode, body: &'static str) -> Self {
|
||||||
|
Self(Err(Some((status, body))))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Future for ResponseFuture {
|
||||||
|
type Output = Result<Response<Full<Bytes>>, Infallible>;
|
||||||
|
|
||||||
|
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
let res = match self.get_mut().0.as_mut() {
|
||||||
|
Ok(key) => Response::builder()
|
||||||
|
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||||
|
.header(
|
||||||
|
http::header::CONNECTION,
|
||||||
|
HeaderValue::from_str("upgrade").unwrap(),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
http::header::UPGRADE,
|
||||||
|
HeaderValue::from_str("websocket").unwrap(),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
http::header::SEC_WEBSOCKET_ACCEPT,
|
||||||
|
sign(key.take().unwrap().as_bytes()),
|
||||||
|
)
|
||||||
|
.body(Full::new(Bytes::new()))
|
||||||
|
.unwrap(),
|
||||||
|
Err(err) => {
|
||||||
|
let (status, body) = err.take().unwrap();
|
||||||
|
Response::builder()
|
||||||
|
.status(status)
|
||||||
|
.body(Full::from(body))
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Poll::Ready(Ok(res))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sign(key: &[u8]) -> HeaderValue {
|
||||||
|
let mut sha1 = Sha1::default();
|
||||||
|
sha1.update(key);
|
||||||
|
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
|
||||||
|
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
|
||||||
|
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
|
||||||
|
}
|
337
src/ws/mod.rs
Normal file
337
src/ws/mod.rs
Normal file
|
@ -0,0 +1,337 @@
|
||||||
|
//! Handle websocket connections.
|
||||||
|
//!
|
||||||
|
//! # Example
|
||||||
|
//!
|
||||||
|
//! ```
|
||||||
|
//! use tower_web::{prelude::*, ws::{ws, WebSocket}};
|
||||||
|
//!
|
||||||
|
//! let app = route("/ws", ws(handle_socket));
|
||||||
|
//!
|
||||||
|
//! async fn handle_socket(mut socket: WebSocket) {
|
||||||
|
//! while let Some(msg) = socket.recv().await {
|
||||||
|
//! let msg = msg.unwrap();
|
||||||
|
//! socket.send(msg).await.unwrap();
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use crate::{routing::EmptyRouter, service::OnMethod};
|
||||||
|
use bytes::Bytes;
|
||||||
|
use future::ResponseFuture;
|
||||||
|
use futures_util::{sink::SinkExt, stream::StreamExt};
|
||||||
|
use http::{
|
||||||
|
header::{self, HeaderName},
|
||||||
|
HeaderValue, Request, Response, StatusCode,
|
||||||
|
};
|
||||||
|
use http_body::Full;
|
||||||
|
use hyper::upgrade::{OnUpgrade, Upgraded};
|
||||||
|
use std::{borrow::Cow, convert::Infallible, fmt, future::Future, task::Context, task::Poll};
|
||||||
|
use tokio_tungstenite::{
|
||||||
|
tungstenite::protocol::{self, WebSocketConfig},
|
||||||
|
WebSocketStream,
|
||||||
|
};
|
||||||
|
use tower::{BoxError, Service};
|
||||||
|
|
||||||
|
pub mod future;
|
||||||
|
|
||||||
|
/// Create a new [`WebSocketUpgrade`] service that will call the closure with
|
||||||
|
/// each connection.
|
||||||
|
///
|
||||||
|
/// See the [module docs](crate::ws) for more details.
|
||||||
|
pub fn ws<F, Fut>(callback: F) -> OnMethod<WebSocketUpgrade<F>, EmptyRouter>
|
||||||
|
where
|
||||||
|
F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
|
||||||
|
Fut: Future<Output = ()> + Send + 'static,
|
||||||
|
{
|
||||||
|
let svc = WebSocketUpgrade {
|
||||||
|
callback,
|
||||||
|
config: WebSocketConfig::default(),
|
||||||
|
};
|
||||||
|
crate::service::get(svc)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// [`Service`] that ugprades connections to websockets and spawns a task to
|
||||||
|
/// handle the stream.
|
||||||
|
///
|
||||||
|
/// Created with [`ws`].
|
||||||
|
///
|
||||||
|
/// See the [module docs](crate::ws) for more details.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct WebSocketUpgrade<F> {
|
||||||
|
callback: F,
|
||||||
|
config: WebSocketConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> fmt::Debug for WebSocketUpgrade<F> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_struct("WebSocketUpgrade")
|
||||||
|
.field("callback", &format_args!("{}", std::any::type_name::<F>()))
|
||||||
|
.field("config", &self.config)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> WebSocketUpgrade<F> {
|
||||||
|
/// Set the size of the internal message send queue.
|
||||||
|
pub fn max_send_queue(mut self, max: usize) -> Self {
|
||||||
|
self.config.max_send_queue = Some(max);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the maximum message size (defaults to 64 megabytes)
|
||||||
|
pub fn max_message_size(mut self, max: usize) -> Self {
|
||||||
|
self.config.max_message_size = Some(max);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the maximum frame size (defaults to 16 megabytes)
|
||||||
|
pub fn max_frame_size(mut self, max: usize) -> Self {
|
||||||
|
self.config.max_frame_size = Some(max);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<ReqBody, F, Fut> Service<Request<ReqBody>> for WebSocketUpgrade<F>
|
||||||
|
where
|
||||||
|
F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
|
||||||
|
Fut: Future<Output = ()> + Send + 'static,
|
||||||
|
{
|
||||||
|
type Response = Response<Full<Bytes>>;
|
||||||
|
type Error = Infallible;
|
||||||
|
type Future = ResponseFuture;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
|
||||||
|
if !header_eq(
|
||||||
|
&req,
|
||||||
|
header::CONNECTION,
|
||||||
|
HeaderValue::from_static("upgrade"),
|
||||||
|
) {
|
||||||
|
return ResponseFuture::err(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"Connection header did not include 'upgrade'",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !header_eq(&req, header::UPGRADE, HeaderValue::from_static("websocket")) {
|
||||||
|
return ResponseFuture::err(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"`Upgrade` header did not include 'websocket'",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !header_eq(
|
||||||
|
&req,
|
||||||
|
header::SEC_WEBSOCKET_VERSION,
|
||||||
|
HeaderValue::from_static("13"),
|
||||||
|
) {
|
||||||
|
return ResponseFuture::err(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"`Sec-Websocket-Version` header did not include '13'",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
|
||||||
|
key
|
||||||
|
} else {
|
||||||
|
return ResponseFuture::err(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"`Sec-Websocket-Key` header missing",
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
|
||||||
|
|
||||||
|
let config = self.config;
|
||||||
|
let callback = self.callback.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let upgraded = on_upgrade.await.unwrap();
|
||||||
|
let socket =
|
||||||
|
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
|
||||||
|
.await;
|
||||||
|
let socket = WebSocket { inner: socket };
|
||||||
|
callback(socket).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
ResponseFuture::ok(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
|
||||||
|
if let Some(header) = req.headers().get(&key) {
|
||||||
|
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A stream of websocket messages.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct WebSocket {
|
||||||
|
inner: WebSocketStream<Upgraded>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WebSocket {
|
||||||
|
/// Receive another message.
|
||||||
|
///
|
||||||
|
/// Returns `None` is stream has closed.
|
||||||
|
pub async fn recv(&mut self) -> Option<Result<Message, BoxError>> {
|
||||||
|
self.inner
|
||||||
|
.next()
|
||||||
|
.await
|
||||||
|
.map(|result| result.map_err(Into::into).map(|inner| Message { inner }))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a message.
|
||||||
|
pub async fn send(&mut self, msg: Message) -> Result<(), BoxError> {
|
||||||
|
self.inner.send(msg.inner).await.map_err(Into::into)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gracefully close this websocket.
|
||||||
|
pub async fn close(mut self) -> Result<(), BoxError> {
|
||||||
|
self.inner.close(None).await.map_err(Into::into)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A WebSocket message.
|
||||||
|
#[derive(Eq, PartialEq, Clone)]
|
||||||
|
pub struct Message {
|
||||||
|
inner: protocol::Message,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Message {
|
||||||
|
/// Construct a new Text `Message`.
|
||||||
|
pub fn text<S>(s: S) -> Message
|
||||||
|
where
|
||||||
|
S: Into<String>,
|
||||||
|
{
|
||||||
|
Message {
|
||||||
|
inner: protocol::Message::text(s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Construct a new Binary `Message`.
|
||||||
|
pub fn binary<V>(v: V) -> Message
|
||||||
|
where
|
||||||
|
V: Into<Vec<u8>>,
|
||||||
|
{
|
||||||
|
Message {
|
||||||
|
inner: protocol::Message::binary(v),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Construct a new Ping `Message`.
|
||||||
|
pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message {
|
||||||
|
Message {
|
||||||
|
inner: protocol::Message::Ping(v.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Construct a new Pong `Message`.
|
||||||
|
///
|
||||||
|
/// Note that one rarely needs to manually construct a Pong message because
|
||||||
|
/// the underlying tungstenite socket automatically responds to the Ping
|
||||||
|
/// messages it receives. Manual construction might still be useful in some
|
||||||
|
/// cases like in tests or to send unidirectional heartbeats.
|
||||||
|
pub fn pong<V: Into<Vec<u8>>>(v: V) -> Message {
|
||||||
|
Message {
|
||||||
|
inner: protocol::Message::Pong(v.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Construct the default Close `Message`.
|
||||||
|
pub fn close() -> Message {
|
||||||
|
Message {
|
||||||
|
inner: protocol::Message::Close(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Construct a Close `Message` with a code and reason.
|
||||||
|
pub fn close_with<C, R>(code: C, reason: R) -> Message
|
||||||
|
where
|
||||||
|
C: Into<u16>,
|
||||||
|
R: Into<Cow<'static, str>>,
|
||||||
|
{
|
||||||
|
Message {
|
||||||
|
inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
|
||||||
|
code: protocol::frame::coding::CloseCode::from(code.into()),
|
||||||
|
reason: reason.into(),
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if this message is a Text message.
|
||||||
|
pub fn is_text(&self) -> bool {
|
||||||
|
self.inner.is_text()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if this message is a Binary message.
|
||||||
|
pub fn is_binary(&self) -> bool {
|
||||||
|
self.inner.is_binary()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if this message a is a Close message.
|
||||||
|
pub fn is_close(&self) -> bool {
|
||||||
|
self.inner.is_close()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if this message is a Ping message.
|
||||||
|
pub fn is_ping(&self) -> bool {
|
||||||
|
self.inner.is_ping()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if this message is a Pong message.
|
||||||
|
pub fn is_pong(&self) -> bool {
|
||||||
|
self.inner.is_pong()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to get the close frame (close code and reason)
|
||||||
|
pub fn close_frame(&self) -> Option<(u16, &str)> {
|
||||||
|
if let protocol::Message::Close(Some(close_frame)) = &self.inner {
|
||||||
|
Some((close_frame.code.into(), close_frame.reason.as_ref()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to get a reference to the string text, if this is a Text message.
|
||||||
|
pub fn to_str(&self) -> Option<&str> {
|
||||||
|
if let protocol::Message::Text(s) = &self.inner {
|
||||||
|
Some(s)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the bytes of this message, if the message can contain data.
|
||||||
|
pub fn as_bytes(&self) -> &[u8] {
|
||||||
|
match self.inner {
|
||||||
|
protocol::Message::Text(ref s) => s.as_bytes(),
|
||||||
|
protocol::Message::Binary(ref v) => v,
|
||||||
|
protocol::Message::Ping(ref v) => v,
|
||||||
|
protocol::Message::Pong(ref v) => v,
|
||||||
|
protocol::Message::Close(_) => &[],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Destructure this message into binary data.
|
||||||
|
pub fn into_bytes(self) -> Vec<u8> {
|
||||||
|
self.inner.into_data()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for Message {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
self.inner.fmt(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Message> for Vec<u8> {
|
||||||
|
fn from(msg: Message) -> Self {
|
||||||
|
msg.into_bytes()
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue