mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-22 07:08:16 +01:00
4194cf70da
Fixes https://github.com/tokio-rs/axum/issues/111 Example usage: ```rust use axum::{ prelude::*, extract::ws::{WebSocketUpgrade, WebSocket}, response::IntoResponse, }; let app = route("/ws", get(handler)); async fn handler(ws: WebSocketUpgrade) -> impl IntoResponse { ws.on_upgrade(handle_socket) } async fn handle_socket(mut socket: WebSocket) { while let Some(msg) = socket.recv().await { let msg = if let Ok(msg) = msg { msg } else { // client disconnected return; }; if socket.send(msg).await.is_err() { // client disconnected return; } } } ```
137 lines
4 KiB
Rust
137 lines
4 KiB
Rust
//! Example chat application.
|
|
//!
|
|
//! Run with
|
|
//!
|
|
//! ```not_rust
|
|
//! cargo run --features=ws --example chat
|
|
//! ```
|
|
|
|
use std::collections::HashSet;
|
|
use std::net::SocketAddr;
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
use futures::{sink::SinkExt, stream::StreamExt};
|
|
|
|
use tokio::sync::broadcast;
|
|
|
|
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
|
|
use axum::prelude::*;
|
|
use axum::response::{Html, IntoResponse};
|
|
use axum::AddExtensionLayer;
|
|
|
|
// Our shared state
|
|
struct AppState {
|
|
user_set: Mutex<HashSet<String>>,
|
|
tx: broadcast::Sender<String>,
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() {
|
|
let user_set = Mutex::new(HashSet::new());
|
|
let (tx, _rx) = broadcast::channel(100);
|
|
|
|
let app_state = Arc::new(AppState { user_set, tx });
|
|
|
|
let app = route("/", get(index))
|
|
.route("/websocket", get(websocket_handler))
|
|
.layer(AddExtensionLayer::new(app_state));
|
|
|
|
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
|
|
|
axum::Server::bind(&addr)
|
|
.serve(app.into_make_service())
|
|
.await
|
|
.unwrap();
|
|
}
|
|
|
|
async fn websocket_handler(
|
|
ws: WebSocketUpgrade,
|
|
extract::Extension(state): extract::Extension<Arc<AppState>>,
|
|
) -> impl IntoResponse {
|
|
ws.on_upgrade(|socket| websocket(socket, state))
|
|
}
|
|
|
|
async fn websocket(stream: WebSocket, state: Arc<AppState>) {
|
|
// By splitting we can send and receive at the same time.
|
|
let (mut sender, mut receiver) = stream.split();
|
|
|
|
// Username gets set in the receive loop, if its valid
|
|
let mut username = String::new();
|
|
|
|
// Loop until a text message is found.
|
|
while let Some(Ok(msg)) = receiver.next().await {
|
|
if let Some(name) = msg.to_str() {
|
|
// If username that is sent by client is not taken, fill username string.
|
|
check_username(&state, &mut username, name);
|
|
|
|
// If not empty we want to quit the loop else we want to quit function.
|
|
if !username.is_empty() {
|
|
break;
|
|
} else {
|
|
// Only send our client that username is taken.
|
|
let _ = sender.send(Message::text("Username already taken.")).await;
|
|
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Subscribe before sending joined message.
|
|
let mut rx = state.tx.subscribe();
|
|
|
|
// Send joined message to all subscribers.
|
|
let msg = format!("{} joined.", username);
|
|
let _ = state.tx.send(msg);
|
|
|
|
// This task will receive broadcast messages and send text message to our client.
|
|
let mut send_task = tokio::spawn(async move {
|
|
while let Ok(msg) = rx.recv().await {
|
|
// In any websocket error, break loop.
|
|
if sender.send(Message::text(msg)).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
|
|
// Clone things we want to pass to the receiving task.
|
|
let tx = state.tx.clone();
|
|
let name = username.clone();
|
|
|
|
// This task will receive messages from client and send them to broadcast subscribers.
|
|
let mut recv_task = tokio::spawn(async move {
|
|
while let Some(Ok(msg)) = receiver.next().await {
|
|
if let Some(text) = msg.to_str() {
|
|
// Add username before message.
|
|
let _ = tx.send(format!("{}: {}", name, text));
|
|
}
|
|
}
|
|
});
|
|
|
|
// If any one of the tasks exit, abort the other.
|
|
tokio::select! {
|
|
_ = (&mut send_task) => recv_task.abort(),
|
|
_ = (&mut recv_task) => send_task.abort(),
|
|
};
|
|
|
|
// Send user left message.
|
|
let msg = format!("{} left.", username);
|
|
let _ = state.tx.send(msg);
|
|
|
|
// Remove username from map so new clients can take it.
|
|
state.user_set.lock().unwrap().remove(&username);
|
|
}
|
|
|
|
fn check_username(state: &AppState, string: &mut String, name: &str) {
|
|
let mut user_set = state.user_set.lock().unwrap();
|
|
|
|
if !user_set.contains(name) {
|
|
user_set.insert(name.to_owned());
|
|
|
|
string.push_str(name);
|
|
}
|
|
}
|
|
|
|
// Include utf-8 file at **compile** time.
|
|
async fn index() -> Html<&'static str> {
|
|
Html(std::include_str!("chat/chat.html"))
|
|
}
|