Add websocket integration test (#2477)

This commit is contained in:
David Pedersen 2023-12-31 11:14:26 +01:00 committed by GitHub
parent 48d169016a
commit 71eedc6d6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 13 deletions

View file

@ -124,6 +124,7 @@ serde_json = "1.0"
time = { version = "0.3", features = ["serde-human-readable"] } time = { version = "0.3", features = ["serde-human-readable"] }
tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] } tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] }
tokio-stream = "0.1" tokio-stream = "0.1"
tokio-tungstenite = "0.21"
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["json"] } tracing-subscriber = { version = "0.3", features = ["json"] }
uuid = { version = "1.0", features = ["serde", "v4"] } uuid = { version = "1.0", features = ["serde", "v4"] }

View file

@ -830,9 +830,12 @@ pub mod close_code {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::future::ready;
use super::*; use super::*;
use crate::{body::Body, routing::get, Router}; use crate::{body::Body, routing::get, test_helpers::spawn_service, Router};
use http::{Request, Version}; use http::{Request, Version};
use tokio_tungstenite::tungstenite;
use tower::ServiceExt; use tower::ServiceExt;
#[crate::test] #[crate::test]
@ -877,4 +880,47 @@ mod tests {
} }
let _: Router = Router::new().route("/", get(handler)); let _: Router = Router::new().route("/", get(handler));
} }
#[crate::test]
async fn integration_test() {
let app = Router::new().route(
"/echo",
get(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
);
async fn handle_socket(mut socket: WebSocket) {
while let Some(Ok(msg)) = socket.recv().await {
match msg {
Message::Text(_) | Message::Binary(_) | Message::Close(_) => {
if socket.send(msg).await.is_err() {
break;
}
}
Message::Ping(_) | Message::Pong(_) => {
// tungstenite will respond to pings automatically
}
}
}
}
let addr = spawn_service(app);
let (mut socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
.await
.unwrap();
let input = tungstenite::Message::Text("foobar".to_owned());
socket.send(input.clone()).await.unwrap();
let output = socket.next().await.unwrap().unwrap();
assert_eq!(input, output);
socket
.send(tungstenite::Message::Ping("ping".to_owned().into_bytes()))
.await
.unwrap();
let output = socket.next().await.unwrap().unwrap();
assert_eq!(
output,
tungstenite::Message::Pong("ping".to_owned().into_bytes())
);
}
} }

View file

@ -10,6 +10,27 @@ use tokio::net::TcpListener;
use tower::make::Shared; use tower::make::Shared;
use tower_service::Service; use tower_service::Service;
pub(crate) fn spawn_service<S>(svc: S) -> SocketAddr
where
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
std_listener.set_nonblocking(true).unwrap();
let listener = TcpListener::from_std(std_listener).unwrap();
let addr = listener.local_addr().unwrap();
println!("Listening on {addr}");
tokio::spawn(async move {
serve(listener, Shared::new(svc))
.await
.expect("server error")
});
addr
}
pub(crate) struct TestClient { pub(crate) struct TestClient {
client: reqwest::Client, client: reqwest::Client,
addr: SocketAddr, addr: SocketAddr,
@ -21,18 +42,7 @@ impl TestClient {
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static, S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send, S::Future: Send,
{ {
let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); let addr = spawn_service(svc);
std_listener.set_nonblocking(true).unwrap();
let listener = TcpListener::from_std(std_listener).unwrap();
let addr = listener.local_addr().unwrap();
println!("Listening on {addr}");
tokio::spawn(async move {
serve(listener, Shared::new(svc))
.await
.expect("server error")
});
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none()) .redirect(reqwest::redirect::Policy::none())