mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-04 22:14:44 +01:00
Make serve
generic over the listener and IO types (#2941)
Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
parent
d84136e1e4
commit
84c3960639
9 changed files with 399 additions and 254 deletions
|
@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
- **added:** Extend `FailedToDeserializePathParams::kind` enum with (`ErrorKind::DeserializeError`)
|
- **added:** Extend `FailedToDeserializePathParams::kind` enum with (`ErrorKind::DeserializeError`)
|
||||||
This new variant captures both `key`, `value`, and `message` from named path parameters parse errors,
|
This new variant captures both `key`, `value`, and `message` from named path parameters parse errors,
|
||||||
instead of only deserialization error message in `ErrorKind::Message`. ([#2720])
|
instead of only deserialization error message in `ErrorKind::Message`. ([#2720])
|
||||||
|
- **breaking:** Make `serve` generic over the listener and IO types ([#2941])
|
||||||
|
|
||||||
[#2897]: https://github.com/tokio-rs/axum/pull/2897
|
[#2897]: https://github.com/tokio-rs/axum/pull/2897
|
||||||
[#2903]: https://github.com/tokio-rs/axum/pull/2903
|
[#2903]: https://github.com/tokio-rs/axum/pull/2903
|
||||||
|
@ -34,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
[#2992]: https://github.com/tokio-rs/axum/pull/2992
|
[#2992]: https://github.com/tokio-rs/axum/pull/2992
|
||||||
[#2720]: https://github.com/tokio-rs/axum/pull/2720
|
[#2720]: https://github.com/tokio-rs/axum/pull/2720
|
||||||
[#3039]: https://github.com/tokio-rs/axum/pull/3039
|
[#3039]: https://github.com/tokio-rs/axum/pull/3039
|
||||||
|
[#2941]: https://github.com/tokio-rs/axum/pull/2941
|
||||||
|
|
||||||
# 0.8.0
|
# 0.8.0
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ use axum::{
|
||||||
serve::IncomingStream,
|
serve::IncomingStream,
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
let app = Router::new().route("/", get(handler));
|
let app = Router::new().route("/", get(handler));
|
||||||
|
|
||||||
|
@ -49,8 +50,8 @@ struct MyConnectInfo {
|
||||||
// ...
|
// ...
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Connected<IncomingStream<'_>> for MyConnectInfo {
|
impl Connected<IncomingStream<'_, TcpListener>> for MyConnectInfo {
|
||||||
fn connect_info(target: IncomingStream<'_>) -> Self {
|
fn connect_info(target: IncomingStream<'_, TcpListener>) -> Self {
|
||||||
MyConnectInfo {
|
MyConnectInfo {
|
||||||
// ...
|
// ...
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,16 +79,17 @@ where
|
||||||
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
|
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
|
||||||
pub trait Connected<T>: Clone + Send + Sync + 'static {
|
pub trait Connected<T>: Clone + Send + Sync + 'static {
|
||||||
/// Create type holding information about the connection.
|
/// Create type holding information about the connection.
|
||||||
fn connect_info(target: T) -> Self;
|
fn connect_info(stream: T) -> Self;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||||
const _: () = {
|
const _: () = {
|
||||||
use crate::serve::IncomingStream;
|
use crate::serve;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
impl Connected<IncomingStream<'_>> for SocketAddr {
|
impl Connected<serve::IncomingStream<'_, TcpListener>> for SocketAddr {
|
||||||
fn connect_info(target: IncomingStream<'_>) -> Self {
|
fn connect_info(stream: serve::IncomingStream<'_, TcpListener>) -> Self {
|
||||||
target.remote_addr()
|
*stream.remote_addr()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -261,8 +262,8 @@ mod tests {
|
||||||
value: &'static str,
|
value: &'static str,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Connected<IncomingStream<'_>> for MyConnectInfo {
|
impl Connected<IncomingStream<'_, TcpListener>> for MyConnectInfo {
|
||||||
fn connect_info(_target: IncomingStream<'_>) -> Self {
|
fn connect_info(_target: IncomingStream<'_, TcpListener>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
value: "it worked!",
|
value: "it worked!",
|
||||||
}
|
}
|
||||||
|
|
|
@ -180,12 +180,13 @@ where
|
||||||
// for `axum::serve(listener, handler)`
|
// for `axum::serve(listener, handler)`
|
||||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||||
const _: () = {
|
const _: () = {
|
||||||
use crate::serve::IncomingStream;
|
use crate::serve;
|
||||||
|
|
||||||
impl<H, T, S> Service<IncomingStream<'_>> for HandlerService<H, T, S>
|
impl<H, T, S, L> Service<serve::IncomingStream<'_, L>> for HandlerService<H, T, S>
|
||||||
where
|
where
|
||||||
H: Clone,
|
H: Clone,
|
||||||
S: Clone,
|
S: Clone,
|
||||||
|
L: serve::Listener,
|
||||||
{
|
{
|
||||||
type Response = Self;
|
type Response = Self;
|
||||||
type Error = Infallible;
|
type Error = Infallible;
|
||||||
|
@ -195,7 +196,7 @@ const _: () = {
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
|
fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
|
||||||
std::future::ready(Ok(self.clone()))
|
std::future::ready(Ok(self.clone()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1313,9 +1313,12 @@ where
|
||||||
// for `axum::serve(listener, router)`
|
// for `axum::serve(listener, router)`
|
||||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||||
const _: () = {
|
const _: () = {
|
||||||
use crate::serve::IncomingStream;
|
use crate::serve;
|
||||||
|
|
||||||
impl Service<IncomingStream<'_>> for MethodRouter<()> {
|
impl<L> Service<serve::IncomingStream<'_, L>> for MethodRouter<()>
|
||||||
|
where
|
||||||
|
L: serve::Listener,
|
||||||
|
{
|
||||||
type Response = Self;
|
type Response = Self;
|
||||||
type Error = Infallible;
|
type Error = Infallible;
|
||||||
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
|
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
|
||||||
|
@ -1324,7 +1327,7 @@ const _: () = {
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
|
fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
|
||||||
std::future::ready(Ok(self.clone().with_state(())))
|
std::future::ready(Ok(self.clone().with_state(())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -518,9 +518,12 @@ impl Router {
|
||||||
// for `axum::serve(listener, router)`
|
// for `axum::serve(listener, router)`
|
||||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||||
const _: () = {
|
const _: () = {
|
||||||
use crate::serve::IncomingStream;
|
use crate::serve;
|
||||||
|
|
||||||
impl Service<IncomingStream<'_>> for Router<()> {
|
impl<L> Service<serve::IncomingStream<'_, L>> for Router<()>
|
||||||
|
where
|
||||||
|
L: serve::Listener,
|
||||||
|
{
|
||||||
type Response = Self;
|
type Response = Self;
|
||||||
type Error = Infallible;
|
type Error = Infallible;
|
||||||
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
|
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
|
||||||
|
@ -529,7 +532,7 @@ const _: () = {
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
|
fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
|
||||||
// call `Router::with_state` such that everything is turned into `Route` eagerly
|
// call `Router::with_state` such that everything is turned into `Route` eagerly
|
||||||
// rather than doing that per request
|
// rather than doing that per request
|
||||||
std::future::ready(Ok(self.clone().with_state(())))
|
std::future::ready(Ok(self.clone().with_state(())))
|
||||||
|
|
|
@ -1,14 +1,13 @@
|
||||||
//! Serve services.
|
//! Serve services.
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
|
any::TypeId,
|
||||||
convert::Infallible,
|
convert::Infallible,
|
||||||
fmt::Debug,
|
fmt::Debug,
|
||||||
future::{poll_fn, Future, IntoFuture},
|
future::{poll_fn, Future, IntoFuture},
|
||||||
io,
|
io,
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
net::SocketAddr,
|
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
time::Duration,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use axum_core::{body::Body, extract::Request, response::Response};
|
use axum_core::{body::Body, extract::Request, response::Response};
|
||||||
|
@ -17,13 +16,14 @@ use hyper::body::Incoming;
|
||||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||||
#[cfg(any(feature = "http1", feature = "http2"))]
|
#[cfg(any(feature = "http1", feature = "http2"))]
|
||||||
use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
|
use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
|
||||||
use tokio::{
|
use tokio::{net::TcpListener, sync::watch};
|
||||||
net::{TcpListener, TcpStream},
|
|
||||||
sync::watch,
|
|
||||||
};
|
|
||||||
use tower::ServiceExt as _;
|
use tower::ServiceExt as _;
|
||||||
use tower_service::Service;
|
use tower_service::Service;
|
||||||
|
|
||||||
|
mod listener;
|
||||||
|
|
||||||
|
pub use self::listener::{Listener, ListenerExt, TapIo};
|
||||||
|
|
||||||
/// Serve the service with the supplied listener.
|
/// Serve the service with the supplied listener.
|
||||||
///
|
///
|
||||||
/// This method of running a service is intentionally simple and doesn't support any configuration.
|
/// This method of running a service is intentionally simple and doesn't support any configuration.
|
||||||
|
@ -89,14 +89,15 @@ use tower_service::Service;
|
||||||
/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
|
/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
|
||||||
/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
|
/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
|
||||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||||
pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S>
|
pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
|
||||||
where
|
where
|
||||||
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
|
L: Listener,
|
||||||
|
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
|
||||||
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
||||||
S::Future: Send,
|
S::Future: Send,
|
||||||
{
|
{
|
||||||
Serve {
|
Serve {
|
||||||
tcp_listener,
|
listener,
|
||||||
make_service,
|
make_service,
|
||||||
tcp_nodelay: None,
|
tcp_nodelay: None,
|
||||||
_marker: PhantomData,
|
_marker: PhantomData,
|
||||||
|
@ -106,15 +107,18 @@ where
|
||||||
/// Future returned by [`serve`].
|
/// Future returned by [`serve`].
|
||||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||||
#[must_use = "futures must be awaited or polled"]
|
#[must_use = "futures must be awaited or polled"]
|
||||||
pub struct Serve<M, S> {
|
pub struct Serve<L, M, S> {
|
||||||
tcp_listener: TcpListener,
|
listener: L,
|
||||||
make_service: M,
|
make_service: M,
|
||||||
tcp_nodelay: Option<bool>,
|
tcp_nodelay: Option<bool>,
|
||||||
_marker: PhantomData<S>,
|
_marker: PhantomData<S>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||||
impl<M, S> Serve<M, S> {
|
impl<L, M, S> Serve<L, M, S>
|
||||||
|
where
|
||||||
|
L: Listener,
|
||||||
|
{
|
||||||
/// Prepares a server to handle graceful shutdown when the provided future completes.
|
/// Prepares a server to handle graceful shutdown when the provided future completes.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
|
@ -136,76 +140,57 @@ impl<M, S> Serve<M, S> {
|
||||||
/// // ...
|
/// // ...
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
|
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
|
||||||
where
|
where
|
||||||
F: Future<Output = ()> + Send + 'static,
|
F: Future<Output = ()> + Send + 'static,
|
||||||
{
|
{
|
||||||
WithGracefulShutdown {
|
WithGracefulShutdown {
|
||||||
tcp_listener: self.tcp_listener,
|
listener: self.listener,
|
||||||
make_service: self.make_service,
|
make_service: self.make_service,
|
||||||
signal,
|
signal,
|
||||||
tcp_nodelay: self.tcp_nodelay,
|
|
||||||
_marker: PhantomData,
|
_marker: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
|
|
||||||
///
|
|
||||||
/// See also [`TcpStream::set_nodelay`].
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
/// ```
|
|
||||||
/// use axum::{Router, routing::get};
|
|
||||||
///
|
|
||||||
/// # async {
|
|
||||||
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
|
|
||||||
///
|
|
||||||
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
|
|
||||||
/// axum::serve(listener, router)
|
|
||||||
/// .tcp_nodelay(true)
|
|
||||||
/// .await
|
|
||||||
/// .unwrap();
|
|
||||||
/// # };
|
|
||||||
/// ```
|
|
||||||
pub fn tcp_nodelay(self, nodelay: bool) -> Self {
|
|
||||||
Self {
|
|
||||||
tcp_nodelay: Some(nodelay),
|
|
||||||
..self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the local address this server is bound to.
|
/// Returns the local address this server is bound to.
|
||||||
pub fn local_addr(&self) -> io::Result<SocketAddr> {
|
pub fn local_addr(&self) -> io::Result<L::Addr> {
|
||||||
self.tcp_listener.local_addr()
|
self.listener.local_addr()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||||
impl<M, S> Debug for Serve<M, S>
|
impl<L, M, S> Debug for Serve<L, M, S>
|
||||||
where
|
where
|
||||||
|
L: Debug + 'static,
|
||||||
M: Debug,
|
M: Debug,
|
||||||
{
|
{
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
let Self {
|
let Self {
|
||||||
tcp_listener,
|
listener,
|
||||||
make_service,
|
make_service,
|
||||||
tcp_nodelay,
|
tcp_nodelay,
|
||||||
_marker: _,
|
_marker: _,
|
||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
f.debug_struct("Serve")
|
let mut s = f.debug_struct("Serve");
|
||||||
.field("tcp_listener", tcp_listener)
|
s.field("listener", listener)
|
||||||
.field("make_service", make_service)
|
.field("make_service", make_service);
|
||||||
.field("tcp_nodelay", tcp_nodelay)
|
|
||||||
.finish()
|
if TypeId::of::<L>() == TypeId::of::<TcpListener>() {
|
||||||
|
s.field("tcp_nodelay", tcp_nodelay);
|
||||||
|
}
|
||||||
|
|
||||||
|
s.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||||
impl<M, S> IntoFuture for Serve<M, S>
|
impl<L, M, S> IntoFuture for Serve<L, M, S>
|
||||||
where
|
where
|
||||||
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
|
L: Listener,
|
||||||
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
|
L::Addr: Debug,
|
||||||
|
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
|
||||||
|
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
|
||||||
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
||||||
S::Future: Send,
|
S::Future: Send,
|
||||||
{
|
{
|
||||||
|
@ -221,81 +206,55 @@ where
|
||||||
/// Serve future with graceful shutdown enabled.
|
/// Serve future with graceful shutdown enabled.
|
||||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||||
#[must_use = "futures must be awaited or polled"]
|
#[must_use = "futures must be awaited or polled"]
|
||||||
pub struct WithGracefulShutdown<M, S, F> {
|
pub struct WithGracefulShutdown<L, M, S, F> {
|
||||||
tcp_listener: TcpListener,
|
listener: L,
|
||||||
make_service: M,
|
make_service: M,
|
||||||
signal: F,
|
signal: F,
|
||||||
tcp_nodelay: Option<bool>,
|
|
||||||
_marker: PhantomData<S>,
|
_marker: PhantomData<S>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M, S, F> WithGracefulShutdown<M, S, F> {
|
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||||
/// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
|
impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
|
||||||
///
|
where
|
||||||
/// See also [`TcpStream::set_nodelay`].
|
L: Listener,
|
||||||
///
|
{
|
||||||
/// # Example
|
|
||||||
/// ```
|
|
||||||
/// use axum::{Router, routing::get};
|
|
||||||
///
|
|
||||||
/// # async {
|
|
||||||
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
|
|
||||||
///
|
|
||||||
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
|
|
||||||
/// axum::serve(listener, router)
|
|
||||||
/// .with_graceful_shutdown(shutdown_signal())
|
|
||||||
/// .tcp_nodelay(true)
|
|
||||||
/// .await
|
|
||||||
/// .unwrap();
|
|
||||||
/// # };
|
|
||||||
///
|
|
||||||
/// async fn shutdown_signal() {
|
|
||||||
/// // ...
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
pub fn tcp_nodelay(self, nodelay: bool) -> Self {
|
|
||||||
Self {
|
|
||||||
tcp_nodelay: Some(nodelay),
|
|
||||||
..self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the local address this server is bound to.
|
/// Returns the local address this server is bound to.
|
||||||
pub fn local_addr(&self) -> io::Result<SocketAddr> {
|
pub fn local_addr(&self) -> io::Result<L::Addr> {
|
||||||
self.tcp_listener.local_addr()
|
self.listener.local_addr()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||||
impl<M, S, F> Debug for WithGracefulShutdown<M, S, F>
|
impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
|
||||||
where
|
where
|
||||||
|
L: Debug + 'static,
|
||||||
M: Debug,
|
M: Debug,
|
||||||
S: Debug,
|
S: Debug,
|
||||||
F: Debug,
|
F: Debug,
|
||||||
{
|
{
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
let Self {
|
let Self {
|
||||||
tcp_listener,
|
listener,
|
||||||
make_service,
|
make_service,
|
||||||
signal,
|
signal,
|
||||||
tcp_nodelay,
|
|
||||||
_marker: _,
|
_marker: _,
|
||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
f.debug_struct("WithGracefulShutdown")
|
f.debug_struct("WithGracefulShutdown")
|
||||||
.field("tcp_listener", tcp_listener)
|
.field("listener", listener)
|
||||||
.field("make_service", make_service)
|
.field("make_service", make_service)
|
||||||
.field("signal", signal)
|
.field("signal", signal)
|
||||||
.field("tcp_nodelay", tcp_nodelay)
|
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||||
impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
|
impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
|
||||||
where
|
where
|
||||||
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
|
L: Listener,
|
||||||
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
|
L::Addr: Debug,
|
||||||
|
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
|
||||||
|
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
|
||||||
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
||||||
S::Future: Send,
|
S::Future: Send,
|
||||||
F: Future<Output = ()> + Send + 'static,
|
F: Future<Output = ()> + Send + 'static,
|
||||||
|
@ -305,10 +264,9 @@ where
|
||||||
|
|
||||||
fn into_future(self) -> Self::IntoFuture {
|
fn into_future(self) -> Self::IntoFuture {
|
||||||
let Self {
|
let Self {
|
||||||
tcp_listener,
|
mut listener,
|
||||||
mut make_service,
|
mut make_service,
|
||||||
signal,
|
signal,
|
||||||
tcp_nodelay,
|
|
||||||
_marker: _,
|
_marker: _,
|
||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
|
@ -324,28 +282,17 @@ where
|
||||||
let (close_tx, close_rx) = watch::channel(());
|
let (close_tx, close_rx) = watch::channel(());
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let (tcp_stream, remote_addr) = tokio::select! {
|
let (io, remote_addr) = tokio::select! {
|
||||||
conn = tcp_accept(&tcp_listener) => {
|
conn = listener.accept() => conn,
|
||||||
match conn {
|
|
||||||
Some(conn) => conn,
|
|
||||||
None => continue,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ = signal_tx.closed() => {
|
_ = signal_tx.closed() => {
|
||||||
trace!("signal received, not accepting new connections");
|
trace!("signal received, not accepting new connections");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(nodelay) = tcp_nodelay {
|
let io = TokioIo::new(io);
|
||||||
if let Err(err) = tcp_stream.set_nodelay(nodelay) {
|
|
||||||
trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let tcp_stream = TokioIo::new(tcp_stream);
|
trace!("connection {remote_addr:?} accepted");
|
||||||
|
|
||||||
trace!("connection {remote_addr} accepted");
|
|
||||||
|
|
||||||
poll_fn(|cx| make_service.poll_ready(cx))
|
poll_fn(|cx| make_service.poll_ready(cx))
|
||||||
.await
|
.await
|
||||||
|
@ -353,7 +300,7 @@ where
|
||||||
|
|
||||||
let tower_service = make_service
|
let tower_service = make_service
|
||||||
.call(IncomingStream {
|
.call(IncomingStream {
|
||||||
tcp_stream: &tcp_stream,
|
io: &io,
|
||||||
remote_addr,
|
remote_addr,
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
|
@ -372,7 +319,7 @@ where
|
||||||
// CONNECT protocol needed for HTTP/2 websockets
|
// CONNECT protocol needed for HTTP/2 websockets
|
||||||
#[cfg(feature = "http2")]
|
#[cfg(feature = "http2")]
|
||||||
builder.http2().enable_connect_protocol();
|
builder.http2().enable_connect_protocol();
|
||||||
let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service);
|
let conn = builder.serve_connection_with_upgrades(io, hyper_service);
|
||||||
pin_mut!(conn);
|
pin_mut!(conn);
|
||||||
|
|
||||||
let signal_closed = signal_tx.closed().fuse();
|
let signal_closed = signal_tx.closed().fuse();
|
||||||
|
@ -393,14 +340,12 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trace!("connection {remote_addr} closed");
|
|
||||||
|
|
||||||
drop(close_rx);
|
drop(close_rx);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
drop(close_rx);
|
drop(close_rx);
|
||||||
drop(tcp_listener);
|
drop(listener);
|
||||||
|
|
||||||
trace!(
|
trace!(
|
||||||
"waiting for {} task(s) to finish",
|
"waiting for {} task(s) to finish",
|
||||||
|
@ -413,38 +358,32 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_connection_error(e: &io::Error) -> bool {
|
/// An incoming stream.
|
||||||
matches!(
|
///
|
||||||
e.kind(),
|
/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
|
||||||
io::ErrorKind::ConnectionRefused
|
///
|
||||||
| io::ErrorKind::ConnectionAborted
|
/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
|
||||||
| io::ErrorKind::ConnectionReset
|
#[derive(Debug)]
|
||||||
)
|
pub struct IncomingStream<'a, L>
|
||||||
|
where
|
||||||
|
L: Listener,
|
||||||
|
{
|
||||||
|
io: &'a TokioIo<L::Io>,
|
||||||
|
remote_addr: L::Addr,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> {
|
impl<L> IncomingStream<'_, L>
|
||||||
match listener.accept().await {
|
where
|
||||||
Ok(conn) => Some(conn),
|
L: Listener,
|
||||||
Err(e) => {
|
{
|
||||||
if is_connection_error(&e) {
|
/// Get a reference to the inner IO type.
|
||||||
return None;
|
pub fn io(&self) -> &L::Io {
|
||||||
|
self.io.inner()
|
||||||
}
|
}
|
||||||
|
|
||||||
// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
|
/// Returns the remote address that this stream is bound to.
|
||||||
//
|
pub fn remote_addr(&self) -> &L::Addr {
|
||||||
// > A possible scenario is that the process has hit the max open files
|
&self.remote_addr
|
||||||
// > allowed, and so trying to accept a new connection will fail with
|
|
||||||
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
|
|
||||||
// > the application will likely close some files (or connections), and try
|
|
||||||
// > to accept the connection again. If this option is `true`, the error
|
|
||||||
// > will be logged at the `error` level, since it is still a big deal,
|
|
||||||
// > and then the listener will sleep for 1 second.
|
|
||||||
//
|
|
||||||
// hyper allowed customizing this but axum does not.
|
|
||||||
error!("accept error: {e}");
|
|
||||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -474,68 +413,89 @@ mod private {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An incoming stream.
|
|
||||||
///
|
|
||||||
/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
|
|
||||||
///
|
|
||||||
/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct IncomingStream<'a> {
|
|
||||||
tcp_stream: &'a TokioIo<TcpStream>,
|
|
||||||
remote_addr: SocketAddr,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IncomingStream<'_> {
|
|
||||||
/// Returns the local address that this stream is bound to.
|
|
||||||
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
|
|
||||||
self.tcp_stream.inner().local_addr()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the remote address that this stream is bound to.
|
|
||||||
pub fn remote_addr(&self) -> SocketAddr {
|
|
||||||
self.remote_addr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use std::{
|
||||||
|
future::{pending, IntoFuture as _},
|
||||||
|
net::{IpAddr, Ipv4Addr},
|
||||||
|
};
|
||||||
|
|
||||||
|
use axum_core::{body::Body, extract::Request};
|
||||||
|
use http::StatusCode;
|
||||||
|
use hyper_util::rt::TokioIo;
|
||||||
|
use tokio::{
|
||||||
|
io::{self, AsyncRead, AsyncWrite},
|
||||||
|
net::{TcpListener, UnixListener},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{serve, IncomingStream, Listener};
|
||||||
use crate::{
|
use crate::{
|
||||||
|
body::to_bytes,
|
||||||
|
extract::connect_info::Connected,
|
||||||
handler::{Handler, HandlerWithoutStateExt},
|
handler::{Handler, HandlerWithoutStateExt},
|
||||||
routing::get,
|
routing::get,
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
use std::{
|
|
||||||
future::pending,
|
|
||||||
net::{IpAddr, Ipv4Addr},
|
|
||||||
};
|
|
||||||
|
|
||||||
#[allow(dead_code, unused_must_use)]
|
#[allow(dead_code, unused_must_use)]
|
||||||
async fn if_it_compiles_it_works() {
|
async fn if_it_compiles_it_works() {
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct UdsConnectInfo;
|
||||||
|
|
||||||
|
impl Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
|
||||||
|
fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let router: Router = Router::new();
|
let router: Router = Router::new();
|
||||||
|
|
||||||
let addr = "0.0.0.0:0";
|
let addr = "0.0.0.0:0";
|
||||||
|
|
||||||
// router
|
// router
|
||||||
serve(TcpListener::bind(addr).await.unwrap(), router.clone());
|
serve(TcpListener::bind(addr).await.unwrap(), router.clone());
|
||||||
|
serve(UnixListener::bind("").unwrap(), router.clone());
|
||||||
|
|
||||||
serve(
|
serve(
|
||||||
TcpListener::bind(addr).await.unwrap(),
|
TcpListener::bind(addr).await.unwrap(),
|
||||||
router.clone().into_make_service(),
|
router.clone().into_make_service(),
|
||||||
);
|
);
|
||||||
|
serve(
|
||||||
|
UnixListener::bind("").unwrap(),
|
||||||
|
router.clone().into_make_service(),
|
||||||
|
);
|
||||||
|
|
||||||
serve(
|
serve(
|
||||||
TcpListener::bind(addr).await.unwrap(),
|
TcpListener::bind(addr).await.unwrap(),
|
||||||
router.into_make_service_with_connect_info::<SocketAddr>(),
|
router
|
||||||
|
.clone()
|
||||||
|
.into_make_service_with_connect_info::<std::net::SocketAddr>(),
|
||||||
|
);
|
||||||
|
serve(
|
||||||
|
UnixListener::bind("").unwrap(),
|
||||||
|
router.into_make_service_with_connect_info::<UdsConnectInfo>(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// method router
|
// method router
|
||||||
serve(TcpListener::bind(addr).await.unwrap(), get(handler));
|
serve(TcpListener::bind(addr).await.unwrap(), get(handler));
|
||||||
|
serve(UnixListener::bind("").unwrap(), get(handler));
|
||||||
|
|
||||||
serve(
|
serve(
|
||||||
TcpListener::bind(addr).await.unwrap(),
|
TcpListener::bind(addr).await.unwrap(),
|
||||||
get(handler).into_make_service(),
|
get(handler).into_make_service(),
|
||||||
);
|
);
|
||||||
|
serve(
|
||||||
|
UnixListener::bind("").unwrap(),
|
||||||
|
get(handler).into_make_service(),
|
||||||
|
);
|
||||||
|
|
||||||
serve(
|
serve(
|
||||||
TcpListener::bind(addr).await.unwrap(),
|
TcpListener::bind(addr).await.unwrap(),
|
||||||
get(handler).into_make_service_with_connect_info::<SocketAddr>(),
|
get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
|
||||||
|
);
|
||||||
|
serve(
|
||||||
|
UnixListener::bind("").unwrap(),
|
||||||
|
get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// handler
|
// handler
|
||||||
|
@ -543,32 +503,28 @@ mod tests {
|
||||||
TcpListener::bind(addr).await.unwrap(),
|
TcpListener::bind(addr).await.unwrap(),
|
||||||
handler.into_service(),
|
handler.into_service(),
|
||||||
);
|
);
|
||||||
|
serve(UnixListener::bind("").unwrap(), handler.into_service());
|
||||||
|
|
||||||
serve(
|
serve(
|
||||||
TcpListener::bind(addr).await.unwrap(),
|
TcpListener::bind(addr).await.unwrap(),
|
||||||
handler.with_state(()),
|
handler.with_state(()),
|
||||||
);
|
);
|
||||||
|
serve(UnixListener::bind("").unwrap(), handler.with_state(()));
|
||||||
|
|
||||||
serve(
|
serve(
|
||||||
TcpListener::bind(addr).await.unwrap(),
|
TcpListener::bind(addr).await.unwrap(),
|
||||||
handler.into_make_service(),
|
handler.into_make_service(),
|
||||||
);
|
);
|
||||||
|
serve(UnixListener::bind("").unwrap(), handler.into_make_service());
|
||||||
|
|
||||||
serve(
|
serve(
|
||||||
TcpListener::bind(addr).await.unwrap(),
|
TcpListener::bind(addr).await.unwrap(),
|
||||||
handler.into_make_service_with_connect_info::<SocketAddr>(),
|
handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// nodelay
|
|
||||||
serve(
|
serve(
|
||||||
TcpListener::bind(addr).await.unwrap(),
|
UnixListener::bind("").unwrap(),
|
||||||
handler.into_service(),
|
handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
|
||||||
)
|
);
|
||||||
.tcp_nodelay(true);
|
|
||||||
|
|
||||||
serve(
|
|
||||||
TcpListener::bind(addr).await.unwrap(),
|
|
||||||
handler.into_service(),
|
|
||||||
)
|
|
||||||
.with_graceful_shutdown(async { /*...*/ })
|
|
||||||
.tcp_nodelay(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handler() {}
|
async fn handler() {}
|
||||||
|
@ -613,4 +569,49 @@ mod tests {
|
||||||
// Call Serve::into_future outside of a tokio context. This used to panic.
|
// Call Serve::into_future outside of a tokio context. This used to panic.
|
||||||
_ = serve(listener, router).into_future();
|
_ = serve(listener, router).into_future();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[crate::test]
|
||||||
|
async fn serving_on_custom_io_type() {
|
||||||
|
struct ReadyListener<T>(Option<T>);
|
||||||
|
|
||||||
|
impl<T> Listener for ReadyListener<T>
|
||||||
|
where
|
||||||
|
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
type Io = T;
|
||||||
|
type Addr = ();
|
||||||
|
|
||||||
|
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
|
||||||
|
match self.0.take() {
|
||||||
|
Some(server) => (server, ()),
|
||||||
|
None => std::future::pending().await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn local_addr(&self) -> io::Result<Self::Addr> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (client, server) = io::duplex(1024);
|
||||||
|
let listener = ReadyListener(Some(server));
|
||||||
|
|
||||||
|
let app = Router::new().route("/", get(|| async { "Hello, World!" }));
|
||||||
|
|
||||||
|
tokio::spawn(serve(listener, app).into_future());
|
||||||
|
|
||||||
|
let stream = TokioIo::new(client);
|
||||||
|
let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap();
|
||||||
|
tokio::spawn(conn);
|
||||||
|
|
||||||
|
let request = Request::builder().body(Body::empty()).unwrap();
|
||||||
|
|
||||||
|
let response = sender.send_request(request).await.unwrap();
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
let body = Body::new(response.into_body());
|
||||||
|
let body = to_bytes(body, usize::MAX).await.unwrap();
|
||||||
|
let body = String::from_utf8(body.to_vec()).unwrap();
|
||||||
|
assert_eq!(body, "Hello, World!");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
167
axum/src/serve/listener.rs
Normal file
167
axum/src/serve/listener.rs
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
use std::{fmt, future::Future, time::Duration};
|
||||||
|
|
||||||
|
use tokio::{
|
||||||
|
io::{self, AsyncRead, AsyncWrite},
|
||||||
|
net::{TcpListener, TcpStream},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Types that can listen for connections.
|
||||||
|
pub trait Listener: Send + 'static {
|
||||||
|
/// The listener's IO type.
|
||||||
|
type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||||
|
|
||||||
|
/// The listener's address type.
|
||||||
|
type Addr: Send;
|
||||||
|
|
||||||
|
/// Accept a new incoming connection to this listener.
|
||||||
|
///
|
||||||
|
/// If the underlying accept call can return an error, this function must
|
||||||
|
/// take care of logging and retrying.
|
||||||
|
fn accept(&mut self) -> impl Future<Output = (Self::Io, Self::Addr)> + Send;
|
||||||
|
|
||||||
|
/// Returns the local address that this listener is bound to.
|
||||||
|
fn local_addr(&self) -> io::Result<Self::Addr>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Listener for TcpListener {
|
||||||
|
type Io = TcpStream;
|
||||||
|
type Addr = std::net::SocketAddr;
|
||||||
|
|
||||||
|
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
|
||||||
|
loop {
|
||||||
|
match Self::accept(self).await {
|
||||||
|
Ok(tup) => return tup,
|
||||||
|
Err(e) => handle_accept_error(e).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn local_addr(&self) -> io::Result<Self::Addr> {
|
||||||
|
Self::local_addr(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
impl Listener for tokio::net::UnixListener {
|
||||||
|
type Io = tokio::net::UnixStream;
|
||||||
|
type Addr = tokio::net::unix::SocketAddr;
|
||||||
|
|
||||||
|
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
|
||||||
|
loop {
|
||||||
|
match Self::accept(self).await {
|
||||||
|
Ok(tup) => return tup,
|
||||||
|
Err(e) => handle_accept_error(e).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn local_addr(&self) -> io::Result<Self::Addr> {
|
||||||
|
Self::local_addr(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extensions to [`Listener`].
|
||||||
|
pub trait ListenerExt: Listener + Sized {
|
||||||
|
/// Run a mutable closure on every accepted `Io`.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use axum::{Router, routing::get, serve::ListenerExt};
|
||||||
|
/// use tracing::trace;
|
||||||
|
///
|
||||||
|
/// # async {
|
||||||
|
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
|
||||||
|
///
|
||||||
|
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
|
||||||
|
/// .await
|
||||||
|
/// .unwrap()
|
||||||
|
/// .tap_io(|tcp_stream| {
|
||||||
|
/// if let Err(err) = tcp_stream.set_nodelay(true) {
|
||||||
|
/// trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
|
||||||
|
/// }
|
||||||
|
/// });
|
||||||
|
/// axum::serve(listener, router).await.unwrap();
|
||||||
|
/// # };
|
||||||
|
/// ```
|
||||||
|
fn tap_io<F>(self, tap_fn: F) -> TapIo<Self, F>
|
||||||
|
where
|
||||||
|
F: FnMut(&mut Self::Io) + Send + 'static,
|
||||||
|
{
|
||||||
|
TapIo {
|
||||||
|
listener: self,
|
||||||
|
tap_fn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<L: Listener> ListenerExt for L {}
|
||||||
|
|
||||||
|
/// Return type of [`ListenerExt::tap_io`].
|
||||||
|
///
|
||||||
|
/// See that method for details.
|
||||||
|
pub struct TapIo<L: Listener, F> {
|
||||||
|
listener: L,
|
||||||
|
tap_fn: F,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<L, F> fmt::Debug for TapIo<L, F>
|
||||||
|
where
|
||||||
|
L: Listener + fmt::Debug,
|
||||||
|
{
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_struct("TapIo")
|
||||||
|
.field("listener", &self.listener)
|
||||||
|
.finish_non_exhaustive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<L, F> Listener for TapIo<L, F>
|
||||||
|
where
|
||||||
|
L: Listener,
|
||||||
|
F: FnMut(&mut L::Io) + Send + 'static,
|
||||||
|
{
|
||||||
|
type Io = L::Io;
|
||||||
|
type Addr = L::Addr;
|
||||||
|
|
||||||
|
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
|
||||||
|
let (mut io, addr) = self.listener.accept().await;
|
||||||
|
(self.tap_fn)(&mut io);
|
||||||
|
(io, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn local_addr(&self) -> io::Result<Self::Addr> {
|
||||||
|
self.listener.local_addr()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_accept_error(e: io::Error) {
|
||||||
|
if is_connection_error(&e) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
|
||||||
|
//
|
||||||
|
// > A possible scenario is that the process has hit the max open files
|
||||||
|
// > allowed, and so trying to accept a new connection will fail with
|
||||||
|
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
|
||||||
|
// > the application will likely close some files (or connections), and try
|
||||||
|
// > to accept the connection again. If this option is `true`, the error
|
||||||
|
// > will be logged at the `error` level, since it is still a big deal,
|
||||||
|
// > and then the listener will sleep for 1 second.
|
||||||
|
//
|
||||||
|
// hyper allowed customizing this but axum does not.
|
||||||
|
error!("accept error: {e}");
|
||||||
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_connection_error(e: &io::Error) -> bool {
|
||||||
|
matches!(
|
||||||
|
e.kind(),
|
||||||
|
io::ErrorKind::ConnectionRefused
|
||||||
|
| io::ErrorKind::ConnectionAborted
|
||||||
|
| io::ErrorKind::ConnectionReset
|
||||||
|
)
|
||||||
|
}
|
|
@ -21,17 +21,13 @@ mod unix {
|
||||||
extract::connect_info::{self, ConnectInfo},
|
extract::connect_info::{self, ConnectInfo},
|
||||||
http::{Method, Request, StatusCode},
|
http::{Method, Request, StatusCode},
|
||||||
routing::get,
|
routing::get,
|
||||||
|
serve::IncomingStream,
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
use http_body_util::BodyExt;
|
use http_body_util::BodyExt;
|
||||||
use hyper::body::Incoming;
|
use hyper_util::rt::TokioIo;
|
||||||
use hyper_util::{
|
use std::{path::PathBuf, sync::Arc};
|
||||||
rt::{TokioExecutor, TokioIo},
|
|
||||||
server,
|
|
||||||
};
|
|
||||||
use std::{convert::Infallible, path::PathBuf, sync::Arc};
|
|
||||||
use tokio::net::{unix::UCred, UnixListener, UnixStream};
|
use tokio::net::{unix::UCred, UnixListener, UnixStream};
|
||||||
use tower::Service;
|
|
||||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||||
|
|
||||||
pub async fn server() {
|
pub async fn server() {
|
||||||
|
@ -52,33 +48,11 @@ mod unix {
|
||||||
|
|
||||||
let uds = UnixListener::bind(path.clone()).unwrap();
|
let uds = UnixListener::bind(path.clone()).unwrap();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let app = Router::new().route("/", get(handler));
|
let app = Router::new()
|
||||||
|
.route("/", get(handler))
|
||||||
|
.into_make_service_with_connect_info::<UdsConnectInfo>();
|
||||||
|
|
||||||
let mut make_service = app.into_make_service_with_connect_info::<UdsConnectInfo>();
|
axum::serve(uds, app).await.unwrap();
|
||||||
|
|
||||||
// See https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs for
|
|
||||||
// more details about this setup
|
|
||||||
loop {
|
|
||||||
let (socket, _remote_addr) = uds.accept().await.unwrap();
|
|
||||||
|
|
||||||
let tower_service = unwrap_infallible(make_service.call(&socket).await);
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let socket = TokioIo::new(socket);
|
|
||||||
|
|
||||||
let hyper_service =
|
|
||||||
hyper::service::service_fn(move |request: Request<Incoming>| {
|
|
||||||
tower_service.clone().call(request)
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new())
|
|
||||||
.serve_connection_with_upgrades(socket, hyper_service)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
eprintln!("failed to serve connection: {err:#}");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
let stream = TokioIo::new(UnixStream::connect(path).await.unwrap());
|
let stream = TokioIo::new(UnixStream::connect(path).await.unwrap());
|
||||||
|
@ -117,22 +91,14 @@ mod unix {
|
||||||
peer_cred: UCred,
|
peer_cred: UCred,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl connect_info::Connected<&UnixStream> for UdsConnectInfo {
|
impl connect_info::Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
|
||||||
fn connect_info(target: &UnixStream) -> Self {
|
fn connect_info(stream: IncomingStream<'_, UnixListener>) -> Self {
|
||||||
let peer_addr = target.peer_addr().unwrap();
|
let peer_addr = stream.io().peer_addr().unwrap();
|
||||||
let peer_cred = target.peer_cred().unwrap();
|
let peer_cred = stream.io().peer_cred().unwrap();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
peer_addr: Arc::new(peer_addr),
|
peer_addr: Arc::new(peer_addr),
|
||||||
peer_cred,
|
peer_cred,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
|
|
||||||
match result {
|
|
||||||
Ok(value) => value,
|
|
||||||
Err(err) => match err {},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue