mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-01 00:50:32 +01:00
Add on_failed_upgrade
callback to WebSocketUpgrade
(#1539)
* Add `on_failed_upgrade` callback to `WebSocket` Previously if upgrading a connection to a WebSocket connection failed the background task would panic. There was no way to customize that so users that might wanna report the error was out of luck. Panicking also wasn't great because users might abort on panics which would bring down the server. * changelog * Apply suggestions from code review Co-authored-by: Jonas Platte <jplatte+git@posteo.de> Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
parent
7090649377
commit
ba8e9c1b21
2 changed files with 117 additions and 8 deletions
|
@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
- **fixed:** Nested routers will now inherit fallbacks from outer routers ([#1521])
|
||||
- **added:** Add `accept_unmasked_frames` setting in WebSocketUpgrade ([#1529])
|
||||
- **added:** Add `WebSocketUpgrade::on_failed_upgrade` to customize what to do
|
||||
when upgrading a connection fails ([#1539])
|
||||
|
||||
[#1539]: https://github.com/tokio-rs/axum/pull/1539
|
||||
|
||||
[#1521]: https://github.com/tokio-rs/axum/pull/1521
|
||||
|
||||
|
|
|
@ -134,18 +134,29 @@ use tokio_tungstenite::{
|
|||
/// rejected.
|
||||
///
|
||||
/// See the [module docs](self) for an example.
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
|
||||
pub struct WebSocketUpgrade {
|
||||
pub struct WebSocketUpgrade<F = DefaultOnFailedUpdgrade> {
|
||||
config: WebSocketConfig,
|
||||
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
|
||||
protocol: Option<HeaderValue>,
|
||||
sec_websocket_key: HeaderValue,
|
||||
on_upgrade: OnUpgrade,
|
||||
on_failed_upgrade: F,
|
||||
sec_websocket_protocol: Option<HeaderValue>,
|
||||
}
|
||||
|
||||
impl WebSocketUpgrade {
|
||||
impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("WebSocketUpgrade")
|
||||
.field("config", &self.config)
|
||||
.field("protocol", &self.protocol)
|
||||
.field("sec_websocket_key", &self.sec_websocket_key)
|
||||
.field("sec_websocket_protocol", &self.sec_websocket_protocol)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> WebSocketUpgrade<F> {
|
||||
/// Set the size of the internal message send queue.
|
||||
pub fn max_send_queue(mut self, max: usize) -> Self {
|
||||
self.config.max_send_queue = Some(max);
|
||||
|
@ -231,24 +242,71 @@ impl WebSocketUpgrade {
|
|||
self
|
||||
}
|
||||
|
||||
/// Provide a callback to call if upgrading the connection fails.
|
||||
///
|
||||
/// The connection upgrade is performed in a background task. If that fails this callback
|
||||
/// will be called.
|
||||
///
|
||||
/// By default any errors will be silently ignored.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use axum::{
|
||||
/// extract::{WebSocketUpgrade},
|
||||
/// response::Response,
|
||||
/// };
|
||||
///
|
||||
/// async fn handler(ws: WebSocketUpgrade) -> Response {
|
||||
/// ws.on_failed_upgrade(|error| {
|
||||
/// report_error(error);
|
||||
/// })
|
||||
/// .on_upgrade(|socket| async { /* ... */ })
|
||||
/// }
|
||||
/// #
|
||||
/// # fn report_error(_: axum::Error) {}
|
||||
/// ```
|
||||
pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
|
||||
where
|
||||
C: OnFailedUpdgrade,
|
||||
{
|
||||
WebSocketUpgrade {
|
||||
config: self.config,
|
||||
protocol: self.protocol,
|
||||
sec_websocket_key: self.sec_websocket_key,
|
||||
on_upgrade: self.on_upgrade,
|
||||
on_failed_upgrade: callback,
|
||||
sec_websocket_protocol: self.sec_websocket_protocol,
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalize upgrading the connection and call the provided callback with
|
||||
/// the stream.
|
||||
///
|
||||
/// When using `WebSocketUpgrade`, the response produced by this method
|
||||
/// should be returned from the handler. See the [module docs](self) for an
|
||||
/// example.
|
||||
pub fn on_upgrade<F, Fut>(self, callback: F) -> Response
|
||||
pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
|
||||
where
|
||||
F: FnOnce(WebSocket) -> Fut + Send + 'static,
|
||||
C: FnOnce(WebSocket) -> Fut + Send + 'static,
|
||||
Fut: Future<Output = ()> + Send + 'static,
|
||||
F: OnFailedUpdgrade,
|
||||
{
|
||||
let on_upgrade = self.on_upgrade;
|
||||
let config = self.config;
|
||||
let on_failed_upgrade = self.on_failed_upgrade;
|
||||
|
||||
let protocol = self.protocol.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let upgraded = on_upgrade.await.expect("connection upgrade failed");
|
||||
let upgraded = match on_upgrade.await {
|
||||
Ok(upgraded) => upgraded,
|
||||
Err(err) => {
|
||||
on_failed_upgrade.call(Error::new(err));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let socket =
|
||||
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
|
||||
.await;
|
||||
|
@ -281,8 +339,37 @@ impl WebSocketUpgrade {
|
|||
}
|
||||
}
|
||||
|
||||
/// What to do when a connection upgrade fails.
|
||||
///
|
||||
/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
|
||||
pub trait OnFailedUpdgrade: Send + 'static {
|
||||
/// Call the callback.
|
||||
fn call(self, error: Error);
|
||||
}
|
||||
|
||||
impl<F> OnFailedUpdgrade for F
|
||||
where
|
||||
F: FnOnce(Error) + Send + 'static,
|
||||
{
|
||||
fn call(self, error: Error) {
|
||||
self(error)
|
||||
}
|
||||
}
|
||||
|
||||
/// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`.
|
||||
///
|
||||
/// It simply ignores the error.
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug)]
|
||||
pub struct DefaultOnFailedUpdgrade;
|
||||
|
||||
impl OnFailedUpdgrade for DefaultOnFailedUpdgrade {
|
||||
#[inline]
|
||||
fn call(self, _error: Error) {}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for WebSocketUpgrade
|
||||
impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpdgrade>
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
|
@ -323,6 +410,7 @@ where
|
|||
sec_websocket_key,
|
||||
on_upgrade,
|
||||
sec_websocket_protocol,
|
||||
on_failed_upgrade: DefaultOnFailedUpdgrade,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -722,7 +810,7 @@ pub mod close_code {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{body::Body, routing::get};
|
||||
use crate::{body::Body, routing::get, Router};
|
||||
use http::{Request, Version};
|
||||
use tower::ServiceExt;
|
||||
|
||||
|
@ -751,4 +839,21 @@ mod tests {
|
|||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn default_on_failed_upgrade() {
|
||||
async fn handler(ws: WebSocketUpgrade) -> Response {
|
||||
ws.on_upgrade(|_| async {})
|
||||
}
|
||||
let _: Router = Router::new().route("/", get(handler));
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn on_failed_upgrade() {
|
||||
async fn handler(ws: WebSocketUpgrade) -> Response {
|
||||
ws.on_failed_upgrade(|_error: Error| println!("oops!"))
|
||||
.on_upgrade(|_| async {})
|
||||
}
|
||||
let _: Router = Router::new().route("/", get(handler));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue