Add support for websockets (#3)

Basically a copy/paste of whats in warp.

Example usage:

```rust
use tower_web::{prelude::*, ws::{ws, WebSocket}};

let app = route("/ws", ws(handle_socket));

async fn handle_socket(mut socket: WebSocket) {
    while let Some(msg) = socket.recv().await {
        let msg = msg.unwrap();
        socket.send(msg).await.unwrap();
    }
}
```
This commit is contained in:
David Pedersen 2021-06-12 20:50:30 +02:00 committed by GitHub
parent 002e3f92b3
commit c9c507aece
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 518 additions and 2 deletions

View file

@ -12,6 +12,9 @@ readme = "README.md"
repository = "https://github.com/davidpdrsn/tower-web" repository = "https://github.com/davidpdrsn/tower-web"
version = "0.1.0" version = "0.1.0"
[features]
ws = ["tokio-tungstenite", "sha-1", "base64"]
[dependencies] [dependencies]
async-trait = "0.1" async-trait = "0.1"
bytes = "1.0" bytes = "1.0"
@ -28,16 +31,32 @@ tokio = { version = "1", features = ["time"] }
tower = { version = "0.4", features = ["util", "buffer"] } tower = { version = "0.4", features = ["util", "buffer"] }
tower-http = { version = "0.1", features = ["add-extension"] } tower-http = { version = "0.1", features = ["add-extension"] }
# optional dependencies
tokio-tungstenite = { optional = true, version = "0.14" }
sha-1 = { optional = true, version = "0.9.6" }
base64 = { optional = true, version = "0.13" }
[dev-dependencies] [dev-dependencies]
hyper = { version = "0.14", features = ["full"] } hyper = { version = "0.14", features = ["full"] }
reqwest = { version = "0.11", features = ["json", "stream"] } reqwest = { version = "0.11", features = ["json", "stream"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] } tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] }
tower = { version = "0.4", features = ["util", "make", "timeout", "limit", "load-shed", "steer"] }
tracing = "0.1" tracing = "0.1"
tracing-subscriber = "0.2" tracing-subscriber = "0.2"
uuid = "0.8" uuid = "0.8"
[dev-dependencies.tower]
version = "0.4"
features = [
"util",
"make",
"timeout",
"limit",
"load-shed",
"steer",
"filter",
]
[dev-dependencies.tower-http] [dev-dependencies.tower-http]
version = "0.1" version = "0.1"
features = [ features = [

63
examples/websocket.rs Normal file
View file

@ -0,0 +1,63 @@
//! Example websocket server.
//!
//! Run with
//!
//! ```
//! RUST_LOG=tower_http=debug,key_value_store=trace \
//! cargo run \
//! --features ws \
//! --example websocket
//! ```
use http::StatusCode;
use hyper::Server;
use std::net::SocketAddr;
use tower::make::Shared;
use tower_http::{
services::ServeDir,
trace::{DefaultMakeSpan, TraceLayer},
};
use tower_web::{
prelude::*,
routing::nest,
service::ServiceExt,
ws::{ws, Message, WebSocket},
};
#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();
// build our application with some routes
let app = nest(
"/",
ServeDir::new("examples/websocket")
.append_index_html_on_directories(true)
.handle_error(|error| (StatusCode::INTERNAL_SERVER_ERROR, error.to_string())),
)
// routes are matched from bottom to top, so we have to put `nest` at the
// top since it matches all routes
.route("/ws", ws(handle_socket))
// logging so we can see whats going on
.layer(
TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default().include_headers(true)),
);
// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
let server = Server::bind(&addr).serve(Shared::new(app));
server.await.unwrap();
}
async fn handle_socket(mut socket: WebSocket) {
if let Some(msg) = socket.recv().await {
let msg = msg.unwrap();
println!("Client says: {:?}", msg);
}
loop {
socket.send(Message::text("Hi!")).await.unwrap();
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
}
}

View file

@ -0,0 +1 @@
<script src='script.js'></script>

View file

@ -0,0 +1,9 @@
const socket = new WebSocket('ws://localhost:3000/ws');
socket.addEventListener('open', function (event) {
socket.send('Hello Server!');
});
socket.addEventListener('message', function (event) {
console.log('Message from server ', event.data);
});

View file

@ -610,6 +610,10 @@ pub mod response;
pub mod routing; pub mod routing;
pub mod service; pub mod service;
#[cfg(feature = "ws")]
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub mod ws;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;

View file

@ -147,6 +147,17 @@ where
} }
} }
impl<T> IntoResponse for (HeaderMap, T)
where
T: Into<Body>,
{
fn into_response(self) -> Response<Body> {
let mut res = Response::new(self.1.into());
*res.headers_mut() = self.0;
res
}
}
impl<T> IntoResponse for (StatusCode, HeaderMap, T) impl<T> IntoResponse for (StatusCode, HeaderMap, T)
where where
T: Into<Body>, T: Into<Body>,

View file

