mirror of
https://github.com/tokio-rs/axum.git
synced 2025-04-26 13:56:22 +02:00
Properly respond with sec-websocket-protocol under http/2 (#3141)
This commit is contained in:
parent
0e6e96fb8c
commit
6c9cabf985
1 changed files with 32 additions and 18 deletions
|
@ -338,7 +338,7 @@ impl<F> WebSocketUpgrade<F> {
|
|||
callback(socket).await;
|
||||
});
|
||||
|
||||
if let Some(sec_websocket_key) = &self.sec_websocket_key {
|
||||
let mut response = if let Some(sec_websocket_key) = &self.sec_websocket_key {
|
||||
// If `sec_websocket_key` was `Some`, we are using HTTP/1.1.
|
||||
|
||||
#[allow(clippy::declare_interior_mutable_const)]
|
||||
|
@ -346,26 +346,30 @@ impl<F> WebSocketUpgrade<F> {
|
|||
#[allow(clippy::declare_interior_mutable_const)]
|
||||
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
|
||||
|
||||
let mut builder = Response::builder()
|
||||
Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(header::CONNECTION, UPGRADE)
|
||||
.header(header::UPGRADE, WEBSOCKET)
|
||||
.header(
|
||||
header::SEC_WEBSOCKET_ACCEPT,
|
||||
sign(sec_websocket_key.as_bytes()),
|
||||
);
|
||||
|
||||
if let Some(protocol) = self.protocol {
|
||||
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
||||
}
|
||||
|
||||
builder.body(Body::empty()).unwrap()
|
||||
)
|
||||
.body(Body::empty())
|
||||
.unwrap()
|
||||
} else {
|
||||
// Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond
|
||||
// with a 2XX with an empty body:
|
||||
// <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
|
||||
Response::new(Body::empty())
|
||||
};
|
||||
|
||||
if let Some(protocol) = self.protocol {
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1092,10 +1096,11 @@ mod tests {
|
|||
#[crate::test]
|
||||
async fn integration_test() {
|
||||
let addr = spawn_service(echo_app());
|
||||
let (socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
|
||||
.await
|
||||
.unwrap();
|
||||
test_echo_app(socket).await;
|
||||
let uri = format!("ws://{addr}/echo").try_into().unwrap();
|
||||
let req = tungstenite::client::ClientRequestBuilder::new(uri)
|
||||
.with_sub_protocol(TEST_ECHO_APP_REQ_SUBPROTO);
|
||||
let (socket, response) = tokio_tungstenite::connect_async(req).await.unwrap();
|
||||
test_echo_app(socket, response.headers()).await;
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
|
@ -1123,21 +1128,22 @@ mod tests {
|
|||
.extension(hyper::ext::Protocol::from_static("websocket"))
|
||||
.uri("/echo")
|
||||
.header("sec-websocket-version", "13")
|
||||
.header("sec-websocket-protocol", TEST_ECHO_APP_REQ_SUBPROTO)
|
||||
.header("Host", "server.example.com")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = send_request.send_request(req).await.unwrap();
|
||||
let mut response = send_request.send_request(req).await.unwrap();
|
||||
let status = response.status();
|
||||
if status != 200 {
|
||||
let body = response.into_body().collect().await.unwrap().to_bytes();
|
||||
let body = std::str::from_utf8(&body).unwrap();
|
||||
panic!("response status was {status}: {body}");
|
||||
}
|
||||
let upgraded = hyper::upgrade::on(response).await.unwrap();
|
||||
let upgraded = hyper::upgrade::on(&mut response).await.unwrap();
|
||||
let upgraded = TokioIo::new(upgraded);
|
||||
let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
|
||||
test_echo_app(socket).await;
|
||||
test_echo_app(socket, response.headers()).await;
|
||||
}
|
||||
|
||||
fn echo_app() -> Router {
|
||||
|
@ -1158,11 +1164,19 @@ mod tests {
|
|||
|
||||
Router::new().route(
|
||||
"/echo",
|
||||
any(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
|
||||
any(|ws: WebSocketUpgrade| {
|
||||
ready(ws.protocols(["echo2", "echo"]).on_upgrade(handle_socket))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(mut socket: WebSocketStream<S>) {
|
||||
const TEST_ECHO_APP_REQ_SUBPROTO: &str = "echo3, echo";
|
||||
async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
mut socket: WebSocketStream<S>,
|
||||
headers: &http::HeaderMap,
|
||||
) {
|
||||
assert_eq!(headers[http::header::SEC_WEBSOCKET_PROTOCOL], "echo");
|
||||
|
||||
let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar"));
|
||||
socket.send(input.clone()).await.unwrap();
|
||||
let output = socket.next().await.unwrap().unwrap();
|
||||
|
|
Loading…
Add table
Reference in a new issue