Fix websockets failing on Firefox (#76)

Axum expected the `Connection` header to be _exactly_ `upgrade`. Turns
out thats a bit too strict as this didn't work in Firefox.

Turns out `Connection` just has to contain `upgrade`. At least that is
what [warp does](https://github.com/seanmonstar/warp/blob/master/src/filters/ws.rs#L46).
This commit is contained in:
David Pedersen 2021-08-01 21:00:38 +02:00 committed by GitHub
parent 10bedca796
commit 69ae7a686a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 12 deletions

View file

@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Improve documentation for routing ([#71](https://github.com/tokio-rs/axum/pull/71))
- Clarify required response body type when routing to `tower::Service`s ([#69](https://github.com/tokio-rs/axum/pull/69))
- Add `axum::body::box_body` to converting an `http_body::Body` to `axum::body::BoxBody` ([#69](https://github.com/tokio-rs/axum/pull/69))
- Fix WebSockets failing on Firefox ([#76](https://github.com/tokio-rs/axum/pull/76))
## Breaking changes

View file

@ -236,29 +236,21 @@ where
return response(StatusCode::NOT_FOUND, "Request method must be `GET`");
}
if !header_eq(
&req,
header::CONNECTION,
HeaderValue::from_static("upgrade"),
) {
if !header_contains(&req, header::CONNECTION, "upgrade") {
return response(
StatusCode::BAD_REQUEST,
"Connection header did not include 'upgrade'",
);
}
if !header_eq(&req, header::UPGRADE, HeaderValue::from_static("websocket")) {
if !header_eq(&req, header::UPGRADE, "websocket") {
return response(
StatusCode::BAD_REQUEST,
"`Upgrade` header did not include 'websocket'",
);
}
if !header_eq(
&req,
header::SEC_WEBSOCKET_VERSION,
HeaderValue::from_static("13"),
) {
if !header_eq(&req, header::SEC_WEBSOCKET_VERSION, "13") {
return response(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Version` header did not include '13'",
@ -320,6 +312,8 @@ where
}
fn response<E>(status: StatusCode, body: &'static str) -> Result<Response<BoxBody>, E> {
dbg!((status, body));
let res = Response::builder()
.status(status)
.body(box_body(Full::from(body)))
@ -327,7 +321,7 @@ fn response<E>(status: StatusCode, body: &'static str) -> Result<Response<BoxBod
Ok(res)
}
fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
fn header_eq<B>(req: &Request<B>, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = req.headers().get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
} else {
@ -335,6 +329,20 @@ fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
}
}
fn header_contains<B>(req: &Request<B>, key: HeaderName, value: &'static str) -> bool {
let header = if let Some(header) = req.headers().get(&key) {
header
} else {
return false;
};
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
header.to_ascii_lowercase().contains(value)
} else {
false
}
}
fn sign(key: &[u8]) -> HeaderValue {
let mut sha1 = Sha1::default();
sha1.update(key);