mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-22 07:08:16 +01:00
Add support for websockets (#3)
Basically a copy/paste of whats in warp. Example usage: ```rust use tower_web::{prelude::*, ws::{ws, WebSocket}}; let app = route("/ws", ws(handle_socket)); async fn handle_socket(mut socket: WebSocket) { while let Some(msg) = socket.recv().await { let msg = msg.unwrap(); socket.send(msg).await.unwrap(); } } ```
This commit is contained in:
parent
002e3f92b3
commit
c9c507aece
9 changed files with 518 additions and 2 deletions
21
Cargo.toml
21
Cargo.toml
|
@ -12,6 +12,9 @@ readme = "README.md"
|
|||
repository = "https://github.com/davidpdrsn/tower-web"
|
||||
version = "0.1.0"
|
||||
|
||||
[features]
|
||||
ws = ["tokio-tungstenite", "sha-1", "base64"]
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1"
|
||||
bytes = "1.0"
|
||||
|
@ -28,16 +31,32 @@ tokio = { version = "1", features = ["time"] }
|
|||
tower = { version = "0.4", features = ["util", "buffer"] }
|
||||
tower-http = { version = "0.1", features = ["add-extension"] }
|
||||
|
||||
# optional dependencies
|
||||
tokio-tungstenite = { optional = true, version = "0.14" }
|
||||
sha-1 = { optional = true, version = "0.9.6" }
|
||||
base64 = { optional = true, version = "0.13" }
|
||||
|
||||
[dev-dependencies]
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
reqwest = { version = "0.11", features = ["json", "stream"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] }
|
||||
tower = { version = "0.4", features = ["util", "make", "timeout", "limit", "load-shed", "steer"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.2"
|
||||
uuid = "0.8"
|
||||
|
||||
[dev-dependencies.tower]
|
||||
version = "0.4"
|
||||
features = [
|
||||
"util",
|
||||
"make",
|
||||
"timeout",
|
||||
"limit",
|
||||
"load-shed",
|
||||
"steer",
|
||||
"filter",
|
||||
]
|
||||
|
||||
[dev-dependencies.tower-http]
|
||||
version = "0.1"
|
||||
features = [
|
||||
|
|
63
examples/websocket.rs
Normal file
63
examples/websocket.rs
Normal file
|
@ -0,0 +1,63 @@
|
|||
//! Example websocket server.
|
||||
//!
|
||||
//! Run with
|
||||
//!
|
||||
//! ```
|
||||
//! RUST_LOG=tower_http=debug,key_value_store=trace \
|
||||
//! cargo run \
|
||||
//! --features ws \
|
||||
//! --example websocket
|
||||
//! ```
|
||||
|
||||
use http::StatusCode;
|
||||
use hyper::Server;
|
||||
use std::net::SocketAddr;
|
||||
use tower::make::Shared;
|
||||
use tower_http::{
|
||||
services::ServeDir,
|
||||
trace::{DefaultMakeSpan, TraceLayer},
|
||||
};
|
||||
use tower_web::{
|
||||
prelude::*,
|
||||
routing::nest,
|
||||
service::ServiceExt,
|
||||
ws::{ws, Message, WebSocket},
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
// build our application with some routes
|
||||
let app = nest(
|
||||
"/",
|
||||
ServeDir::new("examples/websocket")
|
||||
.append_index_html_on_directories(true)
|
||||
.handle_error(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string())),
|
||||
)
|
||||
// routes are matched from bottom to top, so we have to put `nest` at the
|
||||
// top since it matches all routes
|
||||
.route("/ws", ws(handle_socket))
|
||||
// logging so we can see whats going on
|
||||
.layer(
|
||||
TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true)),
|
||||
);
|
||||
|
||||
// run it with hyper
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
tracing::debug!("listening on {}", addr);
|
||||
let server = Server::bind(&addr).serve(Shared::new(app));
|
||||
server.await.unwrap();
|
||||
}
|
||||
|
||||
async fn handle_socket(mut socket: WebSocket) {
|
||||
if let Some(msg) = socket.recv().await {
|
||||
let msg = msg.unwrap();
|
||||
println!("Client says: {:?}", msg);
|
||||
}
|
||||
|
||||
loop {
|
||||
socket.send(Message::text("Hi!")).await.unwrap();
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
}
|
||||
}
|
1
examples/websocket/index.html
Normal file
1
examples/websocket/index.html
Normal file
|
@ -0,0 +1 @@
|
|||
<script src='script.js'></script>
|
9
examples/websocket/script.js
Normal file
9
examples/websocket/script.js
Normal file
|
@ -0,0 +1,9 @@
|
|||
const socket = new WebSocket('ws://localhost:3000/ws');
|
||||
|
||||
socket.addEventListener('open', function (event) {
|
||||
socket.send('Hello Server!');
|
||||
});
|
||||
|
||||
socket.addEventListener('message', function (event) {
|
||||
console.log('Message from server ', event.data);
|
||||
});
|
|
@ -610,6 +610,10 @@ pub mod response;
|
|||
pub mod routing;
|
||||
pub mod service;
|
||||
|
||||
#[cfg(feature = "ws")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
|
||||
pub mod ws;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
|
|
|
@ -147,6 +147,17 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<T> IntoResponse for (HeaderMap, T)
|
||||
where
|
||||
T: Into<Body>,
|
||||
{
|
||||
fn into_response(self) -> Response<Body> {
|
||||
let mut res = Response::new(self.1.into());
|
||||
*res.headers_mut() = self.0;
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoResponse for (StatusCode, HeaderMap, T)
|
||||
where
|
||||
T: Into<Body>,
|
||||
|
|
|
@ -770,12 +770,16 @@ where
|
|||
|
||||
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
|
||||
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
|
||||
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
|
||||
let mut new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
|
||||
path
|
||||
} else {
|
||||
path_and_query.path()
|
||||
};
|
||||
|
||||
if new_path.is_empty() {
|
||||
new_path = "/";
|
||||
}
|
||||
|
||||
if let Some(query) = path_and_query.query() {
|
||||
Some(
|
||||
format!("{}?{}", new_path, query)
|
||||
|
|
68
src/ws/future.rs
Normal file
68
src/ws/future.rs
Normal file
|
@ -0,0 +1,68 @@
|
|||
//! Future types.
|
||||
|
||||
use bytes::Bytes;
|
||||
use http::{HeaderValue, Response, StatusCode};
|
||||
use http_body::Full;
|
||||
use sha1::{Digest, Sha1};
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
/// Response future for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||
#[derive(Debug)]
|
||||
pub struct ResponseFuture(Result<Option<HeaderValue>, Option<(StatusCode, &'static str)>>);
|
||||
|
||||
impl ResponseFuture {
|
||||
pub(super) fn ok(key: HeaderValue) -> Self {
|
||||
Self(Ok(Some(key)))
|
||||
}
|
||||
|
||||
pub(super) fn err(status: StatusCode, body: &'static str) -> Self {
|
||||
Self(Err(Some((status, body))))
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for ResponseFuture {
|
||||
type Output = Result<Response<Full<Bytes>>, Infallible>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let res = match self.get_mut().0.as_mut() {
|
||||
Ok(key) => Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(
|
||||
http::header::CONNECTION,
|
||||
HeaderValue::from_str("upgrade").unwrap(),
|
||||
)
|
||||
.header(
|
||||
http::header::UPGRADE,
|
||||
HeaderValue::from_str("websocket").unwrap(),
|
||||
)
|
||||
.header(
|
||||
http::header::SEC_WEBSOCKET_ACCEPT,
|
||||
sign(key.take().unwrap().as_bytes()),
|
||||
)
|
||||
.body(Full::new(Bytes::new()))
|
||||
.unwrap(),
|
||||
Err(err) => {
|
||||
let (status, body) = err.take().unwrap();
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.body(Full::from(body))
|
||||
.unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
Poll::Ready(Ok(res))
|
||||
}
|
||||
}
|
||||
|
||||
fn sign(key: &[u8]) -> HeaderValue {
|
||||
let mut sha1 = Sha1::default();
|
||||
sha1.update(key);
|
||||
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
|
||||
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
|
||||
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
|
||||
}
|
337
src/ws/mod.rs
Normal file
337
src/ws/mod.rs
Normal file
|
@ -0,0 +1,337 @@
|
|||
//! Handle websocket connections.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```
|
||||
//! use tower_web::{prelude::*, ws::{ws, WebSocket}};
|
||||
//!
|
||||
//! let app = route("/ws", ws(handle_socket));
|
||||
//!
|
||||
//! async fn handle_socket(mut socket: WebSocket) {
|
||||
//! while let Some(msg) = socket.recv().await {
|
||||
//! let msg = msg.unwrap();
|
||||
//! socket.send(msg).await.unwrap();
|
||||
//! }
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
use crate::{routing::EmptyRouter, service::OnMethod};
|
||||
use bytes::Bytes;
|
||||
use future::ResponseFuture;
|
||||
use futures_util::{sink::SinkExt, stream::StreamExt};
|
||||
use http::{
|
||||
header::{self, HeaderName},
|
||||
HeaderValue, Request, Response, StatusCode,
|
||||
};
|
||||
use http_body::Full;
|
||||
use hyper::upgrade::{OnUpgrade, Upgraded};
|
||||
use std::{borrow::Cow, convert::Infallible, fmt, future::Future, task::Context, task::Poll};
|
||||
use tokio_tungstenite::{
|
||||
tungstenite::protocol::{self, WebSocketConfig},
|
||||
WebSocketStream,
|
||||
};
|
||||
use tower::{BoxError, Service};
|
||||
|
||||
pub mod future;
|
||||
|
||||
/// Create a new [`WebSocketUpgrade`] service that will call the closure with
|
||||
/// each connection.
|
||||
///
|
||||
/// See the [module docs](crate::ws) for more details.
|
||||
pub fn ws<F, Fut>(callback: F) -> OnMethod<WebSocketUpgrade<F>, EmptyRouter>
|
||||
where
|
||||
F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
|
||||
Fut: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let svc = WebSocketUpgrade {
|
||||
callback,
|
||||
config: WebSocketConfig::default(),
|
||||
};
|
||||
crate::service::get(svc)
|
||||
}
|
||||
|
||||
/// [`Service`] that ugprades connections to websockets and spawns a task to
|
||||
/// handle the stream.
|
||||
///
|
||||
/// Created with [`ws`].
|
||||
///
|
||||
/// See the [module docs](crate::ws) for more details.
|
||||
#[derive(Clone)]
|
||||
pub struct WebSocketUpgrade<F> {
|
||||
callback: F,
|
||||
config: WebSocketConfig,
|
||||
}
|
||||
|
||||
impl<F> fmt::Debug for WebSocketUpgrade<F> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("WebSocketUpgrade")
|
||||
.field("callback", &format_args!("{}", std::any::type_name::<F>()))
|
||||
.field("config", &self.config)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> WebSocketUpgrade<F> {
|
||||
/// Set the size of the internal message send queue.
|
||||
pub fn max_send_queue(mut self, max: usize) -> Self {
|
||||
self.config.max_send_queue = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum message size (defaults to 64 megabytes)
|
||||
pub fn max_message_size(mut self, max: usize) -> Self {
|
||||
self.config.max_message_size = Some(max);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum frame size (defaults to 16 megabytes)
|
||||
pub fn max_frame_size(mut self, max: usize) -> Self {
|
||||
self.config.max_frame_size = Some(max);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<ReqBody, F, Fut> Service<Request<ReqBody>> for WebSocketUpgrade<F>
|
||||
where
|
||||
F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
|
||||
Fut: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
type Response = Response<Full<Bytes>>;
|
||||
type Error = Infallible;
|
||||
type Future = ResponseFuture;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
|
||||
if !header_eq(
|
||||
&req,
|
||||
header::CONNECTION,
|
||||
HeaderValue::from_static("upgrade"),
|
||||
) {
|
||||
return ResponseFuture::err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Connection header did not include 'upgrade'",
|
||||
);
|
||||
}
|
||||
|
||||
if !header_eq(&req, header::UPGRADE, HeaderValue::from_static("websocket")) {
|
||||
return ResponseFuture::err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Upgrade` header did not include 'websocket'",
|
||||
);
|
||||
}
|
||||
|
||||
if !header_eq(
|
||||
&req,
|
||||
header::SEC_WEBSOCKET_VERSION,
|
||||
HeaderValue::from_static("13"),
|
||||
) {
|
||||
return ResponseFuture::err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Version` header did not include '13'",
|
||||
);
|
||||
}
|
||||
|
||||
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
|
||||
key
|
||||
} else {
|
||||
return ResponseFuture::err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Key` header missing",
|
||||
);
|
||||
};
|
||||
|
||||
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
|
||||
|
||||
let config = self.config;
|
||||
let callback = self.callback.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let upgraded = on_upgrade.await.unwrap();
|
||||
let socket =
|
||||
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
|
||||
.await;
|
||||
let socket = WebSocket { inner: socket };
|
||||
callback(socket).await;
|
||||
});
|
||||
|
||||
ResponseFuture::ok(key)
|
||||
}
|
||||
}
|
||||
|
||||
fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
|
||||
if let Some(header) = req.headers().get(&key) {
|
||||
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// A stream of websocket messages.
|
||||
#[derive(Debug)]
|
||||
pub struct WebSocket {
|
||||
inner: WebSocketStream<Upgraded>,
|
||||
}
|
||||
|
||||
impl WebSocket {
|
||||
/// Receive another message.
|
||||
///
|
||||
/// Returns `None` is stream has closed.
|
||||
pub async fn recv(&mut self) -> Option<Result<Message, BoxError>> {
|
||||
self.inner
|
||||
.next()
|
||||
.await
|
||||
.map(|result| result.map_err(Into::into).map(|inner| Message { inner }))
|
||||
}
|
||||
|
||||
/// Send a message.
|
||||
pub async fn send(&mut self, msg: Message) -> Result<(), BoxError> {
|
||||
self.inner.send(msg.inner).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Gracefully close this websocket.
|
||||
pub async fn close(mut self) -> Result<(), BoxError> {
|
||||
self.inner.close(None).await.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
/// A WebSocket message.
|
||||
#[derive(Eq, PartialEq, Clone)]
|
||||
pub struct Message {
|
||||
inner: protocol::Message,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
/// Construct a new Text `Message`.
|
||||
pub fn text<S>(s: S) -> Message
|
||||
where
|
||||
S: Into<String>,
|
||||
{
|
||||
Message {
|
||||
inner: protocol::Message::text(s),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a new Binary `Message`.
|
||||
pub fn binary<V>(v: V) -> Message
|
||||
where
|
||||
V: Into<Vec<u8>>,
|
||||
{
|
||||
Message {
|
||||
inner: protocol::Message::binary(v),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a new Ping `Message`.
|
||||
pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message {
|
||||
Message {
|
||||
inner: protocol::Message::Ping(v.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a new Pong `Message`.
|
||||
///
|
||||
/// Note that one rarely needs to manually construct a Pong message because
|
||||
/// the underlying tungstenite socket automatically responds to the Ping
|
||||
/// messages it receives. Manual construction might still be useful in some
|
||||
/// cases like in tests or to send unidirectional heartbeats.
|
||||
pub fn pong<V: Into<Vec<u8>>>(v: V) -> Message {
|
||||
Message {
|
||||
inner: protocol::Message::Pong(v.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct the default Close `Message`.
|
||||
pub fn close() -> Message {
|
||||
Message {
|
||||
inner: protocol::Message::Close(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a Close `Message` with a code and reason.
|
||||
pub fn close_with<C, R>(code: C, reason: R) -> Message
|
||||
where
|
||||
C: Into<u16>,
|
||||
R: Into<Cow<'static, str>>,
|
||||
{
|
||||
Message {
|
||||
inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
|
||||
code: protocol::frame::coding::CloseCode::from(code.into()),
|
||||
reason: reason.into(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if this message is a Text message.
|
||||
pub fn is_text(&self) -> bool {
|
||||
self.inner.is_text()
|
||||
}
|
||||
|
||||
/// Returns true if this message is a Binary message.
|
||||
pub fn is_binary(&self) -> bool {
|
||||
self.inner.is_binary()
|
||||
}
|
||||
|
||||
/// Returns true if this message a is a Close message.
|
||||
pub fn is_close(&self) -> bool {
|
||||
self.inner.is_close()
|
||||
}
|
||||
|
||||
/// Returns true if this message is a Ping message.
|
||||
pub fn is_ping(&self) -> bool {
|
||||
self.inner.is_ping()
|
||||
}
|
||||
|
||||
/// Returns true if this message is a Pong message.
|
||||
pub fn is_pong(&self) -> bool {
|
||||
self.inner.is_pong()
|
||||
}
|
||||
|
||||
/// Try to get the close frame (close code and reason)
|
||||
pub fn close_frame(&self) -> Option<(u16, &str)> {
|
||||
if let protocol::Message::Close(Some(close_frame)) = &self.inner {
|
||||
Some((close_frame.code.into(), close_frame.reason.as_ref()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to get a reference to the string text, if this is a Text message.
|
||||
pub fn to_str(&self) -> Option<&str> {
|
||||
if let protocol::Message::Text(s) = &self.inner {
|
||||
Some(s)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the bytes of this message, if the message can contain data.
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
match self.inner {
|
||||
protocol::Message::Text(ref s) => s.as_bytes(),
|
||||
protocol::Message::Binary(ref v) => v,
|
||||
protocol::Message::Ping(ref v) => v,
|
||||
protocol::Message::Pong(ref v) => v,
|
||||
protocol::Message::Close(_) => &[],
|
||||
}
|
||||
}
|
||||
|
||||
/// Destructure this message into binary data.
|
||||
pub fn into_bytes(self) -> Vec<u8> {
|
||||
self.inner.into_data()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Message {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
self.inner.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Message> for Vec<u8> {
|
||||
fn from(msg: Message) -> Self {
|
||||
msg.into_bytes()
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue