mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-11 12:31:25 +01:00
Fix websockets failing on Firefox (#76)
Axum expected the `Connection` header to be _exactly_ `upgrade`. Turns out thats a bit too strict as this didn't work in Firefox. Turns out `Connection` just has to contain `upgrade`. At least that is what [warp does](https://github.com/seanmonstar/warp/blob/master/src/filters/ws.rs#L46).
This commit is contained in:
parent
10bedca796
commit
69ae7a686a
2 changed files with 21 additions and 12 deletions
|
@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- Improve documentation for routing ([#71](https://github.com/tokio-rs/axum/pull/71))
|
||||
- Clarify required response body type when routing to `tower::Service`s ([#69](https://github.com/tokio-rs/axum/pull/69))
|
||||
- Add `axum::body::box_body` to converting an `http_body::Body` to `axum::body::BoxBody` ([#69](https://github.com/tokio-rs/axum/pull/69))
|
||||
- Fix WebSockets failing on Firefox ([#76](https://github.com/tokio-rs/axum/pull/76))
|
||||
|
||||
## Breaking changes
|
||||
|
||||
|
|
|
@ -236,29 +236,21 @@ where
|
|||
return response(StatusCode::NOT_FOUND, "Request method must be `GET`");
|
||||
}
|
||||
|
||||
if !header_eq(
|
||||
&req,
|
||||
header::CONNECTION,
|
||||
HeaderValue::from_static("upgrade"),
|
||||
) {
|
||||
if !header_contains(&req, header::CONNECTION, "upgrade") {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Connection header did not include 'upgrade'",
|
||||
);
|
||||
}
|
||||
|
||||
if !header_eq(&req, header::UPGRADE, HeaderValue::from_static("websocket")) {
|
||||
if !header_eq(&req, header::UPGRADE, "websocket") {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Upgrade` header did not include 'websocket'",
|
||||
);
|
||||
}
|
||||
|
||||
if !header_eq(
|
||||
&req,
|
||||
header::SEC_WEBSOCKET_VERSION,
|
||||
HeaderValue::from_static("13"),
|
||||
) {
|
||||
if !header_eq(&req, header::SEC_WEBSOCKET_VERSION, "13") {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Version` header did not include '13'",
|
||||
|
@ -320,6 +312,8 @@ where
|
|||
}
|
||||
|
||||
fn response<E>(status: StatusCode, body: &'static str) -> Result<Response<BoxBody>, E> {
|
||||
dbg!((status, body));
|
||||
|
||||
let res = Response::builder()
|
||||
.status(status)
|
||||
.body(box_body(Full::from(body)))
|
||||
|
@ -327,7 +321,7 @@ fn response<E>(status: StatusCode, body: &'static str) -> Result<Response<BoxBod
|
|||
Ok(res)
|
||||
}
|
||||
|
||||
fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
|
||||
fn header_eq<B>(req: &Request<B>, key: HeaderName, value: &'static str) -> bool {
|
||||
if let Some(header) = req.headers().get(&key) {
|
||||
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
|
||||
} else {
|
||||
|
@ -335,6 +329,20 @@ fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
|
|||
}
|
||||
}
|
||||
|
||||
fn header_contains<B>(req: &Request<B>, key: HeaderName, value: &'static str) -> bool {
|
||||
let header = if let Some(header) = req.headers().get(&key) {
|
||||
header
|
||||
} else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
|
||||
header.to_ascii_lowercase().contains(value)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn sign(key: &[u8]) -> HeaderValue {
|
||||
let mut sha1 = Sha1::default();
|
||||
sha1.update(key);
|
||||
|
|
Loading…
Reference in a new issue