Add Serve::tcp_nodelay (#2653)

This commit is contained in:
Liss Heidrich 2024-04-01 23:26:18 +02:00 committed by GitHub
parent dbd6178393
commit 50c035c20b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 98 additions and 0 deletions

View file

@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased # Unreleased
- **change:** Avoid cloning `Arc` during deserialization of `Path` - **change:** Avoid cloning `Arc` during deserialization of `Path`
- **added:** `axum::serve::Serve::tcp_nodelay` and `axum::serve::WithGracefulShutdown::tcp_nodelay` ([#2653])
[#2653]: https://github.com/tokio-rs/axum/pull/2653
# 0.7.5 (24. March, 2024) # 0.7.5 (24. March, 2024)

View file

@ -101,6 +101,7 @@ where
Serve { Serve {
tcp_listener, tcp_listener,
make_service, make_service,
tcp_nodelay: None,
_marker: PhantomData, _marker: PhantomData,
} }
} }
@ -111,6 +112,7 @@ where
pub struct Serve<M, S> { pub struct Serve<M, S> {
tcp_listener: TcpListener, tcp_listener: TcpListener,
make_service: M, make_service: M,
tcp_nodelay: Option<bool>,
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
@ -145,9 +147,35 @@ impl<M, S> Serve<M, S> {
tcp_listener: self.tcp_listener, tcp_listener: self.tcp_listener,
make_service: self.make_service, make_service: self.make_service,
signal, signal,
tcp_nodelay: self.tcp_nodelay,
_marker: PhantomData, _marker: PhantomData,
} }
} }
/// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
///
/// See also [`TcpStream::set_nodelay`].
///
/// # Example
/// ```
/// use axum::{Router, routing::get};
///
/// # async {
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
///
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
/// axum::serve(listener, router)
/// .tcp_nodelay(true)
/// .await
/// .unwrap();
/// # };
/// ```
pub fn tcp_nodelay(self, nodelay: bool) -> Self {
Self {
tcp_nodelay: Some(nodelay),
..self
}
}
} }
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
@ -159,12 +187,14 @@ where
let Self { let Self {
tcp_listener, tcp_listener,
make_service, make_service,
tcp_nodelay,
_marker: _, _marker: _,
} = self; } = self;
f.debug_struct("Serve") f.debug_struct("Serve")
.field("tcp_listener", tcp_listener) .field("tcp_listener", tcp_listener)
.field("make_service", make_service) .field("make_service", make_service)
.field("tcp_nodelay", tcp_nodelay)
.finish() .finish()
} }
} }
@ -185,6 +215,7 @@ where
let Self { let Self {
tcp_listener, tcp_listener,
mut make_service, mut make_service,
tcp_nodelay,
_marker: _, _marker: _,
} = self; } = self;
@ -193,6 +224,13 @@ where
Some(conn) => conn, Some(conn) => conn,
None => continue, None => continue,
}; };
if let Some(nodelay) = tcp_nodelay {
if let Err(err) = tcp_stream.set_nodelay(nodelay) {
trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
}
}
let tcp_stream = TokioIo::new(tcp_stream); let tcp_stream = TokioIo::new(tcp_stream);
poll_fn(|cx| make_service.poll_ready(cx)) poll_fn(|cx| make_service.poll_ready(cx))
@ -239,9 +277,42 @@ pub struct WithGracefulShutdown<M, S, F> {
tcp_listener: TcpListener, tcp_listener: TcpListener,
make_service: M, make_service: M,
signal: F, signal: F,
tcp_nodelay: Option<bool>,
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
impl<M, S, F> WithGracefulShutdown<M, S, F> {
/// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
///
/// See also [`TcpStream::set_nodelay`].
///
/// # Example
/// ```
/// use axum::{Router, routing::get};
///
/// # async {
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
///
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
/// axum::serve(listener, router)
/// .with_graceful_shutdown(shutdown_signal())
/// .tcp_nodelay(true)
/// .await
/// .unwrap();
/// # };
///
/// async fn shutdown_signal() {
/// // ...
/// }
/// ```
pub fn tcp_nodelay(self, nodelay: bool) -> Self {
Self {
tcp_nodelay: Some(nodelay),
..self
}
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> Debug for WithGracefulShutdown<M, S, F> impl<M, S, F> Debug for WithGracefulShutdown<M, S, F>
where where
@ -254,6 +325,7 @@ where
tcp_listener, tcp_listener,
make_service, make_service,
signal, signal,
tcp_nodelay,
_marker: _, _marker: _,
} = self; } = self;
@ -261,6 +333,7 @@ where
.field("tcp_listener", tcp_listener) .field("tcp_listener", tcp_listener)
.field("make_service", make_service) .field("make_service", make_service)
.field("signal", signal) .field("signal", signal)
.field("tcp_nodelay", tcp_nodelay)
.finish() .finish()
} }
} }
@ -282,6 +355,7 @@ where
tcp_listener, tcp_listener,
mut make_service, mut make_service,
signal, signal,
tcp_nodelay,
_marker: _, _marker: _,
} = self; } = self;
@ -309,6 +383,13 @@ where
break; break;
} }
}; };
if let Some(nodelay) = tcp_nodelay {
if let Err(err) = tcp_stream.set_nodelay(nodelay) {
trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
}
}
let tcp_stream = TokioIo::new(tcp_stream); let tcp_stream = TokioIo::new(tcp_stream);
trace!("connection {remote_addr} accepted"); trace!("connection {remote_addr} accepted");
@ -557,6 +638,20 @@ mod tests {
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
handler.into_make_service_with_connect_info::<SocketAddr>(), handler.into_make_service_with_connect_info::<SocketAddr>(),
); );
// nodelay
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_service(),
)
.tcp_nodelay(true);
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_service(),
)
.with_graceful_shutdown(async { /*...*/ })
.tcp_nodelay(true);
} }
async fn handler() {} async fn handler() {}