Make serve generic over the listener and IO types (#2941)

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
Jonas Platte 2024-11-30 22:47:13 +00:00 committed by GitHub
parent d84136e1e4
commit 84c3960639
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 399 additions and 254 deletions

View file

@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Extend `FailedToDeserializePathParams::kind` enum with (`ErrorKind::DeserializeError`) - **added:** Extend `FailedToDeserializePathParams::kind` enum with (`ErrorKind::DeserializeError`)
This new variant captures both `key`, `value`, and `message` from named path parameters parse errors, This new variant captures both `key`, `value`, and `message` from named path parameters parse errors,
instead of only deserialization error message in `ErrorKind::Message`. ([#2720]) instead of only deserialization error message in `ErrorKind::Message`. ([#2720])
- **breaking:** Make `serve` generic over the listener and IO types ([#2941])
[#2897]: https://github.com/tokio-rs/axum/pull/2897 [#2897]: https://github.com/tokio-rs/axum/pull/2897
[#2903]: https://github.com/tokio-rs/axum/pull/2903 [#2903]: https://github.com/tokio-rs/axum/pull/2903
@ -34,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#2992]: https://github.com/tokio-rs/axum/pull/2992 [#2992]: https://github.com/tokio-rs/axum/pull/2992
[#2720]: https://github.com/tokio-rs/axum/pull/2720 [#2720]: https://github.com/tokio-rs/axum/pull/2720
[#3039]: https://github.com/tokio-rs/axum/pull/3039 [#3039]: https://github.com/tokio-rs/axum/pull/3039
[#2941]: https://github.com/tokio-rs/axum/pull/2941
# 0.8.0 # 0.8.0

View file

@ -35,6 +35,7 @@ use axum::{
serve::IncomingStream, serve::IncomingStream,
Router, Router,
}; };
use tokio::net::TcpListener;
let app = Router::new().route("/", get(handler)); let app = Router::new().route("/", get(handler));
@ -49,8 +50,8 @@ struct MyConnectInfo {
// ... // ...
} }
impl Connected<IncomingStream<'_>> for MyConnectInfo { impl Connected<IncomingStream<'_, TcpListener>> for MyConnectInfo {
fn connect_info(target: IncomingStream<'_>) -> Self { fn connect_info(target: IncomingStream<'_, TcpListener>) -> Self {
MyConnectInfo { MyConnectInfo {
// ... // ...
} }

View file

@ -79,16 +79,17 @@ where
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
pub trait Connected<T>: Clone + Send + Sync + 'static { pub trait Connected<T>: Clone + Send + Sync + 'static {
/// Create type holding information about the connection. /// Create type holding information about the connection.
fn connect_info(target: T) -> Self; fn connect_info(stream: T) -> Self;
} }
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = { const _: () = {
use crate::serve::IncomingStream; use crate::serve;
use tokio::net::TcpListener;
impl Connected<IncomingStream<'_>> for SocketAddr { impl Connected<serve::IncomingStream<'_, TcpListener>> for SocketAddr {
fn connect_info(target: IncomingStream<'_>) -> Self { fn connect_info(stream: serve::IncomingStream<'_, TcpListener>) -> Self {
target.remote_addr() *stream.remote_addr()
} }
} }
}; };
@ -261,8 +262,8 @@ mod tests {
value: &'static str, value: &'static str,
} }
impl Connected<IncomingStream<'_>> for MyConnectInfo { impl Connected<IncomingStream<'_, TcpListener>> for MyConnectInfo {
fn connect_info(_target: IncomingStream<'_>) -> Self { fn connect_info(_target: IncomingStream<'_, TcpListener>) -> Self {
Self { Self {
value: "it worked!", value: "it worked!",
} }

View file

@ -180,12 +180,13 @@ where
// for `axum::serve(listener, handler)` // for `axum::serve(listener, handler)`
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = { const _: () = {
use crate::serve::IncomingStream; use crate::serve;
impl<H, T, S> Service<IncomingStream<'_>> for HandlerService<H, T, S> impl<H, T, S, L> Service<serve::IncomingStream<'_, L>> for HandlerService<H, T, S>
where where
H: Clone, H: Clone,
S: Clone, S: Clone,
L: serve::Listener,
{ {
type Response = Self; type Response = Self;
type Error = Infallible; type Error = Infallible;
@ -195,7 +196,7 @@ const _: () = {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
std::future::ready(Ok(self.clone())) std::future::ready(Ok(self.clone()))
} }
} }

View file

@ -1313,9 +1313,12 @@ where
// for `axum::serve(listener, router)` // for `axum::serve(listener, router)`
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = { const _: () = {
use crate::serve::IncomingStream; use crate::serve;
impl Service<IncomingStream<'_>> for MethodRouter<()> { impl<L> Service<serve::IncomingStream<'_, L>> for MethodRouter<()>
where
L: serve::Listener,
{
type Response = Self; type Response = Self;
type Error = Infallible; type Error = Infallible;
type Future = std::future::Ready<Result<Self::Response, Self::Error>>; type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
@ -1324,7 +1327,7 @@ const _: () = {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
std::future::ready(Ok(self.clone().with_state(()))) std::future::ready(Ok(self.clone().with_state(())))
} }
} }

View file

@ -518,9 +518,12 @@ impl Router {
// for `axum::serve(listener, router)` // for `axum::serve(listener, router)`
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = { const _: () = {
use crate::serve::IncomingStream; use crate::serve;
impl Service<IncomingStream<'_>> for Router<()> { impl<L> Service<serve::IncomingStream<'_, L>> for Router<()>
where
L: serve::Listener,
{
type Response = Self; type Response = Self;
type Error = Infallible; type Error = Infallible;
type Future = std::future::Ready<Result<Self::Response, Self::Error>>; type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
@ -529,7 +532,7 @@ const _: () = {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
// call `Router::with_state` such that everything is turned into `Route` eagerly // call `Router::with_state` such that everything is turned into `Route` eagerly
// rather than doing that per request // rather than doing that per request
std::future::ready(Ok(self.clone().with_state(()))) std::future::ready(Ok(self.clone().with_state(())))

View file

@ -1,14 +1,13 @@
//! Serve services. //! Serve services.
use std::{ use std::{
any::TypeId,
convert::Infallible, convert::Infallible,
fmt::Debug, fmt::Debug,
future::{poll_fn, Future, IntoFuture}, future::{poll_fn, Future, IntoFuture},
io, io,
marker::PhantomData, marker::PhantomData,
net::SocketAddr,
sync::Arc, sync::Arc,
time::Duration,
}; };
use axum_core::{body::Body, extract::Request, response::Response}; use axum_core::{body::Body, extract::Request, response::Response};
@ -17,13 +16,14 @@ use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo}; use hyper_util::rt::{TokioExecutor, TokioIo};
#[cfg(any(feature = "http1", feature = "http2"))] #[cfg(any(feature = "http1", feature = "http2"))]
use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService}; use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
use tokio::{ use tokio::{net::TcpListener, sync::watch};
net::{TcpListener, TcpStream},
sync::watch,
};
use tower::ServiceExt as _; use tower::ServiceExt as _;
use tower_service::Service; use tower_service::Service;
mod listener;
pub use self::listener::{Listener, ListenerExt, TapIo};
/// Serve the service with the supplied listener. /// Serve the service with the supplied listener.
/// ///
/// This method of running a service is intentionally simple and doesn't support any configuration. /// This method of running a service is intentionally simple and doesn't support any configuration.
@ -89,14 +89,15 @@ use tower_service::Service;
/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info /// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info /// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S> pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
where where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>, L: Listener,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static, S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send, S::Future: Send,
{ {
Serve { Serve {
tcp_listener, listener,
make_service, make_service,
tcp_nodelay: None, tcp_nodelay: None,
_marker: PhantomData, _marker: PhantomData,
@ -106,15 +107,18 @@ where
/// Future returned by [`serve`]. /// Future returned by [`serve`].
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"] #[must_use = "futures must be awaited or polled"]
pub struct Serve<M, S> { pub struct Serve<L, M, S> {
tcp_listener: TcpListener, listener: L,
make_service: M, make_service: M,
tcp_nodelay: Option<bool>, tcp_nodelay: Option<bool>,
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Serve<M, S> { impl<L, M, S> Serve<L, M, S>
where
L: Listener,
{
/// Prepares a server to handle graceful shutdown when the provided future completes. /// Prepares a server to handle graceful shutdown when the provided future completes.
/// ///
/// # Example /// # Example
@ -136,76 +140,57 @@ impl<M, S> Serve<M, S> {
/// // ... /// // ...
/// } /// }
/// ``` /// ```
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F> pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
where where
F: Future<Output = ()> + Send + 'static, F: Future<Output = ()> + Send + 'static,
{ {
WithGracefulShutdown { WithGracefulShutdown {
tcp_listener: self.tcp_listener, listener: self.listener,
make_service: self.make_service, make_service: self.make_service,
signal, signal,
tcp_nodelay: self.tcp_nodelay,
_marker: PhantomData, _marker: PhantomData,
} }
} }
/// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
///
/// See also [`TcpStream::set_nodelay`].
///
/// # Example
/// ```
/// use axum::{Router, routing::get};
///
/// # async {
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
///
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
/// axum::serve(listener, router)
/// .tcp_nodelay(true)
/// .await
/// .unwrap();
/// # };
/// ```
pub fn tcp_nodelay(self, nodelay: bool) -> Self {
Self {
tcp_nodelay: Some(nodelay),
..self
}
}
/// Returns the local address this server is bound to. /// Returns the local address this server is bound to.
pub fn local_addr(&self) -> io::Result<SocketAddr> { pub fn local_addr(&self) -> io::Result<L::Addr> {
self.tcp_listener.local_addr() self.listener.local_addr()
} }
} }
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Debug for Serve<M, S> impl<L, M, S> Debug for Serve<L, M, S>
where where
L: Debug + 'static,
M: Debug, M: Debug,
{ {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self { let Self {
tcp_listener, listener,
make_service, make_service,
tcp_nodelay, tcp_nodelay,
_marker: _, _marker: _,
} = self; } = self;
f.debug_struct("Serve") let mut s = f.debug_struct("Serve");
.field("tcp_listener", tcp_listener) s.field("listener", listener)
.field("make_service", make_service) .field("make_service", make_service);
.field("tcp_nodelay", tcp_nodelay)
.finish() if TypeId::of::<L>() == TypeId::of::<TcpListener>() {
s.field("tcp_nodelay", tcp_nodelay);
}
s.finish()
} }
} }
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> IntoFuture for Serve<M, S> impl<L, M, S> IntoFuture for Serve<L, M, S>
where where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static, L: Listener,
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send, L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static, S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send, S::Future: Send,
{ {
@ -221,81 +206,55 @@ where
/// Serve future with graceful shutdown enabled. /// Serve future with graceful shutdown enabled.
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"] #[must_use = "futures must be awaited or polled"]
pub struct WithGracefulShutdown<M, S, F> { pub struct WithGracefulShutdown<L, M, S, F> {
tcp_listener: TcpListener, listener: L,
make_service: M, make_service: M,
signal: F, signal: F,
tcp_nodelay: Option<bool>,
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
impl<M, S, F> WithGracefulShutdown<M, S, F> { #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
/// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
/// where
/// See also [`TcpStream::set_nodelay`]. L: Listener,
/// {
/// # Example
/// ```
/// use axum::{Router, routing::get};
///
/// # async {
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
///
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
/// axum::serve(listener, router)
/// .with_graceful_shutdown(shutdown_signal())
/// .tcp_nodelay(true)
/// .await
/// .unwrap();
/// # };
///
/// async fn shutdown_signal() {
/// // ...
/// }
/// ```
pub fn tcp_nodelay(self, nodelay: bool) -> Self {
Self {
tcp_nodelay: Some(nodelay),
..self
}
}
/// Returns the local address this server is bound to. /// Returns the local address this server is bound to.
pub fn local_addr(&self) -> io::Result<SocketAddr> { pub fn local_addr(&self) -> io::Result<L::Addr> {
self.tcp_listener.local_addr() self.listener.local_addr()
} }
} }
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> Debug for WithGracefulShutdown<M, S, F> impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
where where
L: Debug + 'static,
M: Debug, M: Debug,
S: Debug, S: Debug,
F: Debug, F: Debug,
{ {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self { let Self {
tcp_listener, listener,
make_service, make_service,
signal, signal,
tcp_nodelay,
_marker: _, _marker: _,
} = self; } = self;
f.debug_struct("WithGracefulShutdown") f.debug_struct("WithGracefulShutdown")
.field("tcp_listener", tcp_listener) .field("listener", listener)
.field("make_service", make_service) .field("make_service", make_service)
.field("signal", signal) .field("signal", signal)
.field("tcp_nodelay", tcp_nodelay)
.finish() .finish()
} }
} }
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F> impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
where where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static, L: Listener,
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send, L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static, S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send, S::Future: Send,
F: Future<Output = ()> + Send + 'static, F: Future<Output = ()> + Send + 'static,
@ -305,10 +264,9 @@ where
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
let Self { let Self {
tcp_listener, mut listener,
mut make_service, mut make_service,
signal, signal,
tcp_nodelay,
_marker: _, _marker: _,
} = self; } = self;
@ -324,28 +282,17 @@ where
let (close_tx, close_rx) = watch::channel(()); let (close_tx, close_rx) = watch::channel(());
loop { loop {
let (tcp_stream, remote_addr) = tokio::select! { let (io, remote_addr) = tokio::select! {
conn = tcp_accept(&tcp_listener) => { conn = listener.accept() => conn,
match conn {
Some(conn) => conn,
None => continue,
}
}
_ = signal_tx.closed() => { _ = signal_tx.closed() => {
trace!("signal received, not accepting new connections"); trace!("signal received, not accepting new connections");
break; break;
} }
}; };
if let Some(nodelay) = tcp_nodelay { let io = TokioIo::new(io);
if let Err(err) = tcp_stream.set_nodelay(nodelay) {
trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
}
}
let tcp_stream = TokioIo::new(tcp_stream); trace!("connection {remote_addr:?} accepted");
trace!("connection {remote_addr} accepted");
poll_fn(|cx| make_service.poll_ready(cx)) poll_fn(|cx| make_service.poll_ready(cx))
.await .await
@ -353,7 +300,7 @@ where
let tower_service = make_service let tower_service = make_service
.call(IncomingStream { .call(IncomingStream {
tcp_stream: &tcp_stream, io: &io,
remote_addr, remote_addr,
}) })
.await .await
@ -372,7 +319,7 @@ where
// CONNECT protocol needed for HTTP/2 websockets // CONNECT protocol needed for HTTP/2 websockets
#[cfg(feature = "http2")] #[cfg(feature = "http2")]
builder.http2().enable_connect_protocol(); builder.http2().enable_connect_protocol();
let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service); let conn = builder.serve_connection_with_upgrades(io, hyper_service);
pin_mut!(conn); pin_mut!(conn);
let signal_closed = signal_tx.closed().fuse(); let signal_closed = signal_tx.closed().fuse();
@ -393,14 +340,12 @@ where
} }
} }
trace!("connection {remote_addr} closed");
drop(close_rx); drop(close_rx);
}); });
} }
drop(close_rx); drop(close_rx);
drop(tcp_listener); drop(listener);
trace!( trace!(
"waiting for {} task(s) to finish", "waiting for {} task(s) to finish",
@ -413,38 +358,32 @@ where
} }
} }
fn is_connection_error(e: &io::Error) -> bool { /// An incoming stream.
matches!( ///
e.kind(), /// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
io::ErrorKind::ConnectionRefused ///
| io::ErrorKind::ConnectionAborted /// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
| io::ErrorKind::ConnectionReset #[derive(Debug)]
) pub struct IncomingStream<'a, L>
where
L: Listener,
{
io: &'a TokioIo<L::Io>,
remote_addr: L::Addr,
} }
async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> { impl<L> IncomingStream<'_, L>
match listener.accept().await { where
Ok(conn) => Some(conn), L: Listener,
Err(e) => { {
if is_connection_error(&e) { /// Get a reference to the inner IO type.
return None; pub fn io(&self) -> &L::Io {
self.io.inner()
} }
// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186) /// Returns the remote address that this stream is bound to.
// pub fn remote_addr(&self) -> &L::Addr {
// > A possible scenario is that the process has hit the max open files &self.remote_addr
// > allowed, and so trying to accept a new connection will fail with
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
// > the application will likely close some files (or connections), and try
// > to accept the connection again. If this option is `true`, the error
// > will be logged at the `error` level, since it is still a big deal,
// > and then the listener will sleep for 1 second.
//
// hyper allowed customizing this but axum does not.
error!("accept error: {e}");
tokio::time::sleep(Duration::from_secs(1)).await;
None
}
} }
} }
@ -474,68 +413,89 @@ mod private {
} }
} }
/// An incoming stream.
///
/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
///
/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
#[derive(Debug)]
pub struct IncomingStream<'a> {
tcp_stream: &'a TokioIo<TcpStream>,
remote_addr: SocketAddr,
}
impl IncomingStream<'_> {
/// Returns the local address that this stream is bound to.
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.tcp_stream.inner().local_addr()
}
/// Returns the remote address that this stream is bound to.
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use std::{
future::{pending, IntoFuture as _},
net::{IpAddr, Ipv4Addr},
};
use axum_core::{body::Body, extract::Request};
use http::StatusCode;
use hyper_util::rt::TokioIo;
use tokio::{
io::{self, AsyncRead, AsyncWrite},
net::{TcpListener, UnixListener},
};
use super::{serve, IncomingStream, Listener};
use crate::{ use crate::{
body::to_bytes,
extract::connect_info::Connected,
handler::{Handler, HandlerWithoutStateExt}, handler::{Handler, HandlerWithoutStateExt},
routing::get, routing::get,
Router, Router,
}; };
use std::{
future::pending,
net::{IpAddr, Ipv4Addr},
};
#[allow(dead_code, unused_must_use)] #[allow(dead_code, unused_must_use)]
async fn if_it_compiles_it_works() { async fn if_it_compiles_it_works() {
#[derive(Clone, Debug)]
struct UdsConnectInfo;
impl Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self {
Self
}
}
let router: Router = Router::new(); let router: Router = Router::new();
let addr = "0.0.0.0:0"; let addr = "0.0.0.0:0";
// router // router
serve(TcpListener::bind(addr).await.unwrap(), router.clone()); serve(TcpListener::bind(addr).await.unwrap(), router.clone());
serve(UnixListener::bind("").unwrap(), router.clone());
serve( serve(
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
router.clone().into_make_service(), router.clone().into_make_service(),
); );
serve(
UnixListener::bind("").unwrap(),
router.clone().into_make_service(),
);
serve( serve(
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
router.into_make_service_with_connect_info::<SocketAddr>(), router
.clone()
.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
router.into_make_service_with_connect_info::<UdsConnectInfo>(),
); );
// method router // method router
serve(TcpListener::bind(addr).await.unwrap(), get(handler)); serve(TcpListener::bind(addr).await.unwrap(), get(handler));
serve(UnixListener::bind("").unwrap(), get(handler));
serve( serve(
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service(), get(handler).into_make_service(),
); );
serve(
UnixListener::bind("").unwrap(),
get(handler).into_make_service(),
);
serve( serve(
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service_with_connect_info::<SocketAddr>(), get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
); );
// handler // handler
@ -543,32 +503,28 @@ mod tests {
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
handler.into_service(), handler.into_service(),
); );
serve(UnixListener::bind("").unwrap(), handler.into_service());
serve( serve(
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
handler.with_state(()), handler.with_state(()),
); );
serve(UnixListener::bind("").unwrap(), handler.with_state(()));
serve( serve(
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
handler.into_make_service(), handler.into_make_service(),
); );
serve(UnixListener::bind("").unwrap(), handler.into_make_service());
serve( serve(
TcpListener::bind(addr).await.unwrap(), TcpListener::bind(addr).await.unwrap(),
handler.into_make_service_with_connect_info::<SocketAddr>(), handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
); );
// nodelay
serve( serve(
TcpListener::bind(addr).await.unwrap(), UnixListener::bind("").unwrap(),
handler.into_service(), handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
) );
.tcp_nodelay(true);
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_service(),
)
.with_graceful_shutdown(async { /*...*/ })
.tcp_nodelay(true);
} }
async fn handler() {} async fn handler() {}
@ -613,4 +569,49 @@ mod tests {
// Call Serve::into_future outside of a tokio context. This used to panic. // Call Serve::into_future outside of a tokio context. This used to panic.
_ = serve(listener, router).into_future(); _ = serve(listener, router).into_future();
} }
#[crate::test]
async fn serving_on_custom_io_type() {
struct ReadyListener<T>(Option<T>);
impl<T> Listener for ReadyListener<T>
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Io = T;
type Addr = ();
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
match self.0.take() {
Some(server) => (server, ()),
None => std::future::pending().await,
}
}
fn local_addr(&self) -> io::Result<Self::Addr> {
Ok(())
}
}
let (client, server) = io::duplex(1024);
let listener = ReadyListener(Some(server));
let app = Router::new().route("/", get(|| async { "Hello, World!" }));
tokio::spawn(serve(listener, app).into_future());
let stream = TokioIo::new(client);
let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap();
tokio::spawn(conn);
let request = Request::builder().body(Body::empty()).unwrap();
let response = sender.send_request(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = Body::new(response.into_body());
let body = to_bytes(body, usize::MAX).await.unwrap();
let body = String::from_utf8(body.to_vec()).unwrap();
assert_eq!(body, "Hello, World!");
}
} }

