mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-21 14:46:32 +01:00
Add websocket integration test (#2477)
This commit is contained in:
parent
48d169016a
commit
71eedc6d6c
3 changed files with 70 additions and 13 deletions
|
@ -124,6 +124,7 @@ serde_json = "1.0"
|
|||
time = { version = "0.3", features = ["serde-human-readable"] }
|
||||
tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] }
|
||||
tokio-stream = "0.1"
|
||||
tokio-tungstenite = "0.21"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["json"] }
|
||||
uuid = { version = "1.0", features = ["serde", "v4"] }
|
||||
|
|
|
@ -830,9 +830,12 @@ pub mod close_code {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::future::ready;
|
||||
|
||||
use super::*;
|
||||
use crate::{body::Body, routing::get, Router};
|
||||
use crate::{body::Body, routing::get, test_helpers::spawn_service, Router};
|
||||
use http::{Request, Version};
|
||||
use tokio_tungstenite::tungstenite;
|
||||
use tower::ServiceExt;
|
||||
|
||||
#[crate::test]
|
||||
|
@ -877,4 +880,47 @@ mod tests {
|
|||
}
|
||||
let _: Router = Router::new().route("/", get(handler));
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn integration_test() {
|
||||
let app = Router::new().route(
|
||||
"/echo",
|
||||
get(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
|
||||
);
|
||||
|
||||
async fn handle_socket(mut socket: WebSocket) {
|
||||
while let Some(Ok(msg)) = socket.recv().await {
|
||||
match msg {
|
||||
Message::Text(_) | Message::Binary(_) | Message::Close(_) => {
|
||||
if socket.send(msg).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Message::Ping(_) | Message::Pong(_) => {
|
||||
// tungstenite will respond to pings automatically
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let addr = spawn_service(app);
|
||||
let (mut socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let input = tungstenite::Message::Text("foobar".to_owned());
|
||||
socket.send(input.clone()).await.unwrap();
|
||||
let output = socket.next().await.unwrap().unwrap();
|
||||
assert_eq!(input, output);
|
||||
|
||||
socket
|
||||
.send(tungstenite::Message::Ping("ping".to_owned().into_bytes()))
|
||||
.await
|
||||
.unwrap();
|
||||
let output = socket.next().await.unwrap().unwrap();
|
||||
assert_eq!(
|
||||
output,
|
||||
tungstenite::Message::Pong("ping".to_owned().into_bytes())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,27 @@ use tokio::net::TcpListener;
|
|||
use tower::make::Shared;
|
||||
use tower_service::Service;
|
||||
|
||||
pub(crate) fn spawn_service<S>(svc: S) -> SocketAddr
|
||||
where
|
||||
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
||||
S::Future: Send,
|
||||
{
|
||||
let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
std_listener.set_nonblocking(true).unwrap();
|
||||
let listener = TcpListener::from_std(std_listener).unwrap();
|
||||
|
||||
let addr = listener.local_addr().unwrap();
|
||||
println!("Listening on {addr}");
|
||||
|
||||
tokio::spawn(async move {
|
||||
serve(listener, Shared::new(svc))
|
||||
.await
|
||||
.expect("server error")
|
||||
});
|
||||
|
||||
addr
|
||||
}
|
||||
|
||||
pub(crate) struct TestClient {
|
||||
client: reqwest::Client,
|
||||
addr: SocketAddr,
|
||||
|
@ -21,18 +42,7 @@ impl TestClient {
|
|||
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
||||
S::Future: Send,
|
||||
{
|
||||
let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
std_listener.set_nonblocking(true).unwrap();
|
||||
let listener = TcpListener::from_std(std_listener).unwrap();
|
||||
|
||||
let addr = listener.local_addr().unwrap();
|
||||
println!("Listening on {addr}");
|
||||
|
||||
tokio::spawn(async move {
|
||||
serve(listener, Shared::new(svc))
|
||||
.await
|
||||
.expect("server error")
|
||||
});
|
||||
let addr = spawn_service(svc);
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
|
|
Loading…
Reference in a new issue