Make serve generic over the listener and IO types (#2941)

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
Jonas Platte 2024-11-30 22:47:13 +00:00 committed by GitHub
parent d84136e1e4
commit 84c3960639
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 399 additions and 254 deletions

View file

@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Extend `FailedToDeserializePathParams::kind` enum with (`ErrorKind::DeserializeError`)
This new variant captures both `key`, `value`, and `message` from named path parameters parse errors,
instead of only deserialization error message in `ErrorKind::Message`. ([#2720])
- **breaking:** Make `serve` generic over the listener and IO types ([#2941])
[#2897]: https://github.com/tokio-rs/axum/pull/2897
[#2903]: https://github.com/tokio-rs/axum/pull/2903
@ -34,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#2992]: https://github.com/tokio-rs/axum/pull/2992
[#2720]: https://github.com/tokio-rs/axum/pull/2720
[#3039]: https://github.com/tokio-rs/axum/pull/3039
[#2941]: https://github.com/tokio-rs/axum/pull/2941
# 0.8.0

View file

@ -35,6 +35,7 @@ use axum::{
serve::IncomingStream,
Router,
};
use tokio::net::TcpListener;
let app = Router::new().route("/", get(handler));
@ -49,8 +50,8 @@ struct MyConnectInfo {
// ...
}
impl Connected<IncomingStream<'_>> for MyConnectInfo {
fn connect_info(target: IncomingStream<'_>) -> Self {
impl Connected<IncomingStream<'_, TcpListener>> for MyConnectInfo {
fn connect_info(target: IncomingStream<'_, TcpListener>) -> Self {
MyConnectInfo {
// ...
}

View file

@ -79,16 +79,17 @@ where
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
pub trait Connected<T>: Clone + Send + Sync + 'static {
/// Create type holding information about the connection.
fn connect_info(target: T) -> Self;
fn connect_info(stream: T) -> Self;
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = {
use crate::serve::IncomingStream;
use crate::serve;
use tokio::net::TcpListener;
impl Connected<IncomingStream<'_>> for SocketAddr {
fn connect_info(target: IncomingStream<'_>) -> Self {
target.remote_addr()
impl Connected<serve::IncomingStream<'_, TcpListener>> for SocketAddr {
fn connect_info(stream: serve::IncomingStream<'_, TcpListener>) -> Self {
*stream.remote_addr()
}
}
};
@ -261,8 +262,8 @@ mod tests {
value: &'static str,
}
impl Connected<IncomingStream<'_>> for MyConnectInfo {
fn connect_info(_target: IncomingStream<'_>) -> Self {
impl Connected<IncomingStream<'_, TcpListener>> for MyConnectInfo {
fn connect_info(_target: IncomingStream<'_, TcpListener>) -> Self {
Self {
value: "it worked!",
}

View file

@ -180,12 +180,13 @@ where
// for `axum::serve(listener, handler)`
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = {
use crate::serve::IncomingStream;
use crate::serve;
impl<H, T, S> Service<IncomingStream<'_>> for HandlerService<H, T, S>
impl<H, T, S, L> Service<serve::IncomingStream<'_, L>> for HandlerService<H, T, S>
where
H: Clone,
S: Clone,
L: serve::Listener,
{
type Response = Self;
type Error = Infallible;
@ -195,7 +196,7 @@ const _: () = {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
std::future::ready(Ok(self.clone()))
}
}

View file

@ -1313,9 +1313,12 @@ where
// for `axum::serve(listener, router)`
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = {
use crate::serve::IncomingStream;
use crate::serve;
impl Service<IncomingStream<'_>> for MethodRouter<()> {
impl<L> Service<serve::IncomingStream<'_, L>> for MethodRouter<()>
where
L: serve::Listener,
{
type Response = Self;
type Error = Infallible;
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
@ -1324,7 +1327,7 @@ const _: () = {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
std::future::ready(Ok(self.clone().with_state(())))
}
}

View file

@ -518,9 +518,12 @@ impl Router {
// for `axum::serve(listener, router)`
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = {
use crate::serve::IncomingStream;
use crate::serve;
impl Service<IncomingStream<'_>> for Router<()> {
impl<L> Service<serve::IncomingStream<'_, L>> for Router<()>
where
L: serve::Listener,
{
type Response = Self;
type Error = Infallible;
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
@ -529,7 +532,7 @@ const _: () = {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
// call `Router::with_state` such that everything is turned into `Route` eagerly
// rather than doing that per request
std::future::ready(Ok(self.clone().with_state(())))

View file

@ -1,14 +1,13 @@
//! Serve services.
use std::{
any::TypeId,
convert::Infallible,
fmt::Debug,
future::{poll_fn, Future, IntoFuture},
io,
marker::PhantomData,
net::SocketAddr,
sync::Arc,
time::Duration,
};
use axum_core::{body::Body, extract::Request, response::Response};
@ -17,13 +16,14 @@ use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
#[cfg(any(feature = "http1", feature = "http2"))]
use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
use tokio::{
net::{TcpListener, TcpStream},
sync::watch,
};
use tokio::{net::TcpListener, sync::watch};
use tower::ServiceExt as _;
use tower_service::Service;
mod listener;
pub use self::listener::{Listener, ListenerExt, TapIo};
/// Serve the service with the supplied listener.
///
/// This method of running a service is intentionally simple and doesn't support any configuration.
@ -89,14 +89,15 @@ use tower_service::Service;
/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S>
pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
L: Listener,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
Serve {
tcp_listener,
listener,
make_service,
tcp_nodelay: None,
_marker: PhantomData,
@ -106,15 +107,18 @@ where
/// Future returned by [`serve`].
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"]
pub struct Serve<M, S> {
tcp_listener: TcpListener,
pub struct Serve<L, M, S> {
listener: L,
make_service: M,
tcp_nodelay: Option<bool>,
_marker: PhantomData<S>,
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Serve<M, S> {
impl<L, M, S> Serve<L, M, S>
where
L: Listener,
{
/// Prepares a server to handle graceful shutdown when the provided future completes.
///
/// # Example
@ -136,76 +140,57 @@ impl<M, S> Serve<M, S> {
/// // ...
/// }
/// ```
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
where
F: Future<Output = ()> + Send + 'static,
{
WithGracefulShutdown {
tcp_listener: self.tcp_listener,
listener: self.listener,
make_service: self.make_service,
signal,
tcp_nodelay: self.tcp_nodelay,
_marker: PhantomData,
}
}
/// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
///
/// See also [`TcpStream::set_nodelay`].
///
/// # Example
/// ```
/// use axum::{Router, routing::get};
///
/// # async {
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
///
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
/// axum::serve(listener, router)
/// .tcp_nodelay(true)
/// .await
/// .unwrap();
/// # };
/// ```
pub fn tcp_nodelay(self, nodelay: bool) -> Self {
Self {
tcp_nodelay: Some(nodelay),
..self
}
}
/// Returns the local address this server is bound to.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.tcp_listener.local_addr()
pub fn local_addr(&self) -> io::Result<L::Addr> {
self.listener.local_addr()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Debug for Serve<M, S>
impl<L, M, S> Debug for Serve<L, M, S>
where
L: Debug + 'static,
M: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
tcp_listener,
listener,
make_service,
tcp_nodelay,
_marker: _,
} = self;
f.debug_struct("Serve")
.field("tcp_listener", tcp_listener)
.field("make_service", make_service)
.field("tcp_nodelay", tcp_nodelay)
.finish()
let mut s = f.debug_struct("Serve");
s.field("listener", listener)
.field("make_service", make_service);
if TypeId::of::<L>() == TypeId::of::<TcpListener>() {
s.field("tcp_nodelay", tcp_nodelay);
}
s.finish()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> IntoFuture for Serve<M, S>
impl<L, M, S> IntoFuture for Serve<L, M, S>
where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
L: Listener,
L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
@ -221,81 +206,55 @@ where
/// Serve future with graceful shutdown enabled.
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"]
pub struct WithGracefulShutdown<M, S, F> {
tcp_listener: TcpListener,
pub struct WithGracefulShutdown<L, M, S, F> {
listener: L,
make_service: M,
signal: F,
tcp_nodelay: Option<bool>,
_marker: PhantomData<S>,
}
impl<M, S, F> WithGracefulShutdown<M, S, F> {
/// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
///
/// See also [`TcpStream::set_nodelay`].
///
/// # Example
/// ```
/// use axum::{Router, routing::get};
///
/// # async {
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
///
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
/// axum::serve(listener, router)
/// .with_graceful_shutdown(shutdown_signal())
/// .tcp_nodelay(true)
/// .await
/// .unwrap();
/// # };
///
/// async fn shutdown_signal() {
/// // ...
/// }
/// ```
pub fn tcp_nodelay(self, nodelay: bool) -> Self {
Self {
tcp_nodelay: Some(nodelay),
..self
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
where
L: Listener,
{
/// Returns the local address this server is bound to.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.tcp_listener.local_addr()
pub fn local_addr(&self) -> io::Result<L::Addr> {
self.listener.local_addr()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> Debug for WithGracefulShutdown<M, S, F>
impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
where
L: Debug + 'static,
M: Debug,
S: Debug,
F: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
tcp_listener,
listener,
make_service,
signal,
tcp_nodelay,
_marker: _,
} = self;
f.debug_struct("WithGracefulShutdown")
.field("tcp_listener", tcp_listener)
.field("listener", listener)
.field("make_service", make_service)
.field("signal", signal)
.field("tcp_nodelay", tcp_nodelay)
.finish()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
L: Listener,
L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
F: Future<Output = ()> + Send + 'static,
@ -305,10 +264,9 @@ where
fn into_future(self) -> Self::IntoFuture {
let Self {
tcp_listener,
mut listener,
mut make_service,
signal,
tcp_nodelay,
_marker: _,
} = self;
@ -324,28 +282,17 @@ where
let (close_tx, close_rx) = watch::channel(());
loop {
let (tcp_stream, remote_addr) = tokio::select! {
conn = tcp_accept(&tcp_listener) => {
match conn {
Some(conn) => conn,
None => continue,
}
}
let (io, remote_addr) = tokio::select! {
conn = listener.accept() => conn,
_ = signal_tx.closed() => {
trace!("signal received, not accepting new connections");
break;
}
};
if let Some(nodelay) = tcp_nodelay {
if let Err(err) = tcp_stream.set_nodelay(nodelay) {
trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
}
}
let io = TokioIo::new(io);
let tcp_stream = TokioIo::new(tcp_stream);
trace!("connection {remote_addr} accepted");
trace!("connection {remote_addr:?} accepted");
poll_fn(|cx| make_service.poll_ready(cx))
.await
@ -353,7 +300,7 @@ where
let tower_service = make_service
.call(IncomingStream {
tcp_stream: &tcp_stream,
io: &io,
remote_addr,
})
.await
@ -372,7 +319,7 @@ where
// CONNECT protocol needed for HTTP/2 websockets
#[cfg(feature = "http2")]
builder.http2().enable_connect_protocol();
let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service);
let conn = builder.serve_connection_with_upgrades(io, hyper_service);
pin_mut!(conn);
let signal_closed = signal_tx.closed().fuse();
@ -393,14 +340,12 @@ where
}
}
trace!("connection {remote_addr} closed");
drop(close_rx);
});
}
drop(close_rx);
drop(tcp_listener);
drop(listener);
trace!(
"waiting for {} task(s) to finish",
@ -413,38 +358,32 @@ where
}
}
fn is_connection_error(e: &io::Error) -> bool {
matches!(
e.kind(),
io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::ConnectionReset
)
/// An incoming stream.
///
/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
///
/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
#[derive(Debug)]
pub struct IncomingStream<'a, L>
where
L: Listener,
{
io: &'a TokioIo<L::Io>,
remote_addr: L::Addr,
}
async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> {
match listener.accept().await {
Ok(conn) => Some(conn),
Err(e) => {
if is_connection_error(&e) {
return None;
}
impl<L> IncomingStream<'_, L>
where
L: Listener,
{
/// Get a reference to the inner IO type.
pub fn io(&self) -> &L::Io {
self.io.inner()
}
// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
//
// > A possible scenario is that the process has hit the max open files
// > allowed, and so trying to accept a new connection will fail with
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
// > the application will likely close some files (or connections), and try
// > to accept the connection again. If this option is `true`, the error
// > will be logged at the `error` level, since it is still a big deal,
// > and then the listener will sleep for 1 second.
//
// hyper allowed customizing this but axum does not.
error!("accept error: {e}");
tokio::time::sleep(Duration::from_secs(1)).await;
None
}
/// Returns the remote address that this stream is bound to.
pub fn remote_addr(&self) -> &L::Addr {
&self.remote_addr
}
}
@ -474,68 +413,89 @@ mod private {
}
}
/// An incoming stream.
///
/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
///
/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
#[derive(Debug)]
pub struct IncomingStream<'a> {
tcp_stream: &'a TokioIo<TcpStream>,
remote_addr: SocketAddr,
}
impl IncomingStream<'_> {
/// Returns the local address that this stream is bound to.
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.tcp_stream.inner().local_addr()
}
/// Returns the remote address that this stream is bound to.
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
future::{pending, IntoFuture as _},
net::{IpAddr, Ipv4Addr},
};
use axum_core::{body::Body, extract::Request};
use http::StatusCode;
use hyper_util::rt::TokioIo;
use tokio::{
io::{self, AsyncRead, AsyncWrite},
net::{TcpListener, UnixListener},
};
use super::{serve, IncomingStream, Listener};
use crate::{
body::to_bytes,
extract::connect_info::Connected,
handler::{Handler, HandlerWithoutStateExt},
routing::get,
Router,
};
use std::{
future::pending,
net::{IpAddr, Ipv4Addr},
};
#[allow(dead_code, unused_must_use)]
async fn if_it_compiles_it_works() {
#[derive(Clone, Debug)]
struct UdsConnectInfo;
impl Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self {
Self
}
}
let router: Router = Router::new();
let addr = "0.0.0.0:0";
// router
serve(TcpListener::bind(addr).await.unwrap(), router.clone());
serve(UnixListener::bind("").unwrap(), router.clone());
serve(
TcpListener::bind(addr).await.unwrap(),
router.clone().into_make_service(),
);
serve(
UnixListener::bind("").unwrap(),
router.clone().into_make_service(),
);
serve(
TcpListener::bind(addr).await.unwrap(),
router.into_make_service_with_connect_info::<SocketAddr>(),
router
.clone()
.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
router.into_make_service_with_connect_info::<UdsConnectInfo>(),
);
// method router
serve(TcpListener::bind(addr).await.unwrap(), get(handler));
serve(UnixListener::bind("").unwrap(), get(handler));
serve(
TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service(),
);
serve(
UnixListener::bind("").unwrap(),
get(handler).into_make_service(),
);
serve(
TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service_with_connect_info::<SocketAddr>(),
get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
);
// handler
@ -543,32 +503,28 @@ mod tests {
TcpListener::bind(addr).await.unwrap(),
handler.into_service(),
);
serve(UnixListener::bind("").unwrap(), handler.into_service());
serve(
TcpListener::bind(addr).await.unwrap(),
handler.with_state(()),
);
serve(UnixListener::bind("").unwrap(), handler.with_state(()));
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_make_service(),
);
serve(UnixListener::bind("").unwrap(), handler.into_make_service());
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_make_service_with_connect_info::<SocketAddr>(),
handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
// nodelay
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_service(),
)
.tcp_nodelay(true);
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_service(),
)
.with_graceful_shutdown(async { /*...*/ })
.tcp_nodelay(true);
UnixListener::bind("").unwrap(),
handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
);
}
async fn handler() {}
@ -613,4 +569,49 @@ mod tests {
// Call Serve::into_future outside of a tokio context. This used to panic.
_ = serve(listener, router).into_future();
}
#[crate::test]
async fn serving_on_custom_io_type() {
struct ReadyListener<T>(Option<T>);
impl<T> Listener for ReadyListener<T>
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Io = T;
type Addr = ();
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
match self.0.take() {
Some(server) => (server, ()),
None => std::future::pending().await,
}
}
fn local_addr(&self) -> io::Result<Self::Addr> {
Ok(())
}
}
let (client, server) = io::duplex(1024);
let listener = ReadyListener(Some(server));
let app = Router::new().route("/", get(|| async { "Hello, World!" }));
tokio::spawn(serve(listener, app).into_future());
let stream = TokioIo::new(client);
let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap();
tokio::spawn(conn);
let request = Request::builder().body(Body::empty()).unwrap();
let response = sender.send_request(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = Body::new(response.into_body());
let body = to_bytes(body, usize::MAX).await.unwrap();
let body = String::from_utf8(body.to_vec()).unwrap();
assert_eq!(body, "Hello, World!");
}
}

167
axum/src/serve/listener.rs Normal file
View file

@ -0,0 +1,167 @@
use std::{fmt, future::Future, time::Duration};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream},
};
/// Types that can listen for connections.
pub trait Listener: Send + 'static {
/// The listener's IO type.
type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static;
/// The listener's address type.
type Addr: Send;
/// Accept a new incoming connection to this listener.
///
/// If the underlying accept call can return an error, this function must
/// take care of logging and retrying.
fn accept(&mut self) -> impl Future<Output = (Self::Io, Self::Addr)> + Send;
/// Returns the local address that this listener is bound to.
fn local_addr(&self) -> io::Result<Self::Addr>;
}
impl Listener for TcpListener {
type Io = TcpStream;
type Addr = std::net::SocketAddr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
loop {
match Self::accept(self).await {
Ok(tup) => return tup,
Err(e) => handle_accept_error(e).await,
}
}
}
#[inline]
fn local_addr(&self) -> io::Result<Self::Addr> {
Self::local_addr(self)
}
}
#[cfg(unix)]
impl Listener for tokio::net::UnixListener {
type Io = tokio::net::UnixStream;
type Addr = tokio::net::unix::SocketAddr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
loop {
match Self::accept(self).await {
Ok(tup) => return tup,
Err(e) => handle_accept_error(e).await,
}
}
}
#[inline]
fn local_addr(&self) -> io::Result<Self::Addr> {
Self::local_addr(self)
}
}
/// Extensions to [`Listener`].
pub trait ListenerExt: Listener + Sized {
/// Run a mutable closure on every accepted `Io`.
///
/// # Example
///
/// ```
/// use axum::{Router, routing::get, serve::ListenerExt};
/// use tracing::trace;
///
/// # async {
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
///
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
/// .await
/// .unwrap()
/// .tap_io(|tcp_stream| {
/// if let Err(err) = tcp_stream.set_nodelay(true) {
/// trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
/// }
/// });
/// axum::serve(listener, router).await.unwrap();
/// # };
/// ```
fn tap_io<F>(self, tap_fn: F) -> TapIo<Self, F>
where
F: FnMut(&mut Self::Io) + Send + 'static,
{
TapIo {
listener: self,
tap_fn,
}
}
}
impl<L: Listener> ListenerExt for L {}
/// Return type of [`ListenerExt::tap_io`].
///
/// See that method for details.
pub struct TapIo<L: Listener, F> {
listener: L,
tap_fn: F,
}
impl<L, F> fmt::Debug for TapIo<L, F>
where
L: Listener + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TapIo")
.field("listener", &self.listener)
.finish_non_exhaustive()
}
}
impl<L, F> Listener for TapIo<L, F>
where
L: Listener,
F: FnMut(&mut L::Io) + Send + 'static,
{
type Io = L::Io;
type Addr = L::Addr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
let (mut io, addr) = self.listener.accept().await;
(self.tap_fn)(&mut io);
(io, addr)
}
fn local_addr(&self) -> io::Result<Self::Addr> {
self.listener.local_addr()
}
}
async fn handle_accept_error(e: io::Error) {
if is_connection_error(&e) {
return;
}
// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
//
// > A possible scenario is that the process has hit the max open files
// > allowed, and so trying to accept a new connection will fail with
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
// > the application will likely close some files (or connections), and try
// > to accept the connection again. If this option is `true`, the error
// > will be logged at the `error` level, since it is still a big deal,
// > and then the listener will sleep for 1 second.
//
// hyper allowed customizing this but axum does not.
error!("accept error: {e}");
tokio::time::sleep(Duration::from_secs(1)).await;
}
fn is_connection_error(e: &io::Error) -> bool {
matches!(
e.kind(),
io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::ConnectionReset
)
}

View file

@ -21,17 +21,13 @@ mod unix {
extract::connect_info::{self, ConnectInfo},
http::{Method, Request, StatusCode},
routing::get,
serve::IncomingStream,
Router,
};
use http_body_util::BodyExt;
use hyper::body::Incoming;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server,
};
use std::{convert::Infallible, path::PathBuf, sync::Arc};
use hyper_util::rt::TokioIo;
use std::{path::PathBuf, sync::Arc};
use tokio::net::{unix::UCred, UnixListener, UnixStream};
use tower::Service;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
pub async fn server() {
@ -52,33 +48,11 @@ mod unix {
let uds = UnixListener::bind(path.clone()).unwrap();
tokio::spawn(async move {
let app = Router::new().route("/", get(handler));
let app = Router::new()
.route("/", get(handler))
.into_make_service_with_connect_info::<UdsConnectInfo>();
let mut make_service = app.into_make_service_with_connect_info::<UdsConnectInfo>();
// See https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs for
// more details about this setup
loop {
let (socket, _remote_addr) = uds.accept().await.unwrap();
let tower_service = unwrap_infallible(make_service.call(&socket).await);
tokio::spawn(async move {
let socket = TokioIo::new(socket);
let hyper_service =
hyper::service::service_fn(move |request: Request<Incoming>| {
tower_service.clone().call(request)
});
if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(socket, hyper_service)
.await
{
eprintln!("failed to serve connection: {err:#}");
}
});
}
axum::serve(uds, app).await.unwrap();
});
let stream = TokioIo::new(UnixStream::connect(path).await.unwrap());
@ -117,22 +91,14 @@ mod unix {
peer_cred: UCred,
}
impl connect_info::Connected<&UnixStream> for UdsConnectInfo {
fn connect_info(target: &UnixStream) -> Self {
let peer_addr = target.peer_addr().unwrap();
let peer_cred = target.peer_cred().unwrap();
impl connect_info::Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
fn connect_info(stream: IncomingStream<'_, UnixListener>) -> Self {
let peer_addr = stream.io().peer_addr().unwrap();
let peer_cred = stream.io().peer_cred().unwrap();
Self {
peer_addr: Arc::new(peer_addr),
peer_cred,
}
}
}
fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
match result {
Ok(value) => value,
Err(err) => match err {},
}
}
}