Make serve generic over the listener and IO types

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
Jonas Platte 2024-09-27 23:57:26 +02:00
parent 689ca1aea2
commit 26a37a4679
No known key found for this signature in database
GPG key ID: 7D261D771D915378
9 changed files with 297 additions and 142 deletions

View file

@ -8,9 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- **breaking:** The tuple and tuple_struct `Path` extractor deserializers now check that the number of parameters matches the tuple length exactly ([#2931])
- **breaking:** Make `serve` generic over the listener and IO types ([#2941])
- **change:** Update minimum rust version to 1.75 ([#2943])
[#2931]: https://github.com/tokio-rs/axum/pull/2931
[#2941]: https://github.com/tokio-rs/axum/pull/2941
[#2943]: https://github.com/tokio-rs/axum/pull/2943
# 0.7.7

View file

@ -113,6 +113,7 @@ features = [
[dev-dependencies]
anyhow = "1.0"
axum-macros = { path = "../axum-macros", version = "0.4.1", features = ["__private"] }
hyper = { version = "1.1.0", features = ["client"] }
quickcheck = "1.0"
quickcheck_macros = "1.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] }

View file

@ -35,6 +35,7 @@ use axum::{
serve::IncomingStream,
Router,
};
use tokio::net::TcpListener;
let app = Router::new().route("/", get(handler));
@ -49,8 +50,8 @@ struct MyConnectInfo {
// ...
}
impl Connected<IncomingStream<'_>> for MyConnectInfo {
fn connect_info(target: IncomingStream<'_>) -> Self {
impl Connected<IncomingStream<'_, TcpListener>> for MyConnectInfo {
fn connect_info(target: IncomingStream<'_, TcpListener>) -> Self {
MyConnectInfo {
// ...
}

View file

@ -80,16 +80,17 @@ where
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
pub trait Connected<T>: Clone + Send + Sync + 'static {
/// Create type holding information about the connection.
fn connect_info(target: T) -> Self;
fn connect_info(stream: T) -> Self;
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = {
use crate::serve::IncomingStream;
use crate::serve;
use tokio::net::TcpListener;
impl Connected<IncomingStream<'_>> for SocketAddr {
fn connect_info(target: IncomingStream<'_>) -> Self {
target.remote_addr()
impl Connected<serve::IncomingStream<'_, TcpListener>> for SocketAddr {
fn connect_info(stream: serve::IncomingStream<'_, TcpListener>) -> Self {
*stream.remote_addr()
}
}
};
@ -263,8 +264,8 @@ mod tests {
value: &'static str,
}
impl Connected<IncomingStream<'_>> for MyConnectInfo {
fn connect_info(_target: IncomingStream<'_>) -> Self {
impl Connected<IncomingStream<'_, TcpListener>> for MyConnectInfo {
fn connect_info(_target: IncomingStream<'_, TcpListener>) -> Self {
Self {
value: "it worked!",
}

View file

@ -180,12 +180,13 @@ where
// for `axum::serve(listener, handler)`
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = {
use crate::serve::IncomingStream;
use crate::serve;
impl<H, T, S> Service<IncomingStream<'_>> for HandlerService<H, T, S>
impl<H, T, S, L> Service<serve::IncomingStream<'_, L>> for HandlerService<H, T, S>
where
H: Clone,
S: Clone,
L: serve::Listener,
{
type Response = Self;
type Error = Infallible;
@ -195,7 +196,7 @@ const _: () = {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
std::future::ready(Ok(self.clone()))
}
}

View file

@ -1227,9 +1227,12 @@ where
// for `axum::serve(listener, router)`
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = {
use crate::serve::IncomingStream;
use crate::serve;
impl Service<IncomingStream<'_>> for MethodRouter<()> {
impl<L> Service<serve::IncomingStream<'_, L>> for MethodRouter<()>
where
L: serve::Listener,
{
type Response = Self;
type Error = Infallible;
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
@ -1238,7 +1241,7 @@ const _: () = {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
std::future::ready(Ok(self.clone().with_state(())))
}
}

View file

@ -486,9 +486,12 @@ impl Router {
// for `axum::serve(listener, router)`
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = {
use crate::serve::IncomingStream;
use crate::serve;
impl Service<IncomingStream<'_>> for Router<()> {
impl<L> Service<serve::IncomingStream<'_, L>> for Router<()>
where
L: serve::Listener,
{
type Response = Self;
type Error = Infallible;
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
@ -497,7 +500,7 @@ const _: () = {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
// call `Router::with_state` such that everything is turned into `Route` eagerly
// rather than doing that per request
std::future::ready(Ok(self.clone().with_state(())))

View file

@ -1,12 +1,12 @@
//! Serve services.
use std::{
any::TypeId,
convert::Infallible,
fmt::Debug,
future::{poll_fn, Future, IntoFuture},
io,
marker::PhantomData,
net::SocketAddr,
sync::Arc,
time::Duration,
};
@ -18,12 +18,59 @@ use hyper_util::rt::{TokioExecutor, TokioIo};
#[cfg(any(feature = "http1", feature = "http2"))]
use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream},
sync::watch,
};
use tower::ServiceExt as _;
use tower_service::Service;
/// Types that can listen for connections.
pub trait Listener: Send + 'static {
/// The listener's IO type.
type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static;
/// The listener's address type.
type Addr: Send;
/// Accept a new incoming connection to this listener
fn accept(&mut self) -> impl Future<Output = io::Result<(Self::Io, Self::Addr)>> + Send;
/// Returns the local address that this listener is bound to.
fn local_addr(&self) -> io::Result<Self::Addr>;
}
impl Listener for TcpListener {
type Io = TcpStream;
type Addr = std::net::SocketAddr;
#[inline]
async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> {
Self::accept(self).await
}
#[inline]
fn local_addr(&self) -> io::Result<Self::Addr> {
Self::local_addr(self)
}
}
#[cfg(unix)]
impl Listener for tokio::net::UnixListener {
type Io = tokio::net::UnixStream;
type Addr = tokio::net::unix::SocketAddr;
#[inline]
async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> {
Self::accept(self).await
}
#[inline]
fn local_addr(&self) -> io::Result<Self::Addr> {
Self::local_addr(self)
}
}
/// Serve the service with the supplied listener.
///
/// This method of running a service is intentionally simple and doesn't support any configuration.
@ -89,14 +136,15 @@ use tower_service::Service;
/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S>
pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
L: Listener,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
Serve {
tcp_listener,
listener,
make_service,
tcp_nodelay: None,
_marker: PhantomData,
@ -106,15 +154,18 @@ where
/// Future returned by [`serve`].
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"]
pub struct Serve<M, S> {
tcp_listener: TcpListener,
pub struct Serve<L, M, S> {
listener: L,
make_service: M,
tcp_nodelay: Option<bool>,
_marker: PhantomData<S>,
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Serve<M, S> {
impl<L, M, S> Serve<L, M, S>
where
L: Listener,
{
/// Prepares a server to handle graceful shutdown when the provided future completes.
///
/// # Example
@ -136,12 +187,12 @@ impl<M, S> Serve<M, S> {
/// // ...
/// }
/// ```
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
where
F: Future<Output = ()> + Send + 'static,
{
WithGracefulShutdown {
tcp_listener: self.tcp_listener,
listener: self.listener,
make_service: self.make_service,
signal,
tcp_nodelay: self.tcp_nodelay,
@ -149,6 +200,14 @@ impl<M, S> Serve<M, S> {
}
}
/// Returns the local address this server is bound to.
pub fn local_addr(&self) -> io::Result<L::Addr> {
self.listener.local_addr()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Serve<TcpListener, M, S> {
/// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
///
/// See also [`TcpStream::set_nodelay`].
@ -173,39 +232,41 @@ impl<M, S> Serve<M, S> {
..self
}
}
/// Returns the local address this server is bound to.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.tcp_listener.local_addr()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Debug for Serve<M, S>
impl<L, M, S> Debug for Serve<L, M, S>
where
L: Debug + 'static,
M: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
tcp_listener,
listener,
make_service,
tcp_nodelay,
_marker: _,
} = self;
f.debug_struct("Serve")
.field("tcp_listener", tcp_listener)
.field("make_service", make_service)
.field("tcp_nodelay", tcp_nodelay)
.finish()
let mut s = f.debug_struct("Serve");
s.field("listener", listener)
.field("make_service", make_service);
if TypeId::of::<L>() == TypeId::of::<TcpListener>() {
s.field("tcp_nodelay", tcp_nodelay);
}
s.finish()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> IntoFuture for Serve<M, S>
impl<L, M, S> IntoFuture for Serve<L, M, S>
where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
L: Listener,
L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
@ -221,15 +282,27 @@ where
/// Serve future with graceful shutdown enabled.
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"]
pub struct WithGracefulShutdown<M, S, F> {
tcp_listener: TcpListener,
pub struct WithGracefulShutdown<L, M, S, F> {
listener: L,
make_service: M,
signal: F,
tcp_nodelay: Option<bool>,
_marker: PhantomData<S>,
}
impl<M, S, F> WithGracefulShutdown<M, S, F> {
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
where
L: Listener,
{
/// Returns the local address this server is bound to.
pub fn local_addr(&self) -> io::Result<L::Addr> {
self.listener.local_addr()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> WithGracefulShutdown<TcpListener, M, S, F> {
/// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
///
/// See also [`TcpStream::set_nodelay`].
@ -259,43 +332,45 @@ impl<M, S, F> WithGracefulShutdown<M, S, F> {
..self
}
}
/// Returns the local address this server is bound to.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.tcp_listener.local_addr()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> Debug for WithGracefulShutdown<M, S, F>
impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
where
L: Debug + 'static,
M: Debug,
S: Debug,
F: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
tcp_listener,
listener,
make_service,
signal,
tcp_nodelay,
_marker: _,
} = self;
f.debug_struct("WithGracefulShutdown")
.field("tcp_listener", tcp_listener)
let mut s = f.debug_struct("WithGracefulShutdown");
s.field("listener", listener)
.field("make_service", make_service)
.field("signal", signal)
.field("tcp_nodelay", tcp_nodelay)
.finish()
.field("signal", signal);
if TypeId::of::<L>() == TypeId::of::<TcpListener>() {
s.field("tcp_nodelay", tcp_nodelay);
}
s.finish()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
L: Listener,
L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
F: Future<Output = ()> + Send + 'static,
@ -305,7 +380,7 @@ where
fn into_future(self) -> Self::IntoFuture {
let Self {
tcp_listener,
mut listener,
mut make_service,
signal,
tcp_nodelay,
@ -324,8 +399,8 @@ where
private::ServeFuture(Box::pin(async move {
loop {
let (tcp_stream, remote_addr) = tokio::select! {
conn = tcp_accept(&tcp_listener) => {
let (io, remote_addr) = tokio::select! {
conn = accept(&mut listener) => {
match conn {
Some(conn) => conn,
None => continue,
@ -338,14 +413,16 @@ where
};
if let Some(nodelay) = tcp_nodelay {
let tcp_stream: &tokio::net::TcpStream = <dyn std::any::Any>::downcast_ref(&io)
.expect("internal error: tcp_nodelay used with the wrong type of listener");
if let Err(err) = tcp_stream.set_nodelay(nodelay) {
trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
}
}
let tcp_stream = TokioIo::new(tcp_stream);
let io = TokioIo::new(io);
trace!("connection {remote_addr} accepted");
trace!("connection {remote_addr:?} accepted");
poll_fn(|cx| make_service.poll_ready(cx))
.await
@ -353,7 +430,7 @@ where
let tower_service = make_service
.call(IncomingStream {
tcp_stream: &tcp_stream,
io: &io,
remote_addr,
})
.await
@ -368,7 +445,7 @@ where
tokio::spawn(async move {
let builder = Builder::new(TokioExecutor::new());
let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service);
let conn = builder.serve_connection_with_upgrades(io, hyper_service);
pin_mut!(conn);
let signal_closed = signal_tx.closed().fuse();
@ -389,14 +466,12 @@ where
}
}
trace!("connection {remote_addr} closed");
drop(close_rx);
});
}
drop(close_rx);
drop(tcp_listener);
drop(listener);
trace!(
"waiting for {} task(s) to finish",
@ -418,7 +493,10 @@ fn is_connection_error(e: &io::Error) -> bool {
)
}
async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> {
async fn accept<L>(listener: &mut L) -> Option<(L::Io, L::Addr)>
where
L: Listener,
{
match listener.accept().await {
Ok(conn) => Some(conn),
Err(e) => {
@ -444,6 +522,35 @@ async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> {
}
}
/// An incoming stream.
///
/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
///
/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
#[derive(Debug)]
pub struct IncomingStream<'a, L>
where
L: Listener,
{
io: &'a TokioIo<L::Io>,
remote_addr: L::Addr,
}
impl<L> IncomingStream<'_, L>
where
L: Listener,
{
/// Get a reference to the inner IO type.
pub fn io(&self) -> &L::Io {
self.io.inner()
}
/// Returns the remote address that this stream is bound to.
pub fn remote_addr(&self) -> &L::Addr {
&self.remote_addr
}
}
mod private {
use std::{
future::Future,
@ -470,33 +577,15 @@ mod private {
}
}
/// An incoming stream.
///
/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
///
/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
#[derive(Debug)]
pub struct IncomingStream<'a> {
tcp_stream: &'a TokioIo<TcpStream>,
remote_addr: SocketAddr,
}
impl IncomingStream<'_> {
/// Returns the local address that this stream is bound to.
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.tcp_stream.inner().local_addr()
}
/// Returns the remote address that this stream is bound to.
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}
#[cfg(test)]
mod tests {
use http::StatusCode;
use tokio::net::UnixListener;
use super::*;
use crate::{
body::to_bytes,
extract::connect_info::Connected,
handler::{Handler, HandlerWithoutStateExt},
routing::get,
Router,
@ -508,30 +597,63 @@ mod tests {
#[allow(dead_code, unused_must_use)]
async fn if_it_compiles_it_works() {
#[derive(Clone, Debug)]
struct UdsConnectInfo;
impl Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self {
Self
}
}
let router: Router = Router::new();
let addr = "0.0.0.0:0";
// router
serve(TcpListener::bind(addr).await.unwrap(), router.clone());
serve(UnixListener::bind("").unwrap(), router.clone());
serve(
TcpListener::bind(addr).await.unwrap(),
router.clone().into_make_service(),
);
serve(
UnixListener::bind("").unwrap(),
router.clone().into_make_service(),
);
serve(
TcpListener::bind(addr).await.unwrap(),
router.into_make_service_with_connect_info::<SocketAddr>(),
router
.clone()
.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
router.into_make_service_with_connect_info::<UdsConnectInfo>(),
);
// method router
serve(TcpListener::bind(addr).await.unwrap(), get(handler));
serve(UnixListener::bind("").unwrap(), get(handler));
serve(
TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service(),
);
serve(
UnixListener::bind("").unwrap(),
get(handler).into_make_service(),
);
serve(
TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service_with_connect_info::<SocketAddr>(),
get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
);
// handler
@ -539,17 +661,27 @@ mod tests {
TcpListener::bind(addr).await.unwrap(),
handler.into_service(),
);
serve(UnixListener::bind("").unwrap(), handler.into_service());
serve(
TcpListener::bind(addr).await.unwrap(),
handler.with_state(()),
);
serve(UnixListener::bind("").unwrap(), handler.with_state(()));
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_make_service(),
);
serve(UnixListener::bind("").unwrap(), handler.into_make_service());
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_make_service_with_connect_info::<SocketAddr>(),
handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
);
// nodelay
@ -593,4 +725,49 @@ mod tests {
assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
assert_ne!(address.port(), 0);
}
#[crate::test]
async fn serving_on_custom_io_type() {
struct ReadyListener<T>(Option<T>);
impl<T> Listener for ReadyListener<T>
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Io = T;
type Addr = ();
async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> {
match self.0.take() {
Some(server) => Ok((server, ())),
None => std::future::pending().await,
}
}
fn local_addr(&self) -> io::Result<Self::Addr> {
Ok(())
}
}
let (client, server) = tokio::io::duplex(1024);
let listener = ReadyListener(Some(server));
let app = Router::new().route("/", get(|| async { "Hello, World!" }));
tokio::spawn(serve(listener, app).into_future());
let stream = TokioIo::new(client);
let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap();
tokio::spawn(conn);
let request = Request::builder().body(Body::empty()).unwrap();
let response = sender.send_request(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = Body::new(response.into_body());
let body = to_bytes(body, usize::MAX).await.unwrap();
let body = String::from_utf8(body.to_vec()).unwrap();
assert_eq!(body, "Hello, World!");
}
}

View file

@ -23,17 +23,13 @@ mod unix {
extract::connect_info::{self, ConnectInfo},
http::{Method, Request, StatusCode},
routing::get,
serve::IncomingStream,
Router,
};
use http_body_util::BodyExt;
use hyper::body::Incoming;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server,
};
use std::{convert::Infallible, path::PathBuf, sync::Arc};
use hyper_util::rt::TokioIo;
use std::{path::PathBuf, sync::Arc};
use tokio::net::{unix::UCred, UnixListener, UnixStream};
use tower::Service;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
pub async fn server() {
@ -54,33 +50,11 @@ mod unix {
let uds = UnixListener::bind(path.clone()).unwrap();
tokio::spawn(async move {
let app = Router::new().route("/", get(handler));
let app = Router::new()
.route("/", get(handler))
.into_make_service_with_connect_info::<UdsConnectInfo>();
let mut make_service = app.into_make_service_with_connect_info::<UdsConnectInfo>();
// See https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs for
// more details about this setup
loop {
let (socket, _remote_addr) = uds.accept().await.unwrap();
let tower_service = unwrap_infallible(make_service.call(&socket).await);
tokio::spawn(async move {
let socket = TokioIo::new(socket);
let hyper_service =
hyper::service::service_fn(move |request: Request<Incoming>| {
tower_service.clone().call(request)
});
if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(socket, hyper_service)
.await
{
eprintln!("failed to serve connection: {err:#}");
}
});
}
axum::serve(uds, app).await.unwrap();
});
let stream = TokioIo::new(UnixStream::connect(path).await.unwrap());
@ -119,22 +93,14 @@ mod unix {
peer_cred: UCred,
}
impl connect_info::Connected<&UnixStream> for UdsConnectInfo {
fn connect_info(target: &UnixStream) -> Self {
let peer_addr = target.peer_addr().unwrap();
let peer_cred = target.peer_cred().unwrap();
impl connect_info::Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
fn connect_info(stream: IncomingStream<'_, UnixListener>) -> Self {
let peer_addr = stream.io().peer_addr().unwrap();
let peer_cred = stream.io().peer_cred().unwrap();
Self {
peer_addr: Arc::new(peer_addr),
peer_cred,
}
}
}
fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
match result {
Ok(value) => value,
Err(err) => match err {},
}
}
}