@ -770,12 +770,16 @@ where
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri { fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
let path_and_query = if let Some(path_and_query) = uri.path_and_query() { let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) { let mut new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
path path
} else { } else {
path_and_query.path() path_and_query.path()
}; };
if new_path.is_empty() {
new_path = "/";
}
if let Some(query) = path_and_query.query() { if let Some(query) = path_and_query.query() {
Some( Some(
format!("{}?{}", new_path, query) format!("{}?{}", new_path, query)

68
src/ws/future.rs Normal file
View file

@ -0,0 +1,68 @@
//! Future types.
use bytes::Bytes;
use http::{HeaderValue, Response, StatusCode};
use http_body::Full;
use sha1::{Digest, Sha1};
use std::{
convert::Infallible,
future::Future,
pin::Pin,
task::{Context, Poll},
};
/// Response future for [`WebSocketUpgrade`](super::WebSocketUpgrade).
#[derive(Debug)]
pub struct ResponseFuture(Result<Option<HeaderValue>, Option<(StatusCode, &'static str)>>);
impl ResponseFuture {
pub(super) fn ok(key: HeaderValue) -> Self {
Self(Ok(Some(key)))
}
pub(super) fn err(status: StatusCode, body: &'static str) -> Self {
Self(Err(Some((status, body))))
}
}
impl Future for ResponseFuture {
type Output = Result<Response<Full<Bytes>>, Infallible>;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = match self.get_mut().0.as_mut() {
Ok(key) => Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(
http::header::CONNECTION,
HeaderValue::from_str("upgrade").unwrap(),
)
.header(
http::header::UPGRADE,
HeaderValue::from_str("websocket").unwrap(),
)
.header(
http::header::SEC_WEBSOCKET_ACCEPT,
sign(key.take().unwrap().as_bytes()),
)
.body(Full::new(Bytes::new()))
.unwrap(),
Err(err) => {
let (status, body) = err.take().unwrap();
Response::builder()
.status(status)
.body(Full::from(body))
.unwrap()
}
};
Poll::Ready(Ok(res))
}
}
fn sign(key: &[u8]) -> HeaderValue {
let mut sha1 = Sha1::default();
sha1.update(key);
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
}

337
src/ws/mod.rs Normal file
View file

@ -0,0 +1,337 @@
//! Handle websocket connections.
//!
//! # Example
//!
//! ```
//! use tower_web::{prelude::*, ws::{ws, WebSocket}};
//!
//! let app = route("/ws", ws(handle_socket));
//!
//! async fn handle_socket(mut socket: WebSocket) {
//! while let Some(msg) = socket.recv().await {
//! let msg = msg.unwrap();
//! socket.send(msg).await.unwrap();
//! }
//! }
//! ```
use crate::{routing::EmptyRouter, service::OnMethod};
use bytes::Bytes;
use future::ResponseFuture;
use futures_util::{sink::SinkExt, stream::StreamExt};
use http::{
header::{self, HeaderName},
HeaderValue, Request, Response, StatusCode,
};
use http_body::Full;
use hyper::upgrade::{OnUpgrade, Upgraded};
use std::{borrow::Cow, convert::Infallible, fmt, future::Future, task::Context, task::Poll};
use tokio_tungstenite::{
tungstenite::protocol::{self, WebSocketConfig},
WebSocketStream,
};
use tower::{BoxError, Service};
pub mod future;
/// Create a new [`WebSocketUpgrade`] service that will call the closure with
/// each connection.
///
/// See the [module docs](crate::ws) for more details.
pub fn ws<F, Fut>(callback: F) -> OnMethod<WebSocketUpgrade<F>, EmptyRouter>
where
F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let svc = WebSocketUpgrade {
callback,
config: WebSocketConfig::default(),
};
crate::service::get(svc)
}
/// [`Service`] that ugprades connections to websockets and spawns a task to
/// handle the stream.
///
/// Created with [`ws`].
///
/// See the [module docs](crate::ws) for more details.
#[derive(Clone)]
pub struct WebSocketUpgrade<F> {
callback: F,
config: WebSocketConfig,
}
impl<F> fmt::Debug for WebSocketUpgrade<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WebSocketUpgrade")
.field("callback", &format_args!("{}", std::any::type_name::<F>()))
.field("config", &self.config)
.finish()
}
}
impl<F> WebSocketUpgrade<F> {
/// Set the size of the internal message send queue.
pub fn max_send_queue(mut self, max: usize) -> Self {
self.config.max_send_queue = Some(max);
self
}
/// Set the maximum message size (defaults to 64 megabytes)
pub fn max_message_size(mut self, max: usize) -> Self {
self.config.max_message_size = Some(max);
self
}
/// Set the maximum frame size (defaults to 16 megabytes)
pub fn max_frame_size(mut self, max: usize) -> Self {
self.config.max_frame_size = Some(max);
self
}
}
impl<ReqBody, F, Fut> Service<Request<ReqBody>> for WebSocketUpgrade<F>
where
F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
type Response = Response<Full<Bytes>>;
type Error = Infallible;
type Future = ResponseFuture;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
if !header_eq(
&req,
header::CONNECTION,
HeaderValue::from_static("upgrade"),
) {
return ResponseFuture::err(
StatusCode::BAD_REQUEST,
"Connection header did not include 'upgrade'",
);
}
if !header_eq(&req, header::UPGRADE, HeaderValue::from_static("websocket")) {
return ResponseFuture::err(
StatusCode::BAD_REQUEST,
"`Upgrade` header did not include 'websocket'",
);
}
if !header_eq(
&req,
header::SEC_WEBSOCKET_VERSION,
HeaderValue::from_static("13"),
) {
return ResponseFuture::err(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Version` header did not include '13'",
);
}
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
key
} else {
return ResponseFuture::err(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Key` header missing",
);
};
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
let config = self.config;
let callback = self.callback.clone();
tokio::spawn(async move {
let upgraded = on_upgrade.await.unwrap();
let socket =
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
.await;
let socket = WebSocket { inner: socket };
callback(socket).await;
});
ResponseFuture::ok(key)
}
}
fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
if let Some(header) = req.headers().get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
} else {
false
}
}
/// A stream of websocket messages.
#[derive(Debug)]
pub struct WebSocket {
inner: WebSocketStream<Upgraded>,
}
impl WebSocket {
/// Receive another message.
///
/// Returns `None` is stream has closed.
pub async fn recv(&mut self) -> Option<Result<Message, BoxError>> {
self.inner
.next()
.await
.map(|result| result.map_err(Into::into).map(|inner| Message { inner }))
}
/// Send a message.
pub async fn send(&mut self, msg: Message) -> Result<(), BoxError> {
self.inner.send(msg.inner).await.map_err(Into::into)
}
/// Gracefully close this websocket.
pub async fn close(mut self) -> Result<(), BoxError> {
self.inner.close(None).await.map_err(Into::into)
}
}
/// A WebSocket message.
#[derive(Eq, PartialEq, Clone)]
pub struct Message {
inner: protocol::Message,
}
impl Message {
/// Construct a new Text `Message`.
pub fn text<S>(s: S) -> Message
where
S: Into<String>,
{
Message {
inner: protocol::Message::text(s),
}
}
/// Construct a new Binary `Message`.
pub fn binary<V>(v: V) -> Message
where
V: Into<Vec<u8>>,
{
Message {
inner: protocol::Message::binary(v),
}
}
/// Construct a new Ping `Message`.
pub fn ping<V: Into<Vec<u8>>>(v: V) -> Message {
Message {
inner: protocol::Message::Ping(v.into()),
}
}
/// Construct a new Pong `Message`.
///
/// Note that one rarely needs to manually construct a Pong message because
/// the underlying tungstenite socket automatically responds to the Ping
/// messages it receives. Manual construction might still be useful in some
/// cases like in tests or to send unidirectional heartbeats.
pub fn pong<V: Into<Vec<u8>>>(v: V) -> Message {
Message {
inner: protocol::Message::Pong(v.into()),
}
}
/// Construct the default Close `Message`.
pub fn close() -> Message {
Message {
inner: protocol::Message::Close(None),
}
}
/// Construct a Close `Message` with a code and reason.
pub fn close_with<C, R>(code: C, reason: R) -> Message
where
C: Into<u16>,
R: Into<Cow<'static, str>>,
{
Message {
inner: protocol::Message::Close(Some(protocol::frame::CloseFrame {
code: protocol::frame::coding::CloseCode::from(code.into()),
reason: reason.into(),
})),
}
}
/// Returns true if this message is a Text message.
pub fn is_text(&self) -> bool {
self.inner.is_text()
}
/// Returns true if this message is a Binary message.
pub fn is_binary(&self) -> bool {
self.inner.is_binary()
}
/// Returns true if this message a is a Close message.
pub fn is_close(&self) -> bool {
self.inner.is_close()
}
/// Returns true if this message is a Ping message.
pub fn is_ping(&self) -> bool {
self.inner.is_ping()
}
/// Returns true if this message is a Pong message.
pub fn is_pong(&self) -> bool {
self.inner.is_pong()
}
/// Try to get the close frame (close code and reason)
pub fn close_frame(&self) -> Option<(u16, &str)> {
if let protocol::Message::Close(Some(close_frame)) = &self.inner {
Some((close_frame.code.into(), close_frame.reason.as_ref()))
} else {
None
}
}
/// Try to get a reference to the string text, if this is a Text message.
pub fn to_str(&self) -> Option<&str> {
if let protocol::Message::Text(s) = &self.inner {
Some(s)
} else {
None
}
}
/// Return the bytes of this message, if the message can contain data.
pub fn as_bytes(&self) -> &[u8] {
match self.inner {
protocol::Message::Text(ref s) => s.as_bytes(),
protocol::Message::Binary(ref v) => v,
protocol::Message::Ping(ref v) => v,
protocol::Message::Pong(ref v) => v,
protocol::Message::Close(_) => &[],
}
}
/// Destructure this message into binary data.
pub fn into_bytes(self) -> Vec<u8> {
self.inner.into_data()
}
}
impl fmt::Debug for Message {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.inner.fmt(f)
}
}
impl From<Message> for Vec<u8> {
fn from(msg: Message) -> Self {
msg.into_bytes()
}
}