167
axum/src/serve/listener.rs Normal file
View file

@ -0,0 +1,167 @@
use std::{fmt, future::Future, time::Duration};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream},
};
/// Types that can listen for connections.
pub trait Listener: Send + 'static {
/// The listener's IO type.
type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static;
/// The listener's address type.
type Addr: Send;
/// Accept a new incoming connection to this listener.
///
/// If the underlying accept call can return an error, this function must
/// take care of logging and retrying.
fn accept(&mut self) -> impl Future<Output = (Self::Io, Self::Addr)> + Send;
/// Returns the local address that this listener is bound to.
fn local_addr(&self) -> io::Result<Self::Addr>;
}
impl Listener for TcpListener {
type Io = TcpStream;
type Addr = std::net::SocketAddr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
loop {
match Self::accept(self).await {
Ok(tup) => return tup,
Err(e) => handle_accept_error(e).await,
}
}
}
#[inline]
fn local_addr(&self) -> io::Result<Self::Addr> {
Self::local_addr(self)
}
}
#[cfg(unix)]
impl Listener for tokio::net::UnixListener {
type Io = tokio::net::UnixStream;
type Addr = tokio::net::unix::SocketAddr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
loop {
match Self::accept(self).await {
Ok(tup) => return tup,
Err(e) => handle_accept_error(e).await,
}
}
}
#[inline]
fn local_addr(&self) -> io::Result<Self::Addr> {
Self::local_addr(self)
}
}
/// Extensions to [`Listener`].
pub trait ListenerExt: Listener + Sized {
/// Run a mutable closure on every accepted `Io`.
///
/// # Example
///
/// ```
/// use axum::{Router, routing::get, serve::ListenerExt};
/// use tracing::trace;
///
/// # async {
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
///
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
/// .await
/// .unwrap()
/// .tap_io(|tcp_stream| {
/// if let Err(err) = tcp_stream.set_nodelay(true) {
/// trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
/// }
/// });
/// axum::serve(listener, router).await.unwrap();
/// # };
/// ```
fn tap_io<F>(self, tap_fn: F) -> TapIo<Self, F>
where
F: FnMut(&mut Self::Io) + Send + 'static,
{
TapIo {
listener: self,
tap_fn,
}
}
}
impl<L: Listener> ListenerExt for L {}
/// Return type of [`ListenerExt::tap_io`].
///
/// See that method for details.
pub struct TapIo<L: Listener, F> {
listener: L,
tap_fn: F,
}
impl<L, F> fmt::Debug for TapIo<L, F>
where
L: Listener + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TapIo")
.field("listener", &self.listener)
.finish_non_exhaustive()
}
}
impl<L, F> Listener for TapIo<L, F>
where
L: Listener,
F: FnMut(&mut L::Io) + Send + 'static,
{
type Io = L::Io;
type Addr = L::Addr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
let (mut io, addr) = self.listener.accept().await;
(self.tap_fn)(&mut io);
(io, addr)
}
fn local_addr(&self) -> io::Result<Self::Addr> {
self.listener.local_addr()
}
}
async fn handle_accept_error(e: io::Error) {
if is_connection_error(&e) {
return;
}
// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
//
// > A possible scenario is that the process has hit the max open files
// > allowed, and so trying to accept a new connection will fail with
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
// > the application will likely close some files (or connections), and try
// > to accept the connection again. If this option is `true`, the error
// > will be logged at the `error` level, since it is still a big deal,
// > and then the listener will sleep for 1 second.
//
// hyper allowed customizing this but axum does not.
error!("accept error: {e}");
tokio::time::sleep(Duration::from_secs(1)).await;
}
fn is_connection_error(e: &io::Error) -> bool {
matches!(
e.kind(),
io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::ConnectionReset
)
}

