diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 220c458e..316f2150 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased - **change:** Avoid cloning `Arc` during deserialization of `Path` +- **added:** `axum::serve::Serve::tcp_nodelay` and `axum::serve::WithGracefulShutdown::tcp_nodelay` ([#2653]) + +[#2653]: https://github.com/tokio-rs/axum/pull/2653 # 0.7.5 (24. March, 2024) diff --git a/axum/src/serve.rs b/axum/src/serve.rs index a2df756b..c5c54086 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -101,6 +101,7 @@ where Serve { tcp_listener, make_service, + tcp_nodelay: None, _marker: PhantomData, } } @@ -111,6 +112,7 @@ where pub struct Serve { tcp_listener: TcpListener, make_service: M, + tcp_nodelay: Option, _marker: PhantomData, } @@ -145,9 +147,35 @@ impl Serve { tcp_listener: self.tcp_listener, make_service: self.make_service, signal, + tcp_nodelay: self.tcp_nodelay, _marker: PhantomData, } } + + /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. + /// + /// See also [`TcpStream::set_nodelay`]. + /// + /// # Example + /// ``` + /// use axum::{Router, routing::get}; + /// + /// # async { + /// let router = Router::new().route("/", get(|| async { "Hello, World!" })); + /// + /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + /// axum::serve(listener, router) + /// .tcp_nodelay(true) + /// .await + /// .unwrap(); + /// # }; + /// ``` + pub fn tcp_nodelay(self, nodelay: bool) -> Self { + Self { + tcp_nodelay: Some(nodelay), + ..self + } + } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] @@ -159,12 +187,14 @@ where let Self { tcp_listener, make_service, + tcp_nodelay, _marker: _, } = self; f.debug_struct("Serve") .field("tcp_listener", tcp_listener) .field("make_service", make_service) + .field("tcp_nodelay", tcp_nodelay) .finish() } } @@ -185,6 +215,7 @@ where let Self { tcp_listener, mut make_service, + tcp_nodelay, _marker: _, } = self; @@ -193,6 +224,13 @@ where Some(conn) => conn, None => continue, }; + + if let Some(nodelay) = tcp_nodelay { + if let Err(err) = tcp_stream.set_nodelay(nodelay) { + trace!("failed to set TCP_NODELAY on incoming connection: {err:#}"); + } + } + let tcp_stream = TokioIo::new(tcp_stream); poll_fn(|cx| make_service.poll_ready(cx)) @@ -239,9 +277,42 @@ pub struct WithGracefulShutdown { tcp_listener: TcpListener, make_service: M, signal: F, + tcp_nodelay: Option, _marker: PhantomData, } +impl WithGracefulShutdown { + /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. + /// + /// See also [`TcpStream::set_nodelay`]. + /// + /// # Example + /// ``` + /// use axum::{Router, routing::get}; + /// + /// # async { + /// let router = Router::new().route("/", get(|| async { "Hello, World!" })); + /// + /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + /// axum::serve(listener, router) + /// .with_graceful_shutdown(shutdown_signal()) + /// .tcp_nodelay(true) + /// .await + /// .unwrap(); + /// # }; + /// + /// async fn shutdown_signal() { + /// // ... + /// } + /// ``` + pub fn tcp_nodelay(self, nodelay: bool) -> Self { + Self { + tcp_nodelay: Some(nodelay), + ..self + } + } +} + #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] impl Debug for WithGracefulShutdown where @@ -254,6 +325,7 @@ where tcp_listener, make_service, signal, + tcp_nodelay, _marker: _, } = self; @@ -261,6 +333,7 @@ where .field("tcp_listener", tcp_listener) .field("make_service", make_service) .field("signal", signal) + .field("tcp_nodelay", tcp_nodelay) .finish() } } @@ -282,6 +355,7 @@ where tcp_listener, mut make_service, signal, + tcp_nodelay, _marker: _, } = self; @@ -309,6 +383,13 @@ where break; } }; + + if let Some(nodelay) = tcp_nodelay { + if let Err(err) = tcp_stream.set_nodelay(nodelay) { + trace!("failed to set TCP_NODELAY on incoming connection: {err:#}"); + } + } + let tcp_stream = TokioIo::new(tcp_stream); trace!("connection {remote_addr} accepted"); @@ -557,6 +638,20 @@ mod tests { TcpListener::bind(addr).await.unwrap(), handler.into_make_service_with_connect_info::(), ); + + // nodelay + serve( + TcpListener::bind(addr).await.unwrap(), + handler.into_service(), + ) + .tcp_nodelay(true); + + serve( + TcpListener::bind(addr).await.unwrap(), + handler.into_service(), + ) + .with_graceful_shutdown(async { /*...*/ }) + .tcp_nodelay(true); } async fn handler() {}