From 84c3960639bde8a178e276d9808ba13fefa60957 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Sat, 30 Nov 2024 22:47:13 +0000 Subject: [PATCH] Make `serve` generic over the listener and IO types (#2941) Co-authored-by: David Pedersen --- axum/CHANGELOG.md | 2 + .../into_make_service_with_connect_info.md | 5 +- axum/src/extract/connect_info.rs | 15 +- axum/src/handler/service.rs | 7 +- axum/src/routing/method_routing.rs | 9 +- axum/src/routing/mod.rs | 9 +- axum/src/serve.rs | 383 +++++++++--------- axum/src/serve/listener.rs | 167 ++++++++ examples/unix-domain-socket/src/main.rs | 56 +-- 9 files changed, 399 insertions(+), 254 deletions(-) create mode 100644 axum/src/serve/listener.rs diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index ce98bca6..51654fa4 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Extend `FailedToDeserializePathParams::kind` enum with (`ErrorKind::DeserializeError`) This new variant captures both `key`, `value`, and `message` from named path parameters parse errors, instead of only deserialization error message in `ErrorKind::Message`. ([#2720]) +- **breaking:** Make `serve` generic over the listener and IO types ([#2941]) [#2897]: https://github.com/tokio-rs/axum/pull/2897 [#2903]: https://github.com/tokio-rs/axum/pull/2903 @@ -34,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#2992]: https://github.com/tokio-rs/axum/pull/2992 [#2720]: https://github.com/tokio-rs/axum/pull/2720 [#3039]: https://github.com/tokio-rs/axum/pull/3039 +[#2941]: https://github.com/tokio-rs/axum/pull/2941 # 0.8.0 diff --git a/axum/src/docs/routing/into_make_service_with_connect_info.md b/axum/src/docs/routing/into_make_service_with_connect_info.md index 26d0602f..088f21f9 100644 --- a/axum/src/docs/routing/into_make_service_with_connect_info.md +++ b/axum/src/docs/routing/into_make_service_with_connect_info.md @@ -35,6 +35,7 @@ use axum::{ serve::IncomingStream, Router, }; +use tokio::net::TcpListener; let app = Router::new().route("/", get(handler)); @@ -49,8 +50,8 @@ struct MyConnectInfo { // ... } -impl Connected> for MyConnectInfo { - fn connect_info(target: IncomingStream<'_>) -> Self { +impl Connected> for MyConnectInfo { + fn connect_info(target: IncomingStream<'_, TcpListener>) -> Self { MyConnectInfo { // ... } diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index 3d8f9a01..54a8d775 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -79,16 +79,17 @@ where /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info pub trait Connected: Clone + Send + Sync + 'static { /// Create type holding information about the connection. - fn connect_info(target: T) -> Self; + fn connect_info(stream: T) -> Self; } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; + use tokio::net::TcpListener; - impl Connected> for SocketAddr { - fn connect_info(target: IncomingStream<'_>) -> Self { - target.remote_addr() + impl Connected> for SocketAddr { + fn connect_info(stream: serve::IncomingStream<'_, TcpListener>) -> Self { + *stream.remote_addr() } } }; @@ -261,8 +262,8 @@ mod tests { value: &'static str, } - impl Connected> for MyConnectInfo { - fn connect_info(_target: IncomingStream<'_>) -> Self { + impl Connected> for MyConnectInfo { + fn connect_info(_target: IncomingStream<'_, TcpListener>) -> Self { Self { value: "it worked!", } diff --git a/axum/src/handler/service.rs b/axum/src/handler/service.rs index e6b8df93..20900519 100644 --- a/axum/src/handler/service.rs +++ b/axum/src/handler/service.rs @@ -180,12 +180,13 @@ where // for `axum::serve(listener, handler)` #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; - impl Service> for HandlerService + impl Service> for HandlerService where H: Clone, S: Clone, + L: serve::Listener, { type Response = Self; type Error = Infallible; @@ -195,7 +196,7 @@ const _: () = { Poll::Ready(Ok(())) } - fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { + fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future { std::future::ready(Ok(self.clone())) } } diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index cfc47e1f..2c03e7d5 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1313,9 +1313,12 @@ where // for `axum::serve(listener, router)` #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; - impl Service> for MethodRouter<()> { + impl Service> for MethodRouter<()> + where + L: serve::Listener, + { type Response = Self; type Error = Infallible; type Future = std::future::Ready>; @@ -1324,7 +1327,7 @@ const _: () = { Poll::Ready(Ok(())) } - fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { + fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future { std::future::ready(Ok(self.clone().with_state(()))) } } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index d1e84d6a..6404ce7d 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -518,9 +518,12 @@ impl Router { // for `axum::serve(listener, router)` #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; - impl Service> for Router<()> { + impl Service> for Router<()> + where + L: serve::Listener, + { type Response = Self; type Error = Infallible; type Future = std::future::Ready>; @@ -529,7 +532,7 @@ const _: () = { Poll::Ready(Ok(())) } - fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { + fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future { // call `Router::with_state` such that everything is turned into `Route` eagerly // rather than doing that per request std::future::ready(Ok(self.clone().with_state(()))) diff --git a/axum/src/serve.rs b/axum/src/serve.rs index 87b103ee..aa205c8d 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -1,14 +1,13 @@ //! Serve services. use std::{ + any::TypeId, convert::Infallible, fmt::Debug, future::{poll_fn, Future, IntoFuture}, io, marker::PhantomData, - net::SocketAddr, sync::Arc, - time::Duration, }; use axum_core::{body::Body, extract::Request, response::Response}; @@ -17,13 +16,14 @@ use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; #[cfg(any(feature = "http1", feature = "http2"))] use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService}; -use tokio::{ - net::{TcpListener, TcpStream}, - sync::watch, -}; +use tokio::{net::TcpListener, sync::watch}; use tower::ServiceExt as _; use tower_service::Service; +mod listener; + +pub use self::listener::{Listener, ListenerExt, TapIo}; + /// Serve the service with the supplied listener. /// /// This method of running a service is intentionally simple and doesn't support any configuration. @@ -89,14 +89,15 @@ use tower_service::Service; /// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info /// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -pub fn serve(tcp_listener: TcpListener, make_service: M) -> Serve +pub fn serve(listener: L, make_service: M) -> Serve where - M: for<'a> Service, Error = Infallible, Response = S>, + L: Listener, + M: for<'a> Service, Error = Infallible, Response = S>, S: Service + Clone + Send + 'static, S::Future: Send, { Serve { - tcp_listener, + listener, make_service, tcp_nodelay: None, _marker: PhantomData, @@ -106,15 +107,18 @@ where /// Future returned by [`serve`]. #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[must_use = "futures must be awaited or polled"] -pub struct Serve { - tcp_listener: TcpListener, +pub struct Serve { + listener: L, make_service: M, tcp_nodelay: Option, _marker: PhantomData, } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Serve { +impl Serve +where + L: Listener, +{ /// Prepares a server to handle graceful shutdown when the provided future completes. /// /// # Example @@ -136,76 +140,57 @@ impl Serve { /// // ... /// } /// ``` - pub fn with_graceful_shutdown(self, signal: F) -> WithGracefulShutdown + pub fn with_graceful_shutdown(self, signal: F) -> WithGracefulShutdown where F: Future + Send + 'static, { WithGracefulShutdown { - tcp_listener: self.tcp_listener, + listener: self.listener, make_service: self.make_service, signal, - tcp_nodelay: self.tcp_nodelay, _marker: PhantomData, } } - /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. - /// - /// See also [`TcpStream::set_nodelay`]. - /// - /// # Example - /// ``` - /// use axum::{Router, routing::get}; - /// - /// # async { - /// let router = Router::new().route("/", get(|| async { "Hello, World!" })); - /// - /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); - /// axum::serve(listener, router) - /// .tcp_nodelay(true) - /// .await - /// .unwrap(); - /// # }; - /// ``` - pub fn tcp_nodelay(self, nodelay: bool) -> Self { - Self { - tcp_nodelay: Some(nodelay), - ..self - } - } - /// Returns the local address this server is bound to. - pub fn local_addr(&self) -> io::Result { - self.tcp_listener.local_addr() + pub fn local_addr(&self) -> io::Result { + self.listener.local_addr() } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Debug for Serve +impl Debug for Serve where + L: Debug + 'static, M: Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let Self { - tcp_listener, + listener, make_service, tcp_nodelay, _marker: _, } = self; - f.debug_struct("Serve") - .field("tcp_listener", tcp_listener) - .field("make_service", make_service) - .field("tcp_nodelay", tcp_nodelay) - .finish() + let mut s = f.debug_struct("Serve"); + s.field("listener", listener) + .field("make_service", make_service); + + if TypeId::of::() == TypeId::of::() { + s.field("tcp_nodelay", tcp_nodelay); + } + + s.finish() } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl IntoFuture for Serve +impl IntoFuture for Serve where - M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, - for<'a> >>::Future: Send, + L: Listener, + L::Addr: Debug, + M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, + for<'a> >>::Future: Send, S: Service + Clone + Send + 'static, S::Future: Send, { @@ -221,81 +206,55 @@ where /// Serve future with graceful shutdown enabled. #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[must_use = "futures must be awaited or polled"] -pub struct WithGracefulShutdown { - tcp_listener: TcpListener, +pub struct WithGracefulShutdown { + listener: L, make_service: M, signal: F, - tcp_nodelay: Option, _marker: PhantomData, } -impl WithGracefulShutdown { - /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. - /// - /// See also [`TcpStream::set_nodelay`]. - /// - /// # Example - /// ``` - /// use axum::{Router, routing::get}; - /// - /// # async { - /// let router = Router::new().route("/", get(|| async { "Hello, World!" })); - /// - /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); - /// axum::serve(listener, router) - /// .with_graceful_shutdown(shutdown_signal()) - /// .tcp_nodelay(true) - /// .await - /// .unwrap(); - /// # }; - /// - /// async fn shutdown_signal() { - /// // ... - /// } - /// ``` - pub fn tcp_nodelay(self, nodelay: bool) -> Self { - Self { - tcp_nodelay: Some(nodelay), - ..self - } - } - +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +impl WithGracefulShutdown +where + L: Listener, +{ /// Returns the local address this server is bound to. - pub fn local_addr(&self) -> io::Result { - self.tcp_listener.local_addr() + pub fn local_addr(&self) -> io::Result { + self.listener.local_addr() } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Debug for WithGracefulShutdown +impl Debug for WithGracefulShutdown where + L: Debug + 'static, M: Debug, S: Debug, F: Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let Self { - tcp_listener, + listener, make_service, signal, - tcp_nodelay, _marker: _, } = self; f.debug_struct("WithGracefulShutdown") - .field("tcp_listener", tcp_listener) + .field("listener", listener) .field("make_service", make_service) .field("signal", signal) - .field("tcp_nodelay", tcp_nodelay) .finish() } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl IntoFuture for WithGracefulShutdown +impl IntoFuture for WithGracefulShutdown where - M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, - for<'a> >>::Future: Send, + L: Listener, + L::Addr: Debug, + M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, + for<'a> >>::Future: Send, S: Service + Clone + Send + 'static, S::Future: Send, F: Future + Send + 'static, @@ -305,10 +264,9 @@ where fn into_future(self) -> Self::IntoFuture { let Self { - tcp_listener, + mut listener, mut make_service, signal, - tcp_nodelay, _marker: _, } = self; @@ -324,28 +282,17 @@ where let (close_tx, close_rx) = watch::channel(()); loop { - let (tcp_stream, remote_addr) = tokio::select! { - conn = tcp_accept(&tcp_listener) => { - match conn { - Some(conn) => conn, - None => continue, - } - } + let (io, remote_addr) = tokio::select! { + conn = listener.accept() => conn, _ = signal_tx.closed() => { trace!("signal received, not accepting new connections"); break; } }; - if let Some(nodelay) = tcp_nodelay { - if let Err(err) = tcp_stream.set_nodelay(nodelay) { - trace!("failed to set TCP_NODELAY on incoming connection: {err:#}"); - } - } + let io = TokioIo::new(io); - let tcp_stream = TokioIo::new(tcp_stream); - - trace!("connection {remote_addr} accepted"); + trace!("connection {remote_addr:?} accepted"); poll_fn(|cx| make_service.poll_ready(cx)) .await @@ -353,7 +300,7 @@ where let tower_service = make_service .call(IncomingStream { - tcp_stream: &tcp_stream, + io: &io, remote_addr, }) .await @@ -372,7 +319,7 @@ where // CONNECT protocol needed for HTTP/2 websockets #[cfg(feature = "http2")] builder.http2().enable_connect_protocol(); - let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service); + let conn = builder.serve_connection_with_upgrades(io, hyper_service); pin_mut!(conn); let signal_closed = signal_tx.closed().fuse(); @@ -393,14 +340,12 @@ where } } - trace!("connection {remote_addr} closed"); - drop(close_rx); }); } drop(close_rx); - drop(tcp_listener); + drop(listener); trace!( "waiting for {} task(s) to finish", @@ -413,38 +358,32 @@ where } } -fn is_connection_error(e: &io::Error) -> bool { - matches!( - e.kind(), - io::ErrorKind::ConnectionRefused - | io::ErrorKind::ConnectionAborted - | io::ErrorKind::ConnectionReset - ) +/// An incoming stream. +/// +/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`]. +/// +/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo +#[derive(Debug)] +pub struct IncomingStream<'a, L> +where + L: Listener, +{ + io: &'a TokioIo, + remote_addr: L::Addr, } -async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> { - match listener.accept().await { - Ok(conn) => Some(conn), - Err(e) => { - if is_connection_error(&e) { - return None; - } +impl IncomingStream<'_, L> +where + L: Listener, +{ + /// Get a reference to the inner IO type. + pub fn io(&self) -> &L::Io { + self.io.inner() + } - // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186) - // - // > A possible scenario is that the process has hit the max open files - // > allowed, and so trying to accept a new connection will fail with - // > `EMFILE`. In some cases, it's preferable to just wait for some time, if - // > the application will likely close some files (or connections), and try - // > to accept the connection again. If this option is `true`, the error - // > will be logged at the `error` level, since it is still a big deal, - // > and then the listener will sleep for 1 second. - // - // hyper allowed customizing this but axum does not. - error!("accept error: {e}"); - tokio::time::sleep(Duration::from_secs(1)).await; - None - } + /// Returns the remote address that this stream is bound to. + pub fn remote_addr(&self) -> &L::Addr { + &self.remote_addr } } @@ -474,68 +413,89 @@ mod private { } } -/// An incoming stream. -/// -/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`]. -/// -/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo -#[derive(Debug)] -pub struct IncomingStream<'a> { - tcp_stream: &'a TokioIo, - remote_addr: SocketAddr, -} - -impl IncomingStream<'_> { - /// Returns the local address that this stream is bound to. - pub fn local_addr(&self) -> std::io::Result { - self.tcp_stream.inner().local_addr() - } - - /// Returns the remote address that this stream is bound to. - pub fn remote_addr(&self) -> SocketAddr { - self.remote_addr - } -} - #[cfg(test)] mod tests { - use super::*; + use std::{ + future::{pending, IntoFuture as _}, + net::{IpAddr, Ipv4Addr}, + }; + + use axum_core::{body::Body, extract::Request}; + use http::StatusCode; + use hyper_util::rt::TokioIo; + use tokio::{ + io::{self, AsyncRead, AsyncWrite}, + net::{TcpListener, UnixListener}, + }; + + use super::{serve, IncomingStream, Listener}; use crate::{ + body::to_bytes, + extract::connect_info::Connected, handler::{Handler, HandlerWithoutStateExt}, routing::get, Router, }; - use std::{ - future::pending, - net::{IpAddr, Ipv4Addr}, - }; #[allow(dead_code, unused_must_use)] async fn if_it_compiles_it_works() { + #[derive(Clone, Debug)] + struct UdsConnectInfo; + + impl Connected> for UdsConnectInfo { + fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self { + Self + } + } + let router: Router = Router::new(); let addr = "0.0.0.0:0"; // router serve(TcpListener::bind(addr).await.unwrap(), router.clone()); + serve(UnixListener::bind("").unwrap(), router.clone()); + serve( TcpListener::bind(addr).await.unwrap(), router.clone().into_make_service(), ); + serve( + UnixListener::bind("").unwrap(), + router.clone().into_make_service(), + ); + serve( TcpListener::bind(addr).await.unwrap(), - router.into_make_service_with_connect_info::(), + router + .clone() + .into_make_service_with_connect_info::(), + ); + serve( + UnixListener::bind("").unwrap(), + router.into_make_service_with_connect_info::(), ); // method router serve(TcpListener::bind(addr).await.unwrap(), get(handler)); + serve(UnixListener::bind("").unwrap(), get(handler)); + serve( TcpListener::bind(addr).await.unwrap(), get(handler).into_make_service(), ); + serve( + UnixListener::bind("").unwrap(), + get(handler).into_make_service(), + ); + serve( TcpListener::bind(addr).await.unwrap(), - get(handler).into_make_service_with_connect_info::(), + get(handler).into_make_service_with_connect_info::(), + ); + serve( + UnixListener::bind("").unwrap(), + get(handler).into_make_service_with_connect_info::(), ); // handler @@ -543,32 +503,28 @@ mod tests { TcpListener::bind(addr).await.unwrap(), handler.into_service(), ); + serve(UnixListener::bind("").unwrap(), handler.into_service()); + serve( TcpListener::bind(addr).await.unwrap(), handler.with_state(()), ); + serve(UnixListener::bind("").unwrap(), handler.with_state(())); + serve( TcpListener::bind(addr).await.unwrap(), handler.into_make_service(), ); + serve(UnixListener::bind("").unwrap(), handler.into_make_service()); + serve( TcpListener::bind(addr).await.unwrap(), - handler.into_make_service_with_connect_info::(), + handler.into_make_service_with_connect_info::(), ); - - // nodelay serve( - TcpListener::bind(addr).await.unwrap(), - handler.into_service(), - ) - .tcp_nodelay(true); - - serve( - TcpListener::bind(addr).await.unwrap(), - handler.into_service(), - ) - .with_graceful_shutdown(async { /*...*/ }) - .tcp_nodelay(true); + UnixListener::bind("").unwrap(), + handler.into_make_service_with_connect_info::(), + ); } async fn handler() {} @@ -613,4 +569,49 @@ mod tests { // Call Serve::into_future outside of a tokio context. This used to panic. _ = serve(listener, router).into_future(); } + + #[crate::test] + async fn serving_on_custom_io_type() { + struct ReadyListener(Option); + + impl Listener for ReadyListener + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + type Io = T; + type Addr = (); + + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + match self.0.take() { + Some(server) => (server, ()), + None => std::future::pending().await, + } + } + + fn local_addr(&self) -> io::Result { + Ok(()) + } + } + + let (client, server) = io::duplex(1024); + let listener = ReadyListener(Some(server)); + + let app = Router::new().route("/", get(|| async { "Hello, World!" })); + + tokio::spawn(serve(listener, app).into_future()); + + let stream = TokioIo::new(client); + let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap(); + tokio::spawn(conn); + + let request = Request::builder().body(Body::empty()).unwrap(); + + let response = sender.send_request(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let body = Body::new(response.into_body()); + let body = to_bytes(body, usize::MAX).await.unwrap(); + let body = String::from_utf8(body.to_vec()).unwrap(); + assert_eq!(body, "Hello, World!"); + } } diff --git a/axum/src/serve/listener.rs b/axum/src/serve/listener.rs new file mode 100644 index 00000000..8458b957 --- /dev/null +++ b/axum/src/serve/listener.rs @@ -0,0 +1,167 @@ +use std::{fmt, future::Future, time::Duration}; + +use tokio::{ + io::{self, AsyncRead, AsyncWrite}, + net::{TcpListener, TcpStream}, +}; + +/// Types that can listen for connections. +pub trait Listener: Send + 'static { + /// The listener's IO type. + type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static; + + /// The listener's address type. + type Addr: Send; + + /// Accept a new incoming connection to this listener. + /// + /// If the underlying accept call can return an error, this function must + /// take care of logging and retrying. + fn accept(&mut self) -> impl Future + Send; + + /// Returns the local address that this listener is bound to. + fn local_addr(&self) -> io::Result; +} + +impl Listener for TcpListener { + type Io = TcpStream; + type Addr = std::net::SocketAddr; + + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + loop { + match Self::accept(self).await { + Ok(tup) => return tup, + Err(e) => handle_accept_error(e).await, + } + } + } + + #[inline] + fn local_addr(&self) -> io::Result { + Self::local_addr(self) + } +} + +#[cfg(unix)] +impl Listener for tokio::net::UnixListener { + type Io = tokio::net::UnixStream; + type Addr = tokio::net::unix::SocketAddr; + + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + loop { + match Self::accept(self).await { + Ok(tup) => return tup, + Err(e) => handle_accept_error(e).await, + } + } + } + + #[inline] + fn local_addr(&self) -> io::Result { + Self::local_addr(self) + } +} + +/// Extensions to [`Listener`]. +pub trait ListenerExt: Listener + Sized { + /// Run a mutable closure on every accepted `Io`. + /// + /// # Example + /// + /// ``` + /// use axum::{Router, routing::get, serve::ListenerExt}; + /// use tracing::trace; + /// + /// # async { + /// let router = Router::new().route("/", get(|| async { "Hello, World!" })); + /// + /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") + /// .await + /// .unwrap() + /// .tap_io(|tcp_stream| { + /// if let Err(err) = tcp_stream.set_nodelay(true) { + /// trace!("failed to set TCP_NODELAY on incoming connection: {err:#}"); + /// } + /// }); + /// axum::serve(listener, router).await.unwrap(); + /// # }; + /// ``` + fn tap_io(self, tap_fn: F) -> TapIo + where + F: FnMut(&mut Self::Io) + Send + 'static, + { + TapIo { + listener: self, + tap_fn, + } + } +} + +impl ListenerExt for L {} + +/// Return type of [`ListenerExt::tap_io`]. +/// +/// See that method for details. +pub struct TapIo { + listener: L, + tap_fn: F, +} + +impl fmt::Debug for TapIo +where + L: Listener + fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TapIo") + .field("listener", &self.listener) + .finish_non_exhaustive() + } +} + +impl Listener for TapIo +where + L: Listener, + F: FnMut(&mut L::Io) + Send + 'static, +{ + type Io = L::Io; + type Addr = L::Addr; + + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + let (mut io, addr) = self.listener.accept().await; + (self.tap_fn)(&mut io); + (io, addr) + } + + fn local_addr(&self) -> io::Result { + self.listener.local_addr() + } +} + +async fn handle_accept_error(e: io::Error) { + if is_connection_error(&e) { + return; + } + + // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186) + // + // > A possible scenario is that the process has hit the max open files + // > allowed, and so trying to accept a new connection will fail with + // > `EMFILE`. In some cases, it's preferable to just wait for some time, if + // > the application will likely close some files (or connections), and try + // > to accept the connection again. If this option is `true`, the error + // > will be logged at the `error` level, since it is still a big deal, + // > and then the listener will sleep for 1 second. + // + // hyper allowed customizing this but axum does not. + error!("accept error: {e}"); + tokio::time::sleep(Duration::from_secs(1)).await; +} + +fn is_connection_error(e: &io::Error) -> bool { + matches!( + e.kind(), + io::ErrorKind::ConnectionRefused + | io::ErrorKind::ConnectionAborted + | io::ErrorKind::ConnectionReset + ) +} diff --git a/examples/unix-domain-socket/src/main.rs b/examples/unix-domain-socket/src/main.rs index 697f31a5..07b38d91 100644 --- a/examples/unix-domain-socket/src/main.rs +++ b/examples/unix-domain-socket/src/main.rs @@ -21,17 +21,13 @@ mod unix { extract::connect_info::{self, ConnectInfo}, http::{Method, Request, StatusCode}, routing::get, + serve::IncomingStream, Router, }; use http_body_util::BodyExt; - use hyper::body::Incoming; - use hyper_util::{ - rt::{TokioExecutor, TokioIo}, - server, - }; - use std::{convert::Infallible, path::PathBuf, sync::Arc}; + use hyper_util::rt::TokioIo; + use std::{path::PathBuf, sync::Arc}; use tokio::net::{unix::UCred, UnixListener, UnixStream}; - use tower::Service; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; pub async fn server() { @@ -52,33 +48,11 @@ mod unix { let uds = UnixListener::bind(path.clone()).unwrap(); tokio::spawn(async move { - let app = Router::new().route("/", get(handler)); + let app = Router::new() + .route("/", get(handler)) + .into_make_service_with_connect_info::(); - let mut make_service = app.into_make_service_with_connect_info::(); - - // See https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs for - // more details about this setup - loop { - let (socket, _remote_addr) = uds.accept().await.unwrap(); - - let tower_service = unwrap_infallible(make_service.call(&socket).await); - - tokio::spawn(async move { - let socket = TokioIo::new(socket); - - let hyper_service = - hyper::service::service_fn(move |request: Request| { - tower_service.clone().call(request) - }); - - if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) - .serve_connection_with_upgrades(socket, hyper_service) - .await - { - eprintln!("failed to serve connection: {err:#}"); - } - }); - } + axum::serve(uds, app).await.unwrap(); }); let stream = TokioIo::new(UnixStream::connect(path).await.unwrap()); @@ -117,22 +91,14 @@ mod unix { peer_cred: UCred, } - impl connect_info::Connected<&UnixStream> for UdsConnectInfo { - fn connect_info(target: &UnixStream) -> Self { - let peer_addr = target.peer_addr().unwrap(); - let peer_cred = target.peer_cred().unwrap(); - + impl connect_info::Connected> for UdsConnectInfo { + fn connect_info(stream: IncomingStream<'_, UnixListener>) -> Self { + let peer_addr = stream.io().peer_addr().unwrap(); + let peer_cred = stream.io().peer_cred().unwrap(); Self { peer_addr: Arc::new(peer_addr), peer_cred, } } } - - fn unwrap_infallible(result: Result) -> T { - match result { - Ok(value) => value, - Err(err) => match err {}, - } - } }