View file

@ -21,17 +21,13 @@ mod unix {
extract::connect_info::{self, ConnectInfo}, extract::connect_info::{self, ConnectInfo},
http::{Method, Request, StatusCode}, http::{Method, Request, StatusCode},
routing::get, routing::get,
serve::IncomingStream,
Router, Router,
}; };
use http_body_util::BodyExt; use http_body_util::BodyExt;
use hyper::body::Incoming; use hyper_util::rt::TokioIo;
use hyper_util::{ use std::{path::PathBuf, sync::Arc};
rt::{TokioExecutor, TokioIo},
server,
};
use std::{convert::Infallible, path::PathBuf, sync::Arc};
use tokio::net::{unix::UCred, UnixListener, UnixStream}; use tokio::net::{unix::UCred, UnixListener, UnixStream};
use tower::Service;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
pub async fn server() { pub async fn server() {
@ -52,33 +48,11 @@ mod unix {
let uds = UnixListener::bind(path.clone()).unwrap(); let uds = UnixListener::bind(path.clone()).unwrap();
tokio::spawn(async move { tokio::spawn(async move {
let app = Router::new().route("/", get(handler)); let app = Router::new()
.route("/", get(handler))
.into_make_service_with_connect_info::<UdsConnectInfo>();
let mut make_service = app.into_make_service_with_connect_info::<UdsConnectInfo>(); axum::serve(uds, app).await.unwrap();
// See https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs for
// more details about this setup
loop {
let (socket, _remote_addr) = uds.accept().await.unwrap();
let tower_service = unwrap_infallible(make_service.call(&socket).await);
tokio::spawn(async move {
let socket = TokioIo::new(socket);
let hyper_service =
hyper::service::service_fn(move |request: Request<Incoming>| {
tower_service.clone().call(request)
});
if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(socket, hyper_service)
.await
{
eprintln!("failed to serve connection: {err:#}");
}
});
}
}); });
let stream = TokioIo::new(UnixStream::connect(path).await.unwrap()); let stream = TokioIo::new(UnixStream::connect(path).await.unwrap());
@ -117,22 +91,14 @@ mod unix {
peer_cred: UCred, peer_cred: UCred,
} }
impl connect_info::Connected<&UnixStream> for UdsConnectInfo { impl connect_info::Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
fn connect_info(target: &UnixStream) -> Self { fn connect_info(stream: IncomingStream<'_, UnixListener>) -> Self {
let peer_addr = target.peer_addr().unwrap(); let peer_addr = stream.io().peer_addr().unwrap();
let peer_cred = target.peer_cred().unwrap(); let peer_cred = stream.io().peer_cred().unwrap();
Self { Self {
peer_addr: Arc::new(peer_addr), peer_addr: Arc::new(peer_addr),
peer_cred, peer_cred,
} }
} }
} }
fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
match result {
Ok(value) => value,
Err(err) => match err {},
}
}
} }