All uses of `get_random()` were in the form of: `&get_random(vec![0u8; SIZE])` with `SIZE` being a constant. Building a `Vec` is unnecessary for two reasons. First, it uses a very short-lived dynamic memory allocation. Second, a `Vec` is a resizable object, which is useless in those context when random data have a fixed size and will only be read. `get_random_bytes()` takes a constant as a generic parameter and returns an array with the requested number of random bytes. Stack safety analysis: the random bytes will be allocated on the caller stack for a very short time (until the encoding function has been called on the data). In some cases, the random bytes take less room than the `Vec` did (a `Vec` is 24 bytes on a 64 bit computer). The maximum used size is 180 bytes, which makes it for 0.008% of the default stack size for a Rust thread (2MiB), so this is a non-issue. Also, most of the uses of those random bytes are to encode them using an `Encoding`. The function `crypto::encode_random_bytes()` generates random bytes and encode them with the provided `Encoding`, leading to code deduplication. `generate_id()` has also been converted to use a constant generic parameter as well since the length of the requested String is always a constant.
use std::{
atomic::{AtomicBool, Ordering},
use chrono::NaiveDateTime;
use futures::{SinkExt, StreamExt};
use rmpv::Value;
use rocket::{serde::json::Json, Route};
use serde_json::Value as JsonValue;
use tokio::{
net::{TcpListener, TcpStream},
use tokio_tungstenite::{
tungstenite::{handshake, Message},
use crate::{
db::models::{Cipher, Folder, Send, User},
Error, CONFIG,
pub fn routes() -> Vec<Route> {
routes![negotiate, websockets_err]
fn websockets_err() -> EmptyResult {
static SHOW_WEBSOCKETS_MSG: AtomicBool = AtomicBool::new(true);
if CONFIG.websocket_enabled()
&& SHOW_WEBSOCKETS_MSG.compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed).is_ok()
'/notifications/hub' should be proxied to the websocket server or notifications won't work.
Go to the Wiki for more info, or disable WebSockets setting WEBSOCKET_ENABLED=false.
} else {
fn negotiate(_headers: Headers) -> Json<JsonValue> {
use crate::crypto;
use data_encoding::BASE64URL;
let conn_id = crypto::encode_random_bytes::<16>(BASE64URL);
let mut available_transports: Vec<JsonValue> = Vec::new();
if CONFIG.websocket_enabled() {
available_transports.push(json!({"transport":"WebSockets", "transferFormats":["Text","Binary"]}));
// TODO: Implement transports
// Rocket WS support: https://github.com/SergioBenitez/Rocket/issues/90
// Rocket SSE support: https://github.com/SergioBenitez/Rocket/issues/33
// {"transport":"ServerSentEvents", "transferFormats":["Text"]},
// {"transport":"LongPolling", "transferFormats":["Text","Binary"]}
"connectionId": conn_id,
"availableTransports": available_transports
// Websockets server
fn serialize(val: Value) -> Vec<u8> {
use rmpv::encode::write_value;
let mut buf = Vec::new();
write_value(&mut buf, &val).expect("Error encoding MsgPack");
// Add size bytes at the start
// Extracted from BinaryMessageFormat.js
let mut size: usize = buf.len();
let mut len_buf: Vec<u8> = Vec::new();
loop {
let mut size_part = size & 0x7f;
size >>= 7;
if size > 0 {
size_part |= 0x80;
len_buf.push(size_part as u8);
if size == 0 {
len_buf.append(&mut buf);
fn serialize_date(date: NaiveDateTime) -> Value {
let seconds: i64 = date.timestamp();
let nanos: i64 = date.timestamp_subsec_nanos().into();
let timestamp = nanos << 34 | seconds;
let bs = timestamp.to_be_bytes();
// -1 is Timestamp
// https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type
Value::Ext(-1, bs.to_vec())
fn convert_option<T: Into<Value>>(option: Option<T>) -> Value {
match option {
Some(a) => a.into(),
None => Value::Nil,
const RECORD_SEPARATOR: u8 = 0x1e;
const INITIAL_RESPONSE: [u8; 3] = [0x7b, 0x7d, RECORD_SEPARATOR]; // {, }, <RS>
#[derive(Deserialize, Copy, Clone, Eq, PartialEq)]
struct InitialMessage<'a> {
protocol: &'a str,
version: i32,
static INITIAL_MESSAGE: InitialMessage<'static> = InitialMessage {
protocol: "messagepack",
version: 1,
// We attach the UUID to the sender so we can differentiate them when we need to remove them from the Vec
type UserSenders = (uuid::Uuid, Sender<Message>);
pub struct WebSocketUsers {
map: Arc<dashmap::DashMap<String, Vec<UserSenders>>>,
impl WebSocketUsers {
async fn send_update(&self, user_uuid: &str, data: &[u8]) {
if let Some(user) = self.map.get(user_uuid).map(|v| v.clone()) {
for (_, sender) in user.iter() {
if sender.send(Message::binary(data)).await.is_err() {
// TODO: Delete from map here too?
// NOTE: The last modified date needs to be updated before calling these methods
pub async fn send_user_update(&self, ut: UpdateType, user: &User) {
let data = create_update(
vec![("UserId".into(), user.uuid.clone().into()), ("Date".into(), serialize_date(user.updated_at))],
self.send_update(&user.uuid, &data).await;
pub async fn send_folder_update(&self, ut: UpdateType, folder: &Folder) {
let data = create_update(
("Id".into(), folder.uuid.clone().into()),
("UserId".into(), folder.user_uuid.clone().into()),
("RevisionDate".into(), serialize_date(folder.updated_at)),
self.send_update(&folder.user_uuid, &data).await;
pub async fn send_cipher_update(&self, ut: UpdateType, cipher: &Cipher, user_uuids: &[String]) {
let user_uuid = convert_option(cipher.user_uuid.clone());
let org_uuid = convert_option(cipher.organization_uuid.clone());
let data = create_update(
("Id".into(), cipher.uuid.clone().into()),
("UserId".into(), user_uuid),
("OrganizationId".into(), org_uuid),
("CollectionIds".into(), Value::Nil),
("RevisionDate".into(), serialize_date(cipher.updated_at)),
for uuid in user_uuids {
self.send_update(uuid, &data).await;
pub async fn send_send_update(&self, ut: UpdateType, send: &Send, user_uuids: &[String]) {
let user_uuid = convert_option(send.user_uuid.clone());
let data = create_update(
("Id".into(), send.uuid.clone().into()),
("UserId".into(), user_uuid),
("RevisionDate".into(), serialize_date(send.revision_date)),
for uuid in user_uuids {
self.send_update(uuid, &data).await;
/* Message Structure
1, // MessageType.Invocation
{}, // Headers (map)
null, // InvocationId
"ReceiveMessage", // Target
[ // Arguments
"ContextId": "app_id",
"Type": ut as i32,
"Payload": {}
fn create_update(payload: Vec<(Value, Value)>, ut: UpdateType) -> Vec<u8> {
use rmpv::Value as V;
let value = V::Array(vec![
("ContextId".into(), "app_id".into()),
("Type".into(), (ut as i32).into()),
("Payload".into(), payload.into()),
fn create_ping() -> Vec<u8> {
#[derive(Eq, PartialEq)]
pub enum UpdateType {
CipherUpdate = 0,
CipherCreate = 1,
LoginDelete = 2,
FolderDelete = 3,
Ciphers = 4,
Vault = 5,
OrgKeys = 6,
FolderCreate = 7,
FolderUpdate = 8,
CipherDelete = 9,
SyncSettings = 10,
LogOut = 11,
SyncSendCreate = 12,
SyncSendUpdate = 13,
SyncSendDelete = 14,
None = 100,
pub type Notify<'a> = &'a rocket::State<WebSocketUsers>;
pub fn start_notification_server() -> WebSocketUsers {
let users = WebSocketUsers {
map: Arc::new(dashmap::DashMap::new()),
if CONFIG.websocket_enabled() {
let users2 = users.clone();
tokio::spawn(async move {
let addr = (CONFIG.websocket_address(), CONFIG.websocket_port());
info!("Starting WebSockets server on {}:{}", addr.0, addr.1);
let listener = TcpListener::bind(addr).await.expect("Can't listen on websocket port");
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
loop {
tokio::select! {
Ok((stream, addr)) = listener.accept() => {
tokio::spawn(handle_connection(stream, users2.clone(), addr));
_ = &mut shutdown_rx => {
info!("Shutting down WebSockets server!")
async fn handle_connection(stream: TcpStream, users: WebSocketUsers, addr: SocketAddr) -> Result<(), Error> {
let mut user_uuid: Option<String> = None;
info!("Accepting WS connection from {addr}");
// Accept connection, do initial handshake, validate auth token and get the user ID
use handshake::server::{Request, Response};
let mut stream = accept_hdr_async(stream, |req: &Request, res: Response| {
if let Some(token) = get_request_token(req) {
if let Ok(claims) = crate::auth::decode_login(&token) {
user_uuid = Some(claims.sub);
return Ok(res);
let user_uuid = user_uuid.expect("User UUID should be set after the handshake");
// Add a channel to send messages to this client to the map
let entry_uuid = uuid::Uuid::new_v4();
let (tx, mut rx) = tokio::sync::mpsc::channel(100);
users.map.entry(user_uuid.clone()).or_default().push((entry_uuid, tx));
let mut interval = tokio::time::interval(Duration::from_secs(15));
loop {
tokio::select! {
res = stream.next() => {
match res {
Some(Ok(message)) => {
// Respond to any pings
if let Message::Ping(ping) = message {
if stream.send(Message::Pong(ping)).await.is_err() {
} else if let Message::Pong(_) = message {
/* Ignored */
// We should receive an initial message with the protocol and version, and we will reply to it
if let Message::Text(ref message) = message {
let msg = message.strip_suffix(RECORD_SEPARATOR as char).unwrap_or(message);
if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) {
// Just echo anything else the client sends
if stream.send(message).await.is_err() {
_ => break,
res = rx.recv() => {
match res {
Some(res) => {
if stream.send(res).await.is_err() {
None => break,
_= interval.tick() => {
if stream.send(Message::Ping(create_ping())).await.is_err() {
info!("Closing WS connection from {addr}");
// Delete from map
users.map.entry(user_uuid).or_default().retain(|(uuid, _)| uuid != &entry_uuid);
fn get_request_token(req: &handshake::server::Request) -> Option<String> {
const ACCESS_TOKEN_KEY: &str = "access_token=";
if let Some(Ok(auth)) = req.headers().get("Authorization").map(|a| a.to_str()) {
if let Some(token_part) = auth.strip_prefix("Bearer ") {
return Some(token_part.to_owned());
if let Some(params) = req.uri().query() {
let params_iter = params.split('&').take(1);
for val in params_iter {
if let Some(stripped) = val.strip_prefix(ACCESS_TOKEN_KEY) {
return Some(stripped.to_owned());