Make serve generic over the listener and IO types

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
Jonas Platte 2024-09-27 23:57:26 +02:00
parent 689ca1aea2
commit 26a37a4679
No known key found for this signature in database
GPG key ID: 7D261D771D915378
9 changed files with 297 additions and 142 deletions

View file

@ -8,9 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased # Unreleased
- **breaking:** The tuple and tuple_struct `Path` extractor deserializers now check that the number of parameters matches the tuple length exactly ([#2931]) - **breaking:** The tuple and tuple_struct `Path` extractor deserializers now check that the number of parameters matches the tuple length exactly ([#2931])
- **breaking:** Make `serve` generic over the listener and IO types ([#2941])
- **change:** Update minimum rust version to 1.75 ([#2943]) - **change:** Update minimum rust version to 1.75 ([#2943])
[#2931]: https://github.com/tokio-rs/axum/pull/2931 [#2931]: https://github.com/tokio-rs/axum/pull/2931
[#2941]: https://github.com/tokio-rs/axum/pull/2941
[#2943]: https://github.com/tokio-rs/axum/pull/2943 [#2943]: https://github.com/tokio-rs/axum/pull/2943
# 0.7.7 # 0.7.7

View file

@ -113,6 +113,7 @@ features = [
[dev-dependencies] [dev-dependencies]
anyhow = "1.0" anyhow = "1.0"
axum-macros = { path = "../axum-macros", version = "0.4.1", features = ["__private"] } axum-macros = { path = "../axum-macros", version = "0.4.1", features = ["__private"] }
hyper = { version = "1.1.0", features = ["client"] }
quickcheck = "1.0" quickcheck = "1.0"
quickcheck_macros = "1.0" quickcheck_macros = "1.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] } reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] }

View file

@ -35,6 +35,7 @@ use axum::{
serve::IncomingStream, serve::IncomingStream,
Router, Router,
}; };
use tokio::net::TcpListener;
let app = Router::new().route("/", get(handler)); let app = Router::new().route("/", get(handler));
@ -49,8 +50,8 @@ struct MyConnectInfo {
// ... // ...
} }
impl Connected<IncomingStream<'_>> for MyConnectInfo { impl Connected<IncomingStream<'_, TcpListener>> for MyConnectInfo {
fn connect_info(target: IncomingStream<'_>) -> Self { fn connect_info(target: IncomingStream<'_, TcpListener>) -> Self {
MyConnectInfo { MyConnectInfo {
// ... // ...
} }

View file

@ -80,16 +80,17 @@ where
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
pub trait Connected<T>: Clone + Send + Sync + 'static { pub trait Connected<T>: Clone + Send + Sync + 'static {
/// Create type holding information about the connection. /// Create type holding information about the connection.
fn connect_info(target: T) -> Self; fn connect_info(stream: T) -> Self;
} }
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = { const _: () = {
use crate::serve::IncomingStream; use crate::serve;
use tokio::net::TcpListener;
impl Connected<IncomingStream<'_>> for SocketAddr { impl Connected<serve::IncomingStream<'_, TcpListener>> for SocketAddr {
fn connect_info(target: IncomingStream<'_>) -> Self { fn connect_info(stream: serve::IncomingStream<'_, TcpListener>) -> Self {
target.remote_addr() *stream.remote_addr()
} }
} }
}; };
@ -263,8 +264,8 @@ mod tests {
value: &'static str, value: &'static str,
} }
impl Connected<IncomingStream<'_>> for MyConnectInfo { impl Connected<IncomingStream<'_, TcpListener>> for MyConnectInfo {
fn connect_info(_target: IncomingStream<'_>) -> Self { fn connect_info(_target: IncomingStream<'_, TcpListener>) -> Self {
Self { Self {
value: "it worked!", value: "it worked!",
} }

View file

@ -180,12 +180,13 @@ where
// for `axum::serve(listener, handler)` // for `axum::serve(listener, handler)`
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = { const _: () = {
use crate::serve::IncomingStream; use crate::serve;
impl<H, T, S> Service<IncomingStream<'_>> for HandlerService<H, T, S> impl<H, T, S, L> Service<serve::IncomingStream<'_, L>> for HandlerService<H, T, S>
where where
H: Clone, H: Clone,
S: Clone, S: Clone,
L: serve::Listener,
{ {
type Response = Self; type Response = Self;
type Error = Infallible; type Error = Infallible;
@ -195,7 +196,7 @@ const _: () = {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
std::future::ready(Ok(self.clone())) std::future::ready(Ok(self.clone()))
} }
} }

View file

@ -1227,9 +1227,12 @@ where
// for `axum::serve(listener, router)` // for `axum::serve(listener, router)`
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = { const _: () = {
use crate::serve::IncomingStream; use crate::serve;
impl Service<IncomingStream<'_>> for MethodRouter<()> { impl<L> Service<serve::IncomingStream<'_, L>> for MethodRouter<()>
where
L: serve::Listener,
{
type Response = Self; type Response = Self;
type Error = Infallible; type Error = Infallible;
type Future = std::future::Ready<Result<Self::Response, Self::Error>>; type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
@ -1238,7 +1241,7 @@ const _: () = {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
std::future::ready(Ok(self.clone().with_state(()))) std::future::ready(Ok(self.clone().with_state(())))
} }
} }

View file

@ -486,9 +486,12 @@ impl Router {
// for `axum::serve(listener, router)` // for `axum::serve(listener, router)`
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = { const _: () = {
use crate::serve::IncomingStream; use crate::serve;
impl Service<IncomingStream<'_>> for Router<()> { impl<L> Service<serve::IncomingStream<'_, L>> for Router<()>
where
L: serve::Listener,
{
type Response = Self; type Response = Self;
type Error = Infallible; type Error = Infallible;
type Future = std::future::Ready<Result<Self::Response, Self::Error>>; type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
@ -497,7 +500,7 @@ const _: () = {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
// call `Router::with_state` such that everything is turned into `Route` eagerly // call `Router::with_state` such that everything is turned into `Route` eagerly
// rather than doing that per request // rather than doing that per request
std::future::ready(Ok(self.clone().with_state(()))) std::future::ready(Ok(self.clone().with_state(())))

View file

@ -1,12 +1,12 @@
//! Serve services. //! Serve services.
use std::{ use std::{
any::TypeId,
convert::Infallible, convert::Infallible,
fmt::Debug, fmt::Debug,
future::{poll_fn, Future, IntoFuture}, future::{poll_fn, Future, IntoFuture},
io, io,
marker::PhantomData, marker::PhantomData,
net::SocketAddr,
sync::Arc, sync::Arc,
time::Duration, time::Duration,
}; };
@ -18,12 +18,59 @@ use hyper_util::rt::{TokioExecutor, TokioIo};
#[cfg(any(feature = "http1", feature = "http2"))] #[cfg(any(feature = "http1", feature = "http2"))]
use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService}; use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
use tokio::{ use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream}, net::{TcpListener, TcpStream},
sync::watch, sync::watch,
}; };
use tower::ServiceExt as _; use tower::ServiceExt as _;
use tower_service::Service; use tower_service::Service;
/// Types that can listen for connections.
pub trait Listener: Send + 'static {
/// The listener's IO type.
type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static;
/// The listener's address type.
type Addr: Send;
/// Accept a new incoming connection to this listener
fn accept(&mut self) -> impl Future<Output = io::Result<(Self::Io, Self::Addr)>> + Send;
/// Returns the local address that this listener is bound to.
fn local_addr(&self) -> io::Result<Self::Addr>;
}
impl Listener for TcpListener {
type Io = TcpStream;
type Addr = std::net::SocketAddr;
#[inline]
async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> {
Self::accept(self).await
}
#[inline]
fn local_addr(&self) -> io::Result<Self::Addr> {
Self::local_addr(self)
}
}
#[cfg(unix)]
impl Listener for tokio::net::UnixListener {
type Io = tokio::net::UnixStream;
type Addr = tokio::net::unix::SocketAddr;
#[inline]
async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> {
Self::accept(self).await
}
#[inline]
fn local_addr(&self) -> io::Result<Self::Addr> {
Self::local_addr(self)
}
}
/// Serve the service with the supplied listener. /// Serve the service with the supplied listener.
/// ///
/// This method of running a service is intentionally simple and doesn't support any configuration. /// This method of running a service is intentionally simple and doesn't support any configuration.
@ -89,14 +136,15 @@ use tower_service::Service;
/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info /// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info /// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S> pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
where where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>, L: Listener,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static, S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send, S::Future: Send,
{ {
Serve { Serve {
tcp_listener, listener,
make_service, make_service,
tcp_nodelay: None, tcp_nodelay: None,
_marker: PhantomData, _marker: PhantomData,
@ -106,15 +154,18 @@ where
/// Future returned by [`serve`]. /// Future returned by [`serve`].
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"] #[must_use = "futures must be awaited or polled"]
pub struct Serve<M, S> { pub struct Serve<L, M, S> {
tcp_listener: TcpListener, listener: L,
make_service: M, make_service: M,
tcp_nodelay: Option<bool>, tcp_nodelay: Option<bool>,
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Serve<M, S> { impl<L, M, S> Serve<L, M, S>
where
L: Listener,
{
/// Prepares a server to handle graceful shutdown when the provided future completes. /// Prepares a server to handle graceful shutdown when the provided future completes.
/// ///
/// # Example /// # Example
@ -136,12 +187,12 @@ impl<M, S> Serve<M, S> {
/// // ... /// // ...
/// } /// }
/// ``` /// ```
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F> pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
where where
F: Future<Output = ()> + Send + 'static, F: Future<Output = ()> + Send + 'static,
{ {
WithGracefulShutdown { WithGracefulShutdown {
tcp_listener: self.tcp_listener, listener: self.listener,
make_service: self.make_service, make_service: self.make_service,
signal, signal,
tcp_nodelay: self.tcp_nodelay, tcp_nodelay: self.tcp_nodelay,
@ -149,6 +200,14 @@ impl<M, S> Serve<M, S> {
} }
} }
/// Returns the local address this server is bound to.
pub fn local_addr(&self) -> io::Result<L::Addr> {
self.listener.local_addr()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Serve<TcpListener, M, S> {
/// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
/// ///
/// See also [`TcpStream::set_nodelay`]. /// See also [`TcpStream::set_nodelay`].
@ -173,39 +232,41 @@ impl<M, S> Serve<M, S> {
..self ..self
} }
} }
/// Returns the local address this server is bound to.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.tcp_listener.local_addr()
}
} }
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Debug for Serve<M, S> impl<L, M, S> Debug for Serve<L, M, S>
where where
L: Debug + 'static,
M: Debug, M: Debug,
{ {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self { let Self {
tcp_listener, listener,
make_service, make_service,
tcp_nodelay, tcp_nodelay,
_marker: _, _marker: _,
} = self; } = self;
f.debug_struct("Serve") let mut s = f.debug_struct("Serve");
.field("tcp_listener", tcp_listener) s.field("listener", listener)
.field("make_service", make_service) .field("make_service", make_service);
.field("tcp_nodelay", tcp_nodelay)
.finish() if TypeId::of::<L>() == TypeId::of::<TcpListener>() {
s.field("tcp_nodelay", tcp_nodelay);
}
s.finish()
} }
} }
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> IntoFuture for Serve<M, S> impl<L, M, S> IntoFuture for Serve<L, M, S>
where where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static, L: Listener,
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send, L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static, S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send, S::Future: Send,
{ {
@ -221,15 +282,27 @@ where
/// Serve future with graceful shutdown enabled. /// Serve future with graceful shutdown enabled.
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"] #[must_use = "futures must be awaited or polled"]
pub struct WithGracefulShutdown<M, S, F> { pub struct WithGracefulShutdown<L, M, S, F> {
tcp_listener: TcpListener, listener: L,
make_service: M, make_service: M,
signal: F, signal: F,
tcp_nodelay: Option<bool>, tcp_nodelay: Option<bool>,
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
impl<M, S, F> WithGracefulShutdown<M, S, F> { #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
where
L: Listener,
{
/// Returns the local address this server is bound to.
pub fn local_addr(&self) -> io::Result<L::Addr> {
self.listener.local_addr()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> WithGracefulShutdown<TcpListener, M, S, F> {
/// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
/// ///
/// See also [`TcpStream::set_nodelay`]. /// See also [`TcpStream::set_nodelay`].
@ -259,43 +332,45 @@ impl<M, S, F> WithGracefulShutdown<M, S, F> {
..self ..self
} }
} }
/// Returns the local address this server is bound to.
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.tcp_listener.local_addr()
}
} }
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> Debug for WithGracefulShutdown<M, S, F> impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
where where
L: Debug + 'static,
M: Debug, M: Debug,
S: Debug, S: Debug,
F: Debug, F: Debug,
{ {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self { let Self {
tcp_listener, listener,
make_service, make_service,
signal, signal,
tcp_nodelay, tcp_nodelay,
_marker: _, _marker: _,
} = self; } = self;
f.debug_struct("WithGracefulShutdown") let mut s = f.debug_struct("WithGracefulShutdown");
.field("tcp_listener", tcp_listener) s.field("listener", listener)
.field("make_service", make_service) .field("make_service", make_service)
.field("signal", signal) .field("signal", signal);
.field("tcp_nodelay", tcp_nodelay)
.finish() if TypeId::of::<L>() == TypeId::of::<TcpListener>() {
s.field("tcp_nodelay", tcp_nodelay);
}
s.finish()
} }
} }
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F> impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
where where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static, L: Listener,
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send, L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static, S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send, S::Future: Send,
F: Future<Output = ()> + Send + 'static, F: Future<Output = ()> + Send + 'static,
@ -305,7 +380,7 @@ where
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
let Self { let Self {
tcp_listener, mut listener,
mut make_service, mut make_service,
signal, signal,
tcp_nodelay, tcp_nodelay,
@ -324,8 +399,8 @@ where
private::ServeFuture(Box::pin(async move { private::ServeFuture(Box::pin(async move {
loop { loop {
let (tcp_stream, remote_addr) = tokio::select! { let (io, remote_addr) = tokio::select! {
conn = tcp_accept(&tcp_listener) => { conn = accept(&mut listener) => {
match conn { match conn {
Some(conn) => conn, Some(conn) => conn,
None => continue, None => continue,
@ -338,14 +413,16 @@ where
}; };
if let Some(nodelay) = tcp_nodelay { if let Some(nodelay) = tcp_nodelay {
let tcp_stream: &tokio::net::TcpStream = <dyn std::any::Any>::downcast_ref(&io)
.expect("internal error: tcp_nodelay used with the wrong type of listener");
if let Err(err) = tcp_stream.set_nodelay(nodelay) { if let Err(err) = tcp_stream.set_nodelay(nodelay) {
trace!("failed to set TCP_NODELAY on incoming connection: {err:#}"); trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
} }
} }
let tcp_stream = TokioIo::new(tcp_stream); let io = TokioIo::new(io);
trace!("connection {remote_addr} accepted"); trace!("connection {remote_addr:?} accepted");
poll_fn(|cx| make_service.poll_ready(cx)) poll_fn(|cx| make_service.poll_ready(cx))
.await .await
@ -353,7 +430,7 @@ where
let tower_service = make_service let tower_service = make_service
.call(IncomingStream { .call(IncomingStream {
tcp_stream: &tcp_stream, io: &io,
remote_addr, remote_addr,
}) })
.await .await
@ -368,7 +445,7 @@ where
tokio::spawn(async move { tokio::spawn(async move {
let builder = Builder::new(TokioExecutor::new()); let builder = Builder::new(TokioExecutor::new());
let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service); let conn = builder.serve_connection_with_upgrades(io, hyper_service);
pin_mut!(conn); pin_mut!(conn);
let signal_closed = signal_tx.closed().fuse(); let signal_closed = signal_tx.closed().fuse();
@ -389,14 +466,12 @@ where
} }
} }
trace!("connection {remote_addr} closed");
drop(close_rx); drop(close_rx);
}); });
} }
drop(close_rx); drop(close_rx);
drop(tcp_listener); drop(listener);
trace!( trace!(
"waiting for {} task(s) to finish", "waiting for {} task(s) to finish",
@ -418,7 +493,10 @@ fn is_connection_error(e: &io::Error) -> bool {
) )
} }
async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> { async fn accept<L>(listener: &mut L) -> Option<(L::Io, L::Addr)>
where
L: Listener,
{
match listener.accept().await { match listener.accept().await {
Ok(conn) => Some(conn), Ok(conn) => Some(conn),
Err(e) => { Err(e) => {
@ -444,6 +522,35 @@ async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> {
} }
} }
/// An incoming stream.
///
/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
///
/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
#[derive(Debug)]
pub struct IncomingStream<'a, L>
where
L: Listener,
{
io: &'a TokioIo<L::Io>,
remote_addr: L::Addr,
}
impl<L> IncomingStream<'_, L>
where
L: Listener,
{
/// Get a reference to the inner IO type.
pub fn io(&self) -> &L::Io {
self.io.inner()
}
/// Returns the remote address that this stream is bound to.
pub fn remote_addr(&self) -> &L::Addr {
&self.remote_addr
}
}
mod private { mod private {
use std::{ use std::{
future::Future, future::Future,
@ -470,33 +577,15 @@ mod private {
} }
} }
/// An incoming stream.
///
/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
///
/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
#[derive(Debug)]
pub struct IncomingStream<'a> {
tcp_stream: &'a TokioIo<TcpStream>,
remote_addr: SocketAddr,
}
impl IncomingStream<'_> {
/// Returns the local address that this stream is bound to.
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.tcp_stream.inner().local_addr()
}
/// Returns the remote address that this stream is bound to.
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use http::StatusCode;
use tokio::net::UnixListener;
use super::*; use super::*;
use crate::{ use crate::{
body::to_bytes,
extract::connect_info::Connected,
handler::{Handler, HandlerWithoutStateExt}, handler::{Handler, HandlerWithoutStateExt},
routing::get, routing::get,
Router, Router,
@ -508,30 +597,63 @@ mod tests {
#[allow(dead_code, unused_must_use)] #[allow(dead_code, unused_must_use)]
async fn if_it_compiles_it_works() { async fn if_it_compiles_it_works() {
#[derive(Clone, Debug)]
struct UdsConnectInfo;
impl Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self {
Self
}
}
let router: Router = Router::new(); let router: Router = Router::new();
let addr = "0.0.0.0:0"; let addr = "0.0.0.0:0";
// router // router
serve(TcpListener::bind(addr).await.unwrap(), router.clone()); serve(TcpListener::bind(addr).await.unwrap(), router.clone());
serve(UnixListener::bind("").unwrap(), router.clone());
serve( serve(
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
router.clone().into_make_service(), router.clone().into_make_service(),
); );
serve(
UnixListener::bind("").unwrap(),
router.clone().into_make_service(),
);
serve( serve(
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
router.into_make_service_with_connect_info::<SocketAddr>(), router
.clone()
.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
router.into_make_service_with_connect_info::<UdsConnectInfo>(),
); );
// method router // method router
serve(TcpListener::bind(addr).await.unwrap(), get(handler)); serve(TcpListener::bind(addr).await.unwrap(), get(handler));
serve(UnixListener::bind("").unwrap(), get(handler));
serve( serve(
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service(), get(handler).into_make_service(),
); );
serve(
UnixListener::bind("").unwrap(),
get(handler).into_make_service(),
);
serve( serve(
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service_with_connect_info::<SocketAddr>(), get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
); );
// handler // handler
@ -539,17 +661,27 @@ mod tests {
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
handler.into_service(), handler.into_service(),
); );
serve(UnixListener::bind("").unwrap(), handler.into_service());
serve( serve(
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
handler.with_state(()), handler.with_state(()),
); );
serve(UnixListener::bind("").unwrap(), handler.with_state(()));
serve( serve(
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
handler.into_make_service(), handler.into_make_service(),
); );
serve(UnixListener::bind("").unwrap(), handler.into_make_service());
serve( serve(
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
handler.into_make_service_with_connect_info::<SocketAddr>(), handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
); );
// nodelay // nodelay
@ -593,4 +725,49 @@ mod tests {
assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
assert_ne!(address.port(), 0); assert_ne!(address.port(), 0);
} }
#[crate::test]
async fn serving_on_custom_io_type() {
struct ReadyListener<T>(Option<T>);
impl<T> Listener for ReadyListener<T>
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Io = T;
type Addr = ();
async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> {
match self.0.take() {
Some(server) => Ok((server, ())),
None => std::future::pending().await,
}
}
fn local_addr(&self) -> io::Result<Self::Addr> {
Ok(())
}
}
let (client, server) = tokio::io::duplex(1024);
let listener = ReadyListener(Some(server));
let app = Router::new().route("/", get(|| async { "Hello, World!" }));
tokio::spawn(serve(listener, app).into_future());
let stream = TokioIo::new(client);
let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap();
tokio::spawn(conn);
let request = Request::builder().body(Body::empty()).unwrap();
let response = sender.send_request(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = Body::new(response.into_body());
let body = to_bytes(body, usize::MAX).await.unwrap();
let body = String::from_utf8(body.to_vec()).unwrap();
assert_eq!(body, "Hello, World!");
}
} }

View file

@ -23,17 +23,13 @@ mod unix {
extract::connect_info::{self, ConnectInfo}, extract::connect_info::{self, ConnectInfo},
http::{Method, Request, StatusCode}, http::{Method, Request, StatusCode},
routing::get, routing::get,
serve::IncomingStream,
Router, Router,
}; };
use http_body_util::BodyExt; use http_body_util::BodyExt;
use hyper::body::Incoming; use hyper_util::rt::TokioIo;
use hyper_util::{ use std::{path::PathBuf, sync::Arc};
rt::{TokioExecutor, TokioIo},
server,
};
use std::{convert::Infallible, path::PathBuf, sync::Arc};
use tokio::net::{unix::UCred, UnixListener, UnixStream}; use tokio::net::{unix::UCred, UnixListener, UnixStream};
use tower::Service;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
pub async fn server() { pub async fn server() {
@ -54,33 +50,11 @@ mod unix {
let uds = UnixListener::bind(path.clone()).unwrap(); let uds = UnixListener::bind(path.clone()).unwrap();
tokio::spawn(async move { tokio::spawn(async move {
let app = Router::new().route("/", get(handler)); let app = Router::new()
.route("/", get(handler))
.into_make_service_with_connect_info::<UdsConnectInfo>();
let mut make_service = app.into_make_service_with_connect_info::<UdsConnectInfo>(); axum::serve(uds, app).await.unwrap();
// See https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs for
// more details about this setup
loop {
let (socket, _remote_addr) = uds.accept().await.unwrap();
let tower_service = unwrap_infallible(make_service.call(&socket).await);
tokio::spawn(async move {
let socket = TokioIo::new(socket);
let hyper_service =
hyper::service::service_fn(move |request: Request<Incoming>| {
tower_service.clone().call(request)
});
if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(socket, hyper_service)
.await
{
eprintln!("failed to serve connection: {err:#}");
}
});
}
}); });
let stream = TokioIo::new(UnixStream::connect(path).await.unwrap()); let stream = TokioIo::new(UnixStream::connect(path).await.unwrap());
@ -119,22 +93,14 @@ mod unix {
peer_cred: UCred, peer_cred: UCred,
} }
impl connect_info::Connected<&UnixStream> for UdsConnectInfo { impl connect_info::Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
fn connect_info(target: &UnixStream) -> Self { fn connect_info(stream: IncomingStream<'_, UnixListener>) -> Self {
let peer_addr = target.peer_addr().unwrap(); let peer_addr = stream.io().peer_addr().unwrap();
let peer_cred = target.peer_cred().unwrap(); let peer_cred = stream.io().peer_cred().unwrap();
Self { Self {
peer_addr: Arc::new(peer_addr), peer_addr: Arc::new(peer_addr),
peer_cred, peer_cred,
} }
} }
} }
fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
match result {
Ok(value) => value,
Err(err) => match err {},
}
}
} }