mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-21 14:46:32 +01:00
Support graceful shutdown on serve
(#2398)
This commit is contained in:
parent
56159b0d4e
commit
12e8c6219d
8 changed files with 249 additions and 139 deletions
|
@ -13,8 +13,8 @@ error[E0277]: the trait bound `bool: FromRequestParts<()>` is not satisfied
|
|||
<axum::http::request::Parts as FromRequestParts<S>>
|
||||
<Uri as FromRequestParts<S>>
|
||||
<Version as FromRequestParts<S>>
|
||||
<ConnectInfo<T> as FromRequestParts<S>>
|
||||
<Extensions as FromRequestParts<S>>
|
||||
<ConnectInfo<T> as FromRequestParts<S>>
|
||||
and $N others
|
||||
= note: required for `bool` to implement `FromRequest<(), axum_core::extract::private::ViaParts>`
|
||||
note: required by a bound in `__axum_macros_check_handler_0_from_request_check`
|
||||
|
|
|
@ -14,5 +14,5 @@ error[E0277]: the trait bound `String: FromRequestParts<S>` is not satisfied
|
|||
<axum::http::request::Parts as FromRequestParts<S>>
|
||||
<Uri as FromRequestParts<S>>
|
||||
<Version as FromRequestParts<S>>
|
||||
<ConnectInfo<T> as FromRequestParts<S>>
|
||||
<Extensions as FromRequestParts<S>>
|
||||
and $N others
|
||||
|
|
|
@ -10,10 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- **added:** `Body` implements `From<()>` now ([#2411])
|
||||
- **change:** Update version of multer used internally for multipart ([#2433])
|
||||
- **change:** Update tokio-tungstenite to 0.21 ([#2435])
|
||||
- **added:** Support graceful shutdown on `serve` ([#2398])
|
||||
|
||||
[#2411]: https://github.com/tokio-rs/axum/pull/2411
|
||||
[#2433]: https://github.com/tokio-rs/axum/pull/2433
|
||||
[#2435]: https://github.com/tokio-rs/axum/pull/2435
|
||||
[#2398]: https://github.com/tokio-rs/axum/pull/2398
|
||||
|
||||
# 0.7.2 (03. December, 2023)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ matched-path = []
|
|||
multipart = ["dep:multer"]
|
||||
original-uri = []
|
||||
query = ["dep:serde_urlencoded"]
|
||||
tokio = ["dep:hyper-util", "dep:tokio", "tokio/net", "tokio/rt", "tower/make"]
|
||||
tokio = ["dep:hyper-util", "dep:tokio", "tokio/net", "tokio/rt", "tower/make", "tokio/macros"]
|
||||
tower-log = ["tower/log"]
|
||||
tracing = ["dep:tracing", "axum-core/tracing"]
|
||||
ws = ["dep:hyper", "tokio", "dep:tokio-tungstenite", "dep:sha1", "dep:base64"]
|
||||
|
@ -53,8 +53,8 @@ tower-service = "0.3"
|
|||
# optional dependencies
|
||||
axum-macros = { path = "../axum-macros", version = "0.4.0", optional = true }
|
||||
base64 = { version = "0.21.0", optional = true }
|
||||
hyper = { version = "1.0.0", optional = true }
|
||||
hyper-util = { version = "0.1.1", features = ["tokio", "server", "server-auto"], optional = true }
|
||||
hyper = { version = "1.1.0", optional = true }
|
||||
hyper-util = { version = "0.1.2", features = ["tokio", "server", "server-auto"], optional = true }
|
||||
multer = { version = "3.0.0", optional = true }
|
||||
serde_json = { version = "1.0", features = ["raw_value"], optional = true }
|
||||
serde_path_to_error = { version = "0.1.8", optional = true }
|
||||
|
|
|
@ -67,6 +67,13 @@ macro_rules! all_the_tuples {
|
|||
};
|
||||
}
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
macro_rules! trace {
|
||||
($($tt:tt)*) => {
|
||||
tracing::trace!($($tt)*)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
macro_rules! error {
|
||||
($($tt:tt)*) => {
|
||||
|
@ -74,6 +81,11 @@ macro_rules! error {
|
|||
};
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tracing"))]
|
||||
macro_rules! trace {
|
||||
($($tt:tt)*) => {};
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tracing"))]
|
||||
macro_rules! error {
|
||||
($($tt:tt)*) => {};
|
||||
|
|
|
@ -2,24 +2,29 @@
|
|||
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
future::{Future, IntoFuture},
|
||||
fmt::Debug,
|
||||
future::{poll_fn, Future, IntoFuture},
|
||||
io,
|
||||
marker::PhantomData,
|
||||
net::SocketAddr,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use axum_core::{body::Body, extract::Request, response::Response};
|
||||
use futures_util::future::poll_fn;
|
||||
use futures_util::{pin_mut, FutureExt};
|
||||
use hyper::body::Incoming;
|
||||
use hyper_util::{
|
||||
rt::{TokioExecutor, TokioIo},
|
||||
server::conn::auto::Builder,
|
||||
};
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::{
|
||||
net::{TcpListener, TcpStream},
|
||||
sync::watch,
|
||||
};
|
||||
use tower::util::{Oneshot, ServiceExt};
|
||||
use tower_service::Service;
|
||||
|
||||
|
@ -110,9 +115,45 @@ pub struct Serve<M, S> {
|
|||
}
|
||||
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
impl<M, S> std::fmt::Debug for Serve<M, S>
|
||||
impl<M, S> Serve<M, S> {
|
||||
/// Prepares a server to handle graceful shutdown when the provided future completes.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use axum::{Router, routing::get};
|
||||
///
|
||||
/// # async {
|
||||
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
|
||||
///
|
||||
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
|
||||
/// axum::serve(listener, router)
|
||||
/// .with_graceful_shutdown(shutdown_signal())
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// # };
|
||||
///
|
||||
/// async fn shutdown_signal() {
|
||||
/// // ...
|
||||
/// }
|
||||
/// ```
|
||||
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
WithGracefulShutdown {
|
||||
tcp_listener: self.tcp_listener,
|
||||
make_service: self.make_service,
|
||||
signal,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
impl<M, S> Debug for Serve<M, S>
|
||||
where
|
||||
M: std::fmt::Debug,
|
||||
M: Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let Self {
|
||||
|
@ -148,30 +189,9 @@ where
|
|||
} = self;
|
||||
|
||||
loop {
|
||||
let (tcp_stream, remote_addr) = match tcp_listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
// Connection errors can be ignored directly, continue
|
||||
// by accepting the next request.
|
||||
if is_connection_error(&e) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
|
||||
//
|
||||
// > A possible scenario is that the process has hit the max open files
|
||||
// > allowed, and so trying to accept a new connection will fail with
|
||||
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
|
||||
// > the application will likely close some files (or connections), and try
|
||||
// > to accept the connection again. If this option is `true`, the error
|
||||
// > will be logged at the `error` level, since it is still a big deal,
|
||||
// > and then the listener will sleep for 1 second.
|
||||
//
|
||||
// hyper allowed customizing this but axum does not.
|
||||
error!("accept error: {e}");
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
continue;
|
||||
}
|
||||
let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await {
|
||||
Some(conn) => conn,
|
||||
None => continue,
|
||||
};
|
||||
let tcp_stream = TokioIo::new(tcp_stream);
|
||||
|
||||
|
@ -191,7 +211,7 @@ where
|
|||
service: tower_service,
|
||||
};
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
tokio::spawn(async move {
|
||||
match Builder::new(TokioExecutor::new())
|
||||
// upgrades needed for websockets
|
||||
.serve_connection_with_upgrades(tcp_stream, hyper_service)
|
||||
|
@ -212,6 +232,149 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Serve future with graceful shutdown enabled.
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
pub struct WithGracefulShutdown<M, S, F> {
|
||||
tcp_listener: TcpListener,
|
||||
make_service: M,
|
||||
signal: F,
|
||||
_marker: PhantomData<S>,
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
impl<M, S, F> Debug for WithGracefulShutdown<M, S, F>
|
||||
where
|
||||
M: Debug,
|
||||
S: Debug,
|
||||
F: Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let Self {
|
||||
tcp_listener,
|
||||
make_service,
|
||||
signal,
|
||||
_marker: _,
|
||||
} = self;
|
||||
|
||||
f.debug_struct("WithGracefulShutdown")
|
||||
.field("tcp_listener", tcp_listener)
|
||||
.field("make_service", make_service)
|
||||
.field("signal", signal)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
|
||||
where
|
||||
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
|
||||
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
|
||||
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
||||
S::Future: Send,
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
type Output = io::Result<()>;
|
||||
type IntoFuture = private::ServeFuture;
|
||||
|
||||
fn into_future(self) -> Self::IntoFuture {
|
||||
let Self {
|
||||
tcp_listener,
|
||||
mut make_service,
|
||||
signal,
|
||||
_marker: _,
|
||||
} = self;
|
||||
|
||||
let (signal_tx, signal_rx) = watch::channel(());
|
||||
let signal_tx = Arc::new(signal_tx);
|
||||
tokio::spawn(async move {
|
||||
signal.await;
|
||||
trace!("received graceful shutdown signal. Telling tasks to shutdown");
|
||||
drop(signal_rx);
|
||||
});
|
||||
|
||||
let (close_tx, close_rx) = watch::channel(());
|
||||
|
||||
private::ServeFuture(Box::pin(async move {
|
||||
loop {
|
||||
let (tcp_stream, remote_addr) = tokio::select! {
|
||||
conn = tcp_accept(&tcp_listener) => {
|
||||
match conn {
|
||||
Some(conn) => conn,
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
_ = signal_tx.closed() => {
|
||||
trace!("signal received, not accepting new connections");
|
||||
break;
|
||||
}
|
||||
};
|
||||
let tcp_stream = TokioIo::new(tcp_stream);
|
||||
|
||||
trace!("connection {remote_addr} accepted");
|
||||
|
||||
poll_fn(|cx| make_service.poll_ready(cx))
|
||||
.await
|
||||
.unwrap_or_else(|err| match err {});
|
||||
|
||||
let tower_service = make_service
|
||||
.call(IncomingStream {
|
||||
tcp_stream: &tcp_stream,
|
||||
remote_addr,
|
||||
})
|
||||
.await
|
||||
.unwrap_or_else(|err| match err {});
|
||||
|
||||
let hyper_service = TowerToHyperService {
|
||||
service: tower_service,
|
||||
};
|
||||
|
||||
let signal_tx = Arc::clone(&signal_tx);
|
||||
|
||||
let close_rx = close_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let builder = Builder::new(TokioExecutor::new());
|
||||
let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service);
|
||||
pin_mut!(conn);
|
||||
|
||||
let signal_closed = signal_tx.closed().fuse();
|
||||
pin_mut!(signal_closed);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = conn.as_mut() => {
|
||||
if let Err(_err) = result {
|
||||
trace!("failed to serve connection: {_err:#}");
|
||||
}
|
||||
break;
|
||||
}
|
||||
_ = &mut signal_closed => {
|
||||
trace!("signal received in task, starting graceful shutdown");
|
||||
conn.as_mut().graceful_shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!("connection {remote_addr} closed");
|
||||
|
||||
drop(close_rx);
|
||||
});
|
||||
}
|
||||
|
||||
drop(close_rx);
|
||||
drop(tcp_listener);
|
||||
|
||||
trace!(
|
||||
"waiting for {} task(s) to finish",
|
||||
close_tx.receiver_count()
|
||||
);
|
||||
close_tx.closed().await;
|
||||
|
||||
Ok(())
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_connection_error(e: &io::Error) -> bool {
|
||||
matches!(
|
||||
e.kind(),
|
||||
|
@ -221,6 +384,32 @@ fn is_connection_error(e: &io::Error) -> bool {
|
|||
)
|
||||
}
|
||||
|
||||
async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> {
|
||||
match listener.accept().await {
|
||||
Ok(conn) => Some(conn),
|
||||
Err(e) => {
|
||||
if is_connection_error(&e) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
|
||||
//
|
||||
// > A possible scenario is that the process has hit the max open files
|
||||
// > allowed, and so trying to accept a new connection will fail with
|
||||
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
|
||||
// > the application will likely close some files (or connections), and try
|
||||
// > to accept the connection again. If this option is `true`, the error
|
||||
// > will be logged at the `error` level, since it is still a big deal,
|
||||
// > and then the listener will sleep for 1 second.
|
||||
//
|
||||
// hyper allowed customizing this but axum does not.
|
||||
error!("accept error: {e}");
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod private {
|
||||
use std::{
|
||||
future::Future,
|
||||
|
|
|
@ -5,7 +5,7 @@ edition = "2021"
|
|||
publish = false
|
||||
|
||||
[dependencies]
|
||||
axum = { path = "../../axum" }
|
||||
axum = { path = "../../axum", features = ["tracing"] }
|
||||
hyper = { version = "1.0", features = [] }
|
||||
hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
|
|
|
@ -10,17 +10,12 @@
|
|||
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::{extract::Request, routing::get, Router};
|
||||
use hyper::body::Incoming;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use axum::{routing::get, Router};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::signal;
|
||||
use tokio::sync::watch;
|
||||
use tokio::time::sleep;
|
||||
use tower::Service;
|
||||
use tower_http::timeout::TimeoutLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing::debug;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
#[tokio::main]
|
||||
|
@ -28,10 +23,11 @@ async fn main() {
|
|||
// Enable tracing.
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "example_graceful_shutdown=debug,tower_http=debug".into()),
|
||||
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
|
||||
"example_graceful_shutdown=debug,tower_http=debug,axum=trace".into()
|
||||
}),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(tracing_subscriber::fmt::layer().without_time())
|
||||
.init();
|
||||
|
||||
// Create a regular axum app.
|
||||
|
@ -48,100 +44,11 @@ async fn main() {
|
|||
// Create a `TcpListener` using tokio.
|
||||
let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap();
|
||||
|
||||
// Create a watch channel to track tasks that are handling connections and wait for them to
|
||||
// complete.
|
||||
let (close_tx, close_rx) = watch::channel(());
|
||||
|
||||
// Continuously accept new connections.
|
||||
loop {
|
||||
let (socket, remote_addr) = tokio::select! {
|
||||
// Either accept a new connection...
|
||||
result = listener.accept() => {
|
||||
result.unwrap()
|
||||
}
|
||||
// ...or wait to receive a shutdown signal and stop the accept loop.
|
||||
_ = shutdown_signal() => {
|
||||
debug!("signal received, not accepting new connections");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
debug!("connection {remote_addr} accepted");
|
||||
|
||||
// We don't need to call `poll_ready` because `Router` is always ready.
|
||||
let tower_service = app.clone();
|
||||
|
||||
// Clone the watch receiver and move it into the task.
|
||||
let close_rx = close_rx.clone();
|
||||
|
||||
// Spawn a task to handle the connection. That way we can serve multiple connections
|
||||
// concurrently.
|
||||
tokio::spawn(async move {
|
||||
// Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
|
||||
// `TokioIo` converts between them.
|
||||
let socket = TokioIo::new(socket);
|
||||
|
||||
// Hyper also has its own `Service` trait and doesn't use tower. We can use
|
||||
// `hyper::service::service_fn` to create a hyper `Service` that calls our app through
|
||||
// `tower::Service::call`.
|
||||
let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
|
||||
// We have to clone `tower_service` because hyper's `Service` uses `&self` whereas
|
||||
// tower's `Service` requires `&mut self`.
|
||||
//
|
||||
// We don't need to call `poll_ready` since `Router` is always ready.
|
||||
tower_service.clone().call(request)
|
||||
});
|
||||
|
||||
// `hyper_util::server::conn::auto::Builder` supports both http1 and http2 but doesn't
|
||||
// support graceful so we have to use hyper directly and unfortunately pick between
|
||||
// http1 and http2.
|
||||
let conn = hyper::server::conn::http1::Builder::new()
|
||||
.serve_connection(socket, hyper_service)
|
||||
// `with_upgrades` is required for websockets.
|
||||
.with_upgrades();
|
||||
|
||||
// `graceful_shutdown` requires a pinned connection.
|
||||
let mut conn = std::pin::pin!(conn);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Poll the connection. This completes when the client has closed the
|
||||
// connection, graceful shutdown has completed, or we encounter a TCP error.
|
||||
result = conn.as_mut() => {
|
||||
if let Err(err) = result {
|
||||
debug!("failed to serve connection: {err:#}");
|
||||
}
|
||||
break;
|
||||
}
|
||||
// Start graceful shutdown when we receive a shutdown signal.
|
||||
//
|
||||
// We use a loop to continue polling the connection to allow requests to finish
|
||||
// after starting graceful shutdown. Our `Router` has `TimeoutLayer` so
|
||||
// requests will finish after at most 10 seconds.
|
||||
_ = shutdown_signal() => {
|
||||
debug!("signal received, starting graceful shutdown");
|
||||
conn.as_mut().graceful_shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!("connection {remote_addr} closed");
|
||||
|
||||
// Drop the watch receiver to signal to `main` that this task is done.
|
||||
drop(close_rx);
|
||||
});
|
||||
}
|
||||
|
||||
// We only care about the watch receivers that were moved into the tasks so close the residual
|
||||
// receiver.
|
||||
drop(close_rx);
|
||||
|
||||
// Close the listener to stop accepting new connections.
|
||||
drop(listener);
|
||||
|
||||
// Wait for all tasks to complete.
|
||||
debug!("waiting for {} tasks to finish", close_tx.receiver_count());
|
||||
close_tx.closed().await;
|
||||
// Run the server with graceful shutdown
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn shutdown_signal() {
|
||||
|
|
Loading…
Reference in a new issue