axum/examples/chat.rs
Grzegorz Baranski 4792d0c15c
Make ws::Message an enum for easier frame type matching (#116)
* feat(ws): make Message an enum to allow pattern matching

* fix(examples): update to new websockets `Message`

* fix(ws): remove wildcard imports

* fix(examples/chat): apply clippy's never_loop

* style: `cargo fmt`

* docs:add license notes above parts that are copied

* fix(ws): make CloseCode an alias to u16

* fix: move Message from src/ws/mod.rs to src/extract/ws.rs

* docs: add changelog entry about websocket messages

* fix: remove useless convertions to the same type
2021-08-07 19:47:22 +02:00

137 lines
4 KiB
Rust

//! Example chat application.
//!
//! Run with
//!
//! ```not_rust
//! cargo run --features=ws --example chat
//! ```
use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use futures::{sink::SinkExt, stream::StreamExt};
use tokio::sync::broadcast;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::prelude::*;
use axum::response::{Html, IntoResponse};
use axum::AddExtensionLayer;
// Our shared state
struct AppState {
user_set: Mutex<HashSet<String>>,
tx: broadcast::Sender<String>,
}
#[tokio::main]
async fn main() {
let user_set = Mutex::new(HashSet::new());
let (tx, _rx) = broadcast::channel(100);
let app_state = Arc::new(AppState { user_set, tx });
let app = route("/", get(index))
.route("/websocket", get(websocket_handler))
.layer(AddExtensionLayer::new(app_state));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
async fn websocket_handler(
ws: WebSocketUpgrade,
extract::Extension(state): extract::Extension<Arc<AppState>>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| websocket(socket, state))
}
async fn websocket(stream: WebSocket, state: Arc<AppState>) {
// By splitting we can send and receive at the same time.
let (mut sender, mut receiver) = stream.split();
// Username gets set in the receive loop, if its valid
let mut username = String::new();
// Loop until a text message is found.
while let Some(Ok(message)) = receiver.next().await {
if let Message::Text(name) = message {
// If username that is sent by client is not taken, fill username string.
check_username(&state, &mut username, &name);
// If not empty we want to quit the loop else we want to quit function.
if !username.is_empty() {
break;
} else {
// Only send our client that username is taken.
let _ = sender
.send(Message::Text(String::from("Username already taken.")))
.await;
return;
}
}
}
// Subscribe before sending joined message.
let mut rx = state.tx.subscribe();
// Send joined message to all subscribers.
let msg = format!("{} joined.", username);
let _ = state.tx.send(msg);
// This task will receive broadcast messages and send text message to our client.
let mut send_task = tokio::spawn(async move {
while let Ok(msg) = rx.recv().await {
// In any websocket error, break loop.
if sender.send(Message::Text(msg)).await.is_err() {
break;
}
}
});
// Clone things we want to pass to the receiving task.
let tx = state.tx.clone();
let name = username.clone();
// This task will receive messages from client and send them to broadcast subscribers.
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(Message::Text(text))) = receiver.next().await {
// Add username before message.
let _ = tx.send(format!("{}: {}", name, text));
}
});
// If any one of the tasks exit, abort the other.
tokio::select! {
_ = (&mut send_task) => recv_task.abort(),
_ = (&mut recv_task) => send_task.abort(),
};
// Send user left message.
let msg = format!("{} left.", username);
let _ = state.tx.send(msg);
// Remove username from map so new clients can take it.
state.user_set.lock().unwrap().remove(&username);
}
fn check_username(state: &AppState, string: &mut String, name: &str) {
let mut user_set = state.user_set.lock().unwrap();
if !user_set.contains(name) {
user_set.insert(name.to_owned());
string.push_str(name);
}
}
// Include utf-8 file at **compile** time.
async fn index() -> Html<&'static str> {
Html(std::include_str!("chat/chat.html"))
}