mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-24 16:17:56 +01:00
Add websocket integration test (#2477)
This commit is contained in:
parent
48d169016a
commit
71eedc6d6c
3 changed files with 70 additions and 13 deletions
|
@ -124,6 +124,7 @@ serde_json = "1.0"
|
||||||
time = { version = "0.3", features = ["serde-human-readable"] }
|
time = { version = "0.3", features = ["serde-human-readable"] }
|
||||||
tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] }
|
tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] }
|
||||||
tokio-stream = "0.1"
|
tokio-stream = "0.1"
|
||||||
|
tokio-tungstenite = "0.21"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = { version = "0.3", features = ["json"] }
|
tracing-subscriber = { version = "0.3", features = ["json"] }
|
||||||
uuid = { version = "1.0", features = ["serde", "v4"] }
|
uuid = { version = "1.0", features = ["serde", "v4"] }
|
||||||
|
|
|
@ -830,9 +830,12 @@ pub mod close_code {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use std::future::ready;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{body::Body, routing::get, Router};
|
use crate::{body::Body, routing::get, test_helpers::spawn_service, Router};
|
||||||
use http::{Request, Version};
|
use http::{Request, Version};
|
||||||
|
use tokio_tungstenite::tungstenite;
|
||||||
use tower::ServiceExt;
|
use tower::ServiceExt;
|
||||||
|
|
||||||
#[crate::test]
|
#[crate::test]
|
||||||
|
@ -877,4 +880,47 @@ mod tests {
|
||||||
}
|
}
|
||||||
let _: Router = Router::new().route("/", get(handler));
|
let _: Router = Router::new().route("/", get(handler));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[crate::test]
|
||||||
|
async fn integration_test() {
|
||||||
|
let app = Router::new().route(
|
||||||
|
"/echo",
|
||||||
|
get(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
|
||||||
|
);
|
||||||
|
|
||||||
|
async fn handle_socket(mut socket: WebSocket) {
|
||||||
|
while let Some(Ok(msg)) = socket.recv().await {
|
||||||
|
match msg {
|
||||||
|
Message::Text(_) | Message::Binary(_) | Message::Close(_) => {
|
||||||
|
if socket.send(msg).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Message::Ping(_) | Message::Pong(_) => {
|
||||||
|
// tungstenite will respond to pings automatically
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let addr = spawn_service(app);
|
||||||
|
let (mut socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let input = tungstenite::Message::Text("foobar".to_owned());
|
||||||
|
socket.send(input.clone()).await.unwrap();
|
||||||
|
let output = socket.next().await.unwrap().unwrap();
|
||||||
|
assert_eq!(input, output);
|
||||||
|
|
||||||
|
socket
|
||||||
|
.send(tungstenite::Message::Ping("ping".to_owned().into_bytes()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let output = socket.next().await.unwrap().unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
output,
|
||||||
|
tungstenite::Message::Pong("ping".to_owned().into_bytes())
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,27 @@ use tokio::net::TcpListener;
|
||||||
use tower::make::Shared;
|
use tower::make::Shared;
|
||||||
use tower_service::Service;
|
use tower_service::Service;
|
||||||
|
|
||||||
|
pub(crate) fn spawn_service<S>(svc: S) -> SocketAddr
|
||||||
|
where
|
||||||
|
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
||||||
|
S::Future: Send,
|
||||||
|
{
|
||||||
|
let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
|
||||||
|
std_listener.set_nonblocking(true).unwrap();
|
||||||
|
let listener = TcpListener::from_std(std_listener).unwrap();
|
||||||
|
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
println!("Listening on {addr}");
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
serve(listener, Shared::new(svc))
|
||||||
|
.await
|
||||||
|
.expect("server error")
|
||||||
|
});
|
||||||
|
|
||||||
|
addr
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) struct TestClient {
|
pub(crate) struct TestClient {
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
|
@ -21,18 +42,7 @@ impl TestClient {
|
||||||
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
||||||
S::Future: Send,
|
S::Future: Send,
|
||||||
{
|
{
|
||||||
let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
|
let addr = spawn_service(svc);
|
||||||
std_listener.set_nonblocking(true).unwrap();
|
|
||||||
let listener = TcpListener::from_std(std_listener).unwrap();
|
|
||||||
|
|
||||||
let addr = listener.local_addr().unwrap();
|
|
||||||
println!("Listening on {addr}");
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
serve(listener, Shared::new(svc))
|
|
||||||
.await
|
|
||||||
.expect("server error")
|
|
||||||
});
|
|
||||||
|
|
||||||
let client = reqwest::Client::builder()
|
let client = reqwest::Client::builder()
|
||||||
.redirect(reqwest::redirect::Policy::none())
|
.redirect(reqwest::redirect::Policy::none())
|
||||||
|
|
Loading…
Reference in a new issue