Add on_failed_upgrade callback to WebSocketUpgrade (#1539)

* Add `on_failed_upgrade` callback to `WebSocket`

Previously if upgrading a connection to a WebSocket connection failed
the background task would panic. There was no way to customize that so
users that might wanna report the error was out of luck.

Panicking also wasn't great because users might abort on panics which
would bring down the server.

* changelog

* Apply suggestions from code review

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
David Pedersen 2022-11-18 11:53:04 +01:00 committed by GitHub
parent 7090649377
commit ba8e9c1b21
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 117 additions and 8 deletions

View file

@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **fixed:** Nested routers will now inherit fallbacks from outer routers ([#1521])
- **added:** Add `accept_unmasked_frames` setting in WebSocketUpgrade ([#1529])
- **added:** Add `WebSocketUpgrade::on_failed_upgrade` to customize what to do
when upgrading a connection fails ([#1539])
[#1539]: https://github.com/tokio-rs/axum/pull/1539
[#1521]: https://github.com/tokio-rs/axum/pull/1521

View file

@ -134,18 +134,29 @@ use tokio_tungstenite::{
/// rejected.
///
/// See the [module docs](self) for an example.
#[derive(Debug)]
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub struct WebSocketUpgrade {
pub struct WebSocketUpgrade<F = DefaultOnFailedUpdgrade> {
config: WebSocketConfig,
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
protocol: Option<HeaderValue>,
sec_websocket_key: HeaderValue,
on_upgrade: OnUpgrade,
on_failed_upgrade: F,
sec_websocket_protocol: Option<HeaderValue>,
}
impl WebSocketUpgrade {
impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebSocketUpgrade")
.field("config", &self.config)
.field("protocol", &self.protocol)
.field("sec_websocket_key", &self.sec_websocket_key)
.field("sec_websocket_protocol", &self.sec_websocket_protocol)
.finish_non_exhaustive()
}
}
impl<F> WebSocketUpgrade<F> {
/// Set the size of the internal message send queue.
pub fn max_send_queue(mut self, max: usize) -> Self {
self.config.max_send_queue = Some(max);
@ -231,24 +242,71 @@ impl WebSocketUpgrade {
self
}
/// Provide a callback to call if upgrading the connection fails.
///
/// The connection upgrade is performed in a background task. If that fails this callback
/// will be called.
///
/// By default any errors will be silently ignored.
///
/// # Example
///
/// ```
/// use axum::{
/// extract::{WebSocketUpgrade},
/// response::Response,
/// };
///
/// async fn handler(ws: WebSocketUpgrade) -> Response {
/// ws.on_failed_upgrade(|error| {
/// report_error(error);
/// })
/// .on_upgrade(|socket| async { /* ... */ })
/// }
/// #
/// # fn report_error(_: axum::Error) {}
/// ```
pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
where
C: OnFailedUpdgrade,
{
WebSocketUpgrade {
config: self.config,
protocol: self.protocol,
sec_websocket_key: self.sec_websocket_key,
on_upgrade: self.on_upgrade,
on_failed_upgrade: callback,
sec_websocket_protocol: self.sec_websocket_protocol,
}
}
/// Finalize upgrading the connection and call the provided callback with
/// the stream.
///
/// When using `WebSocketUpgrade`, the response produced by this method
/// should be returned from the handler. See the [module docs](self) for an
/// example.
pub fn on_upgrade<F, Fut>(self, callback: F) -> Response
pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
where
F: FnOnce(WebSocket) -> Fut + Send + 'static,
C: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
F: OnFailedUpdgrade,
{
let on_upgrade = self.on_upgrade;
let config = self.config;
let on_failed_upgrade = self.on_failed_upgrade;
let protocol = self.protocol.clone();
tokio::spawn(async move {
let upgraded = on_upgrade.await.expect("connection upgrade failed");
let upgraded = match on_upgrade.await {
Ok(upgraded) => upgraded,
Err(err) => {
on_failed_upgrade.call(Error::new(err));
return;
}
};
let socket =
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
.await;
@ -281,8 +339,37 @@ impl WebSocketUpgrade {
}
}
/// What to do when a connection upgrade fails.
///
/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
pub trait OnFailedUpdgrade: Send + 'static {
/// Call the callback.
fn call(self, error: Error);
}
impl<F> OnFailedUpdgrade for F
where
F: FnOnce(Error) + Send + 'static,
{
fn call(self, error: Error) {
self(error)
}
}
/// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`.
///
/// It simply ignores the error.
#[non_exhaustive]
#[derive(Debug)]
pub struct DefaultOnFailedUpdgrade;
impl OnFailedUpdgrade for DefaultOnFailedUpdgrade {
#[inline]
fn call(self, _error: Error) {}
}
#[async_trait]
impl<S> FromRequestParts<S> for WebSocketUpgrade
impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpdgrade>
where
S: Send + Sync,
{
@ -323,6 +410,7 @@ where
sec_websocket_key,
on_upgrade,
sec_websocket_protocol,
on_failed_upgrade: DefaultOnFailedUpdgrade,
})
}
}
@ -722,7 +810,7 @@ pub mod close_code {
#[cfg(test)]
mod tests {
use super::*;
use crate::{body::Body, routing::get};
use crate::{body::Body, routing::get, Router};
use http::{Request, Version};
use tower::ServiceExt;
@ -751,4 +839,21 @@ mod tests {
assert_eq!(res.status(), StatusCode::OK);
}
#[allow(dead_code)]
fn default_on_failed_upgrade() {
async fn handler(ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(|_| async {})
}
let _: Router = Router::new().route("/", get(handler));
}
#[allow(dead_code)]
fn on_failed_upgrade() {
async fn handler(ws: WebSocketUpgrade) -> Response {
ws.on_failed_upgrade(|_error: Error| println!("oops!"))
.on_upgrade(|_| async {})
}
let _: Router = Router::new().route("/", get(handler));
}
}