mirror of
https://github.com/tokio-rs/axum.git
synced 2024-12-04 05:54:40 +01:00
Make serve
generic over the listener and IO types (#2941)
Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
parent
d84136e1e4
commit
84c3960639
9 changed files with 399 additions and 254 deletions
|
@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- **added:** Extend `FailedToDeserializePathParams::kind` enum with (`ErrorKind::DeserializeError`)
|
||||
This new variant captures both `key`, `value`, and `message` from named path parameters parse errors,
|
||||
instead of only deserialization error message in `ErrorKind::Message`. ([#2720])
|
||||
- **breaking:** Make `serve` generic over the listener and IO types ([#2941])
|
||||
|
||||
[#2897]: https://github.com/tokio-rs/axum/pull/2897
|
||||
[#2903]: https://github.com/tokio-rs/axum/pull/2903
|
||||
|
@ -34,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
[#2992]: https://github.com/tokio-rs/axum/pull/2992
|
||||
[#2720]: https://github.com/tokio-rs/axum/pull/2720
|
||||
[#3039]: https://github.com/tokio-rs/axum/pull/3039
|
||||
[#2941]: https://github.com/tokio-rs/axum/pull/2941
|
||||
|
||||
# 0.8.0
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ use axum::{
|
|||
serve::IncomingStream,
|
||||
Router,
|
||||
};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
let app = Router::new().route("/", get(handler));
|
||||
|
||||
|
@ -49,8 +50,8 @@ struct MyConnectInfo {
|
|||
// ...
|
||||
}
|
||||
|
||||
impl Connected<IncomingStream<'_>> for MyConnectInfo {
|
||||
fn connect_info(target: IncomingStream<'_>) -> Self {
|
||||
impl Connected<IncomingStream<'_, TcpListener>> for MyConnectInfo {
|
||||
fn connect_info(target: IncomingStream<'_, TcpListener>) -> Self {
|
||||
MyConnectInfo {
|
||||
// ...
|
||||
}
|
||||
|
|
|
@ -79,16 +79,17 @@ where
|
|||
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
|
||||
pub trait Connected<T>: Clone + Send + Sync + 'static {
|
||||
/// Create type holding information about the connection.
|
||||
fn connect_info(target: T) -> Self;
|
||||
fn connect_info(stream: T) -> Self;
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
const _: () = {
|
||||
use crate::serve::IncomingStream;
|
||||
use crate::serve;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
impl Connected<IncomingStream<'_>> for SocketAddr {
|
||||
fn connect_info(target: IncomingStream<'_>) -> Self {
|
||||
target.remote_addr()
|
||||
impl Connected<serve::IncomingStream<'_, TcpListener>> for SocketAddr {
|
||||
fn connect_info(stream: serve::IncomingStream<'_, TcpListener>) -> Self {
|
||||
*stream.remote_addr()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -261,8 +262,8 @@ mod tests {
|
|||
value: &'static str,
|
||||
}
|
||||
|
||||
impl Connected<IncomingStream<'_>> for MyConnectInfo {
|
||||
fn connect_info(_target: IncomingStream<'_>) -> Self {
|
||||
impl Connected<IncomingStream<'_, TcpListener>> for MyConnectInfo {
|
||||
fn connect_info(_target: IncomingStream<'_, TcpListener>) -> Self {
|
||||
Self {
|
||||
value: "it worked!",
|
||||
}
|
||||
|
|
|
@ -180,12 +180,13 @@ where
|
|||
// for `axum::serve(listener, handler)`
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
const _: () = {
|
||||
use crate::serve::IncomingStream;
|
||||
use crate::serve;
|
||||
|
||||
impl<H, T, S> Service<IncomingStream<'_>> for HandlerService<H, T, S>
|
||||
impl<H, T, S, L> Service<serve::IncomingStream<'_, L>> for HandlerService<H, T, S>
|
||||
where
|
||||
H: Clone,
|
||||
S: Clone,
|
||||
L: serve::Listener,
|
||||
{
|
||||
type Response = Self;
|
||||
type Error = Infallible;
|
||||
|
@ -195,7 +196,7 @@ const _: () = {
|
|||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
|
||||
fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
|
||||
std::future::ready(Ok(self.clone()))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1313,9 +1313,12 @@ where
|
|||
// for `axum::serve(listener, router)`
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
const _: () = {
|
||||
use crate::serve::IncomingStream;
|
||||
use crate::serve;
|
||||
|
||||
impl Service<IncomingStream<'_>> for MethodRouter<()> {
|
||||
impl<L> Service<serve::IncomingStream<'_, L>> for MethodRouter<()>
|
||||
where
|
||||
L: serve::Listener,
|
||||
{
|
||||
type Response = Self;
|
||||
type Error = Infallible;
|
||||
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
|
||||
|
@ -1324,7 +1327,7 @@ const _: () = {
|
|||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
|
||||
fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
|
||||
std::future::ready(Ok(self.clone().with_state(())))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -518,9 +518,12 @@ impl Router {
|
|||
// for `axum::serve(listener, router)`
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
const _: () = {
|
||||
use crate::serve::IncomingStream;
|
||||
use crate::serve;
|
||||
|
||||
impl Service<IncomingStream<'_>> for Router<()> {
|
||||
impl<L> Service<serve::IncomingStream<'_, L>> for Router<()>
|
||||
where
|
||||
L: serve::Listener,
|
||||
{
|
||||
type Response = Self;
|
||||
type Error = Infallible;
|
||||
type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
|
||||
|
@ -529,7 +532,7 @@ const _: () = {
|
|||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future {
|
||||
fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
|
||||
// call `Router::with_state` such that everything is turned into `Route` eagerly
|
||||
// rather than doing that per request
|
||||
std::future::ready(Ok(self.clone().with_state(())))
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
//! Serve services.
|
||||
|
||||
use std::{
|
||||
any::TypeId,
|
||||
convert::Infallible,
|
||||
fmt::Debug,
|
||||
future::{poll_fn, Future, IntoFuture},
|
||||
io,
|
||||
marker::PhantomData,
|
||||
net::SocketAddr,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use axum_core::{body::Body, extract::Request, response::Response};
|
||||
|
@ -17,13 +16,14 @@ use hyper::body::Incoming;
|
|||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
#[cfg(any(feature = "http1", feature = "http2"))]
|
||||
use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
|
||||
use tokio::{
|
||||
net::{TcpListener, TcpStream},
|
||||
sync::watch,
|
||||
};
|
||||
use tokio::{net::TcpListener, sync::watch};
|
||||
use tower::ServiceExt as _;
|
||||
use tower_service::Service;
|
||||
|
||||
mod listener;
|
||||
|
||||
pub use self::listener::{Listener, ListenerExt, TapIo};
|
||||
|
||||
/// Serve the service with the supplied listener.
|
||||
///
|
||||
/// This method of running a service is intentionally simple and doesn't support any configuration.
|
||||
|
@ -89,14 +89,15 @@ use tower_service::Service;
|
|||
/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
|
||||
/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S>
|
||||
pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
|
||||
where
|
||||
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
|
||||
L: Listener,
|
||||
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
|
||||
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
||||
S::Future: Send,
|
||||
{
|
||||
Serve {
|
||||
tcp_listener,
|
||||
listener,
|
||||
make_service,
|
||||
tcp_nodelay: None,
|
||||
_marker: PhantomData,
|
||||
|
@ -106,15 +107,18 @@ where
|
|||
/// Future returned by [`serve`].
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
#[must_use = "futures must be awaited or polled"]
|
||||
pub struct Serve<M, S> {
|
||||
tcp_listener: TcpListener,
|
||||
pub struct Serve<L, M, S> {
|
||||
listener: L,
|
||||
make_service: M,
|
||||
tcp_nodelay: Option<bool>,
|
||||
_marker: PhantomData<S>,
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
impl<M, S> Serve<M, S> {
|
||||
impl<L, M, S> Serve<L, M, S>
|
||||
where
|
||||
L: Listener,
|
||||
{
|
||||
/// Prepares a server to handle graceful shutdown when the provided future completes.
|
||||
///
|
||||
/// # Example
|
||||
|
@ -136,76 +140,57 @@ impl<M, S> Serve<M, S> {
|
|||
/// // ...
|
||||
/// }
|
||||
/// ```
|
||||
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
|
||||
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
WithGracefulShutdown {
|
||||
tcp_listener: self.tcp_listener,
|
||||
listener: self.listener,
|
||||
make_service: self.make_service,
|
||||
signal,
|
||||
tcp_nodelay: self.tcp_nodelay,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
|
||||
///
|
||||
/// See also [`TcpStream::set_nodelay`].
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use axum::{Router, routing::get};
|
||||
///
|
||||
/// # async {
|
||||
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
|
||||
///
|
||||
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
|
||||
/// axum::serve(listener, router)
|
||||
/// .tcp_nodelay(true)
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
pub fn tcp_nodelay(self, nodelay: bool) -> Self {
|
||||
Self {
|
||||
tcp_nodelay: Some(nodelay),
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the local address this server is bound to.
|
||||
pub fn local_addr(&self) -> io::Result<SocketAddr> {
|
||||
self.tcp_listener.local_addr()
|
||||
pub fn local_addr(&self) -> io::Result<L::Addr> {
|
||||
self.listener.local_addr()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
impl<M, S> Debug for Serve<M, S>
|
||||
impl<L, M, S> Debug for Serve<L, M, S>
|
||||
where
|
||||
L: Debug + 'static,
|
||||
M: Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let Self {
|
||||
tcp_listener,
|
||||
listener,
|
||||
make_service,
|
||||
tcp_nodelay,
|
||||
_marker: _,
|
||||
} = self;
|
||||
|
||||
f.debug_struct("Serve")
|
||||
.field("tcp_listener", tcp_listener)
|
||||
.field("make_service", make_service)
|
||||
.field("tcp_nodelay", tcp_nodelay)
|
||||
.finish()
|
||||
let mut s = f.debug_struct("Serve");
|
||||
s.field("listener", listener)
|
||||
.field("make_service", make_service);
|
||||
|
||||
if TypeId::of::<L>() == TypeId::of::<TcpListener>() {
|
||||
s.field("tcp_nodelay", tcp_nodelay);
|
||||
}
|
||||
|
||||
s.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
impl<M, S> IntoFuture for Serve<M, S>
|
||||
impl<L, M, S> IntoFuture for Serve<L, M, S>
|
||||
where
|
||||
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
|
||||
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
|
||||
L: Listener,
|
||||
L::Addr: Debug,
|
||||
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
|
||||
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
|
||||
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
||||
S::Future: Send,
|
||||
{
|
||||
|
@ -221,81 +206,55 @@ where
|
|||
/// Serve future with graceful shutdown enabled.
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
#[must_use = "futures must be awaited or polled"]
|
||||
pub struct WithGracefulShutdown<M, S, F> {
|
||||
tcp_listener: TcpListener,
|
||||
pub struct WithGracefulShutdown<L, M, S, F> {
|
||||
listener: L,
|
||||
make_service: M,
|
||||
signal: F,
|
||||
tcp_nodelay: Option<bool>,
|
||||
_marker: PhantomData<S>,
|
||||
}
|
||||
|
||||
impl<M, S, F> WithGracefulShutdown<M, S, F> {
|
||||
/// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection.
|
||||
///
|
||||
/// See also [`TcpStream::set_nodelay`].
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use axum::{Router, routing::get};
|
||||
///
|
||||
/// # async {
|
||||
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
|
||||
///
|
||||
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
|
||||
/// axum::serve(listener, router)
|
||||
/// .with_graceful_shutdown(shutdown_signal())
|
||||
/// .tcp_nodelay(true)
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// # };
|
||||
///
|
||||
/// async fn shutdown_signal() {
|
||||
/// // ...
|
||||
/// }
|
||||
/// ```
|
||||
pub fn tcp_nodelay(self, nodelay: bool) -> Self {
|
||||
Self {
|
||||
tcp_nodelay: Some(nodelay),
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
|
||||
where
|
||||
L: Listener,
|
||||
{
|
||||
/// Returns the local address this server is bound to.
|
||||
pub fn local_addr(&self) -> io::Result<SocketAddr> {
|
||||
self.tcp_listener.local_addr()
|
||||
pub fn local_addr(&self) -> io::Result<L::Addr> {
|
||||
self.listener.local_addr()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
impl<M, S, F> Debug for WithGracefulShutdown<M, S, F>
|
||||
impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
|
||||
where
|
||||
L: Debug + 'static,
|
||||
M: Debug,
|
||||
S: Debug,
|
||||
F: Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let Self {
|
||||
tcp_listener,
|
||||
listener,
|
||||
make_service,
|
||||
signal,
|
||||
tcp_nodelay,
|
||||
_marker: _,
|
||||
} = self;
|
||||
|
||||
f.debug_struct("WithGracefulShutdown")
|
||||
.field("tcp_listener", tcp_listener)
|
||||
.field("listener", listener)
|
||||
.field("make_service", make_service)
|
||||
.field("signal", signal)
|
||||
.field("tcp_nodelay", tcp_nodelay)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
|
||||
impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
|
||||
impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
|
||||
where
|
||||
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
|
||||
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
|
||||
L: Listener,
|
||||
L::Addr: Debug,
|
||||
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
|
||||
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
|
||||
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
|
||||
S::Future: Send,
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
|
@ -305,10 +264,9 @@ where
|
|||
|
||||
fn into_future(self) -> Self::IntoFuture {
|
||||
let Self {
|
||||
tcp_listener,
|
||||
mut listener,
|
||||
mut make_service,
|
||||
signal,
|
||||
tcp_nodelay,
|
||||
_marker: _,
|
||||
} = self;
|
||||
|
||||
|
@ -324,28 +282,17 @@ where
|
|||
let (close_tx, close_rx) = watch::channel(());
|
||||
|
||||
loop {
|
||||
let (tcp_stream, remote_addr) = tokio::select! {
|
||||
conn = tcp_accept(&tcp_listener) => {
|
||||
match conn {
|
||||
Some(conn) => conn,
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
let (io, remote_addr) = tokio::select! {
|
||||
conn = listener.accept() => conn,
|
||||
_ = signal_tx.closed() => {
|
||||
trace!("signal received, not accepting new connections");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(nodelay) = tcp_nodelay {
|
||||
if let Err(err) = tcp_stream.set_nodelay(nodelay) {
|
||||
trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
|
||||
}
|
||||
}
|
||||
let io = TokioIo::new(io);
|
||||
|
||||
let tcp_stream = TokioIo::new(tcp_stream);
|
||||
|
||||
trace!("connection {remote_addr} accepted");
|
||||
trace!("connection {remote_addr:?} accepted");
|
||||
|
||||
poll_fn(|cx| make_service.poll_ready(cx))
|
||||
.await
|
||||
|
@ -353,7 +300,7 @@ where
|
|||
|
||||
let tower_service = make_service
|
||||
.call(IncomingStream {
|
||||
tcp_stream: &tcp_stream,
|
||||
io: &io,
|
||||
remote_addr,
|
||||
})
|
||||
.await
|
||||
|
@ -372,7 +319,7 @@ where
|
|||
// CONNECT protocol needed for HTTP/2 websockets
|
||||
#[cfg(feature = "http2")]
|
||||
builder.http2().enable_connect_protocol();
|
||||
let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service);
|
||||
let conn = builder.serve_connection_with_upgrades(io, hyper_service);
|
||||
pin_mut!(conn);
|
||||
|
||||
let signal_closed = signal_tx.closed().fuse();
|
||||
|
@ -393,14 +340,12 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
trace!("connection {remote_addr} closed");
|
||||
|
||||
drop(close_rx);
|
||||
});
|
||||
}
|
||||
|
||||
drop(close_rx);
|
||||
drop(tcp_listener);
|
||||
drop(listener);
|
||||
|
||||
trace!(
|
||||
"waiting for {} task(s) to finish",
|
||||
|
@ -413,38 +358,32 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn is_connection_error(e: &io::Error) -> bool {
|
||||
matches!(
|
||||
e.kind(),
|
||||
io::ErrorKind::ConnectionRefused
|
||||
| io::ErrorKind::ConnectionAborted
|
||||
| io::ErrorKind::ConnectionReset
|
||||
)
|
||||
/// An incoming stream.
|
||||
///
|
||||
/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
|
||||
///
|
||||
/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
|
||||
#[derive(Debug)]
|
||||
pub struct IncomingStream<'a, L>
|
||||
where
|
||||
L: Listener,
|
||||
{
|
||||
io: &'a TokioIo<L::Io>,
|
||||
remote_addr: L::Addr,
|
||||
}
|
||||
|
||||
async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> {
|
||||
match listener.accept().await {
|
||||
Ok(conn) => Some(conn),
|
||||
Err(e) => {
|
||||
if is_connection_error(&e) {
|
||||
return None;
|
||||
}
|
||||
impl<L> IncomingStream<'_, L>
|
||||
where
|
||||
L: Listener,
|
||||
{
|
||||
/// Get a reference to the inner IO type.
|
||||
pub fn io(&self) -> &L::Io {
|
||||
self.io.inner()
|
||||
}
|
||||
|
||||
// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
|
||||
//
|
||||
// > A possible scenario is that the process has hit the max open files
|
||||
// > allowed, and so trying to accept a new connection will fail with
|
||||
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
|
||||
// > the application will likely close some files (or connections), and try
|
||||
// > to accept the connection again. If this option is `true`, the error
|
||||
// > will be logged at the `error` level, since it is still a big deal,
|
||||
// > and then the listener will sleep for 1 second.
|
||||
//
|
||||
// hyper allowed customizing this but axum does not.
|
||||
error!("accept error: {e}");
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
None
|
||||
}
|
||||
/// Returns the remote address that this stream is bound to.
|
||||
pub fn remote_addr(&self) -> &L::Addr {
|
||||
&self.remote_addr
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -474,68 +413,89 @@ mod private {
|
|||
}
|
||||
}
|
||||
|
||||
/// An incoming stream.
|
||||
///
|
||||
/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
|
||||
///
|
||||
/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
|
||||
#[derive(Debug)]
|
||||
pub struct IncomingStream<'a> {
|
||||
tcp_stream: &'a TokioIo<TcpStream>,
|
||||
remote_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl IncomingStream<'_> {
|
||||
/// Returns the local address that this stream is bound to.
|
||||
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
|
||||
self.tcp_stream.inner().local_addr()
|
||||
}
|
||||
|
||||
/// Returns the remote address that this stream is bound to.
|
||||
pub fn remote_addr(&self) -> SocketAddr {
|
||||
self.remote_addr
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::{
|
||||
future::{pending, IntoFuture as _},
|
||||
net::{IpAddr, Ipv4Addr},
|
||||
};
|
||||
|
||||
use axum_core::{body::Body, extract::Request};
|
||||
use http::StatusCode;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use tokio::{
|
||||
io::{self, AsyncRead, AsyncWrite},
|
||||
net::{TcpListener, UnixListener},
|
||||
};
|
||||
|
||||
use super::{serve, IncomingStream, Listener};
|
||||
use crate::{
|
||||
body::to_bytes,
|
||||
extract::connect_info::Connected,
|
||||
handler::{Handler, HandlerWithoutStateExt},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use std::{
|
||||
future::pending,
|
||||
net::{IpAddr, Ipv4Addr},
|
||||
};
|
||||
|
||||
#[allow(dead_code, unused_must_use)]
|
||||
async fn if_it_compiles_it_works() {
|
||||
#[derive(Clone, Debug)]
|
||||
struct UdsConnectInfo;
|
||||
|
||||
impl Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
|
||||
fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
let router: Router = Router::new();
|
||||
|
||||
let addr = "0.0.0.0:0";
|
||||
|
||||
// router
|
||||
serve(TcpListener::bind(addr).await.unwrap(), router.clone());
|
||||
serve(UnixListener::bind("").unwrap(), router.clone());
|
||||
|
||||
serve(
|
||||
TcpListener::bind(addr).await.unwrap(),
|
||||
router.clone().into_make_service(),
|
||||
);
|
||||
serve(
|
||||
UnixListener::bind("").unwrap(),
|
||||
router.clone().into_make_service(),
|
||||
);
|
||||
|
||||
serve(
|
||||
TcpListener::bind(addr).await.unwrap(),
|
||||
router.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
router
|
||||
.clone()
|
||||
.into_make_service_with_connect_info::<std::net::SocketAddr>(),
|
||||
);
|
||||
serve(
|
||||
UnixListener::bind("").unwrap(),
|
||||
router.into_make_service_with_connect_info::<UdsConnectInfo>(),
|
||||
);
|
||||
|
||||
// method router
|
||||
serve(TcpListener::bind(addr).await.unwrap(), get(handler));
|
||||
serve(UnixListener::bind("").unwrap(), get(handler));
|
||||
|
||||
serve(
|
||||
TcpListener::bind(addr).await.unwrap(),
|
||||
get(handler).into_make_service(),
|
||||
);
|
||||
serve(
|
||||
UnixListener::bind("").unwrap(),
|
||||
get(handler).into_make_service(),
|
||||
);
|
||||
|
||||
serve(
|
||||
TcpListener::bind(addr).await.unwrap(),
|
||||
get(handler).into_make_service_with_connect_info::<SocketAddr>(),
|
||||
get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
|
||||
);
|
||||
serve(
|
||||
UnixListener::bind("").unwrap(),
|
||||
get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
|
||||
);
|
||||
|
||||
// handler
|
||||
|
@ -543,32 +503,28 @@ mod tests {
|
|||
TcpListener::bind(addr).await.unwrap(),
|
||||
handler.into_service(),
|
||||
);
|
||||
serve(UnixListener::bind("").unwrap(), handler.into_service());
|
||||
|
||||
serve(
|
||||
TcpListener::bind(addr).await.unwrap(),
|
||||
handler.with_state(()),
|
||||
);
|
||||
serve(UnixListener::bind("").unwrap(), handler.with_state(()));
|
||||
|
||||
serve(
|
||||
TcpListener::bind(addr).await.unwrap(),
|
||||
handler.into_make_service(),
|
||||
);
|
||||
serve(UnixListener::bind("").unwrap(), handler.into_make_service());
|
||||
|
||||
serve(
|
||||
TcpListener::bind(addr).await.unwrap(),
|
||||
handler.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
|
||||
);
|
||||
|
||||
// nodelay
|
||||
serve(
|
||||
TcpListener::bind(addr).await.unwrap(),
|
||||
handler.into_service(),
|
||||
)
|
||||
.tcp_nodelay(true);
|
||||
|
||||
serve(
|
||||
TcpListener::bind(addr).await.unwrap(),
|
||||
handler.into_service(),
|
||||
)
|
||||
.with_graceful_shutdown(async { /*...*/ })
|
||||
.tcp_nodelay(true);
|
||||
UnixListener::bind("").unwrap(),
|
||||
handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
|
||||
);
|
||||
}
|
||||
|
||||
async fn handler() {}
|
||||
|
@ -613,4 +569,49 @@ mod tests {
|
|||
// Call Serve::into_future outside of a tokio context. This used to panic.
|
||||
_ = serve(listener, router).into_future();
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
async fn serving_on_custom_io_type() {
|
||||
struct ReadyListener<T>(Option<T>);
|
||||
|
||||
impl<T> Listener for ReadyListener<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
type Io = T;
|
||||
type Addr = ();
|
||||
|
||||
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
|
||||
match self.0.take() {
|
||||
Some(server) => (server, ()),
|
||||
None => std::future::pending().await,
|
||||
}
|
||||
}
|
||||
|
||||
fn local_addr(&self) -> io::Result<Self::Addr> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let (client, server) = io::duplex(1024);
|
||||
let listener = ReadyListener(Some(server));
|
||||
|
||||
let app = Router::new().route("/", get(|| async { "Hello, World!" }));
|
||||
|
||||
tokio::spawn(serve(listener, app).into_future());
|
||||
|
||||
let stream = TokioIo::new(client);
|
||||
let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap();
|
||||
tokio::spawn(conn);
|
||||
|
||||
let request = Request::builder().body(Body::empty()).unwrap();
|
||||
|
||||
let response = sender.send_request(request).await.unwrap();
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = Body::new(response.into_body());
|
||||
let body = to_bytes(body, usize::MAX).await.unwrap();
|
||||
let body = String::from_utf8(body.to_vec()).unwrap();
|
||||
assert_eq!(body, "Hello, World!");
|
||||
}
|
||||
}
|
||||
|
|
167
axum/src/serve/listener.rs
Normal file
167
axum/src/serve/listener.rs
Normal file
|
@ -0,0 +1,167 @@
|
|||
use std::{fmt, future::Future, time::Duration};
|
||||
|
||||
use tokio::{
|
||||
io::{self, AsyncRead, AsyncWrite},
|
||||
net::{TcpListener, TcpStream},
|
||||
};
|
||||
|
||||
/// Types that can listen for connections.
|
||||
pub trait Listener: Send + 'static {
|
||||
/// The listener's IO type.
|
||||
type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static;
|
||||
|
||||
/// The listener's address type.
|
||||
type Addr: Send;
|
||||
|
||||
/// Accept a new incoming connection to this listener.
|
||||
///
|
||||
/// If the underlying accept call can return an error, this function must
|
||||
/// take care of logging and retrying.
|
||||
fn accept(&mut self) -> impl Future<Output = (Self::Io, Self::Addr)> + Send;
|
||||
|
||||
/// Returns the local address that this listener is bound to.
|
||||
fn local_addr(&self) -> io::Result<Self::Addr>;
|
||||
}
|
||||
|
||||
impl Listener for TcpListener {
|
||||
type Io = TcpStream;
|
||||
type Addr = std::net::SocketAddr;
|
||||
|
||||
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
|
||||
loop {
|
||||
match Self::accept(self).await {
|
||||
Ok(tup) => return tup,
|
||||
Err(e) => handle_accept_error(e).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn local_addr(&self) -> io::Result<Self::Addr> {
|
||||
Self::local_addr(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl Listener for tokio::net::UnixListener {
|
||||
type Io = tokio::net::UnixStream;
|
||||
type Addr = tokio::net::unix::SocketAddr;
|
||||
|
||||
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
|
||||
loop {
|
||||
match Self::accept(self).await {
|
||||
Ok(tup) => return tup,
|
||||
Err(e) => handle_accept_error(e).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn local_addr(&self) -> io::Result<Self::Addr> {
|
||||
Self::local_addr(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Extensions to [`Listener`].
|
||||
pub trait ListenerExt: Listener + Sized {
|
||||
/// Run a mutable closure on every accepted `Io`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use axum::{Router, routing::get, serve::ListenerExt};
|
||||
/// use tracing::trace;
|
||||
///
|
||||
/// # async {
|
||||
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
|
||||
///
|
||||
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
|
||||
/// .await
|
||||
/// .unwrap()
|
||||
/// .tap_io(|tcp_stream| {
|
||||
/// if let Err(err) = tcp_stream.set_nodelay(true) {
|
||||
/// trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
|
||||
/// }
|
||||
/// });
|
||||
/// axum::serve(listener, router).await.unwrap();
|
||||
/// # };
|
||||
/// ```
|
||||
fn tap_io<F>(self, tap_fn: F) -> TapIo<Self, F>
|
||||
where
|
||||
F: FnMut(&mut Self::Io) + Send + 'static,
|
||||
{
|
||||
TapIo {
|
||||
listener: self,
|
||||
tap_fn,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: Listener> ListenerExt for L {}
|
||||
|
||||
/// Return type of [`ListenerExt::tap_io`].
|
||||
///
|
||||
/// See that method for details.
|
||||
pub struct TapIo<L: Listener, F> {
|
||||
listener: L,
|
||||
tap_fn: F,
|
||||
}
|
||||
|
||||
impl<L, F> fmt::Debug for TapIo<L, F>
|
||||
where
|
||||
L: Listener + fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("TapIo")
|
||||
.field("listener", &self.listener)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl<L, F> Listener for TapIo<L, F>
|
||||
where
|
||||
L: Listener,
|
||||
F: FnMut(&mut L::Io) + Send + 'static,
|
||||
{
|
||||
type Io = L::Io;
|
||||
type Addr = L::Addr;
|
||||
|
||||
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
|
||||
let (mut io, addr) = self.listener.accept().await;
|
||||
(self.tap_fn)(&mut io);
|
||||
(io, addr)
|
||||
}
|
||||
|
||||
fn local_addr(&self) -> io::Result<Self::Addr> {
|
||||
self.listener.local_addr()
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_accept_error(e: io::Error) {
|
||||
if is_connection_error(&e) {
|
||||
return;
|
||||
}
|
||||
|
||||
// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
|
||||
//
|
||||
// > A possible scenario is that the process has hit the max open files
|
||||
// > allowed, and so trying to accept a new connection will fail with
|
||||
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
|
||||
// > the application will likely close some files (or connections), and try
|
||||
// > to accept the connection again. If this option is `true`, the error
|
||||
// > will be logged at the `error` level, since it is still a big deal,
|
||||
// > and then the listener will sleep for 1 second.
|
||||
//
|
||||
// hyper allowed customizing this but axum does not.
|
||||
error!("accept error: {e}");
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
|
||||
fn is_connection_error(e: &io::Error) -> bool {
|
||||
matches!(
|
||||
e.kind(),
|
||||
io::ErrorKind::ConnectionRefused
|
||||
| io::ErrorKind::ConnectionAborted
|
||||
| io::ErrorKind::ConnectionReset
|
||||
)
|
||||
}
|
|
@ -21,17 +21,13 @@ mod unix {
|
|||
extract::connect_info::{self, ConnectInfo},
|
||||
http::{Method, Request, StatusCode},
|
||||
routing::get,
|
||||
serve::IncomingStream,
|
||||
Router,
|
||||
};
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::body::Incoming;
|
||||
use hyper_util::{
|
||||
rt::{TokioExecutor, TokioIo},
|
||||
server,
|
||||
};
|
||||
use std::{convert::Infallible, path::PathBuf, sync::Arc};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use tokio::net::{unix::UCred, UnixListener, UnixStream};
|
||||
use tower::Service;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
pub async fn server() {
|
||||
|
@ -52,33 +48,11 @@ mod unix {
|
|||
|
||||
let uds = UnixListener::bind(path.clone()).unwrap();
|
||||
tokio::spawn(async move {
|
||||
let app = Router::new().route("/", get(handler));
|
||||
let app = Router::new()
|
||||
.route("/", get(handler))
|
||||
.into_make_service_with_connect_info::<UdsConnectInfo>();
|
||||
|
||||
let mut make_service = app.into_make_service_with_connect_info::<UdsConnectInfo>();
|
||||
|
||||
// See https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs for
|
||||
// more details about this setup
|
||||
loop {
|
||||
let (socket, _remote_addr) = uds.accept().await.unwrap();
|
||||
|
||||
let tower_service = unwrap_infallible(make_service.call(&socket).await);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let socket = TokioIo::new(socket);
|
||||
|
||||
let hyper_service =
|
||||
hyper::service::service_fn(move |request: Request<Incoming>| {
|
||||
tower_service.clone().call(request)
|
||||
});
|
||||
|
||||
if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new())
|
||||
.serve_connection_with_upgrades(socket, hyper_service)
|
||||
.await
|
||||
{
|
||||
eprintln!("failed to serve connection: {err:#}");
|
||||
}
|
||||
});
|
||||
}
|
||||
axum::serve(uds, app).await.unwrap();
|
||||
});
|
||||
|
||||
let stream = TokioIo::new(UnixStream::connect(path).await.unwrap());
|
||||
|
@ -117,22 +91,14 @@ mod unix {
|
|||
peer_cred: UCred,
|
||||
}
|
||||
|
||||
impl connect_info::Connected<&UnixStream> for UdsConnectInfo {
|
||||
fn connect_info(target: &UnixStream) -> Self {
|
||||
let peer_addr = target.peer_addr().unwrap();
|
||||
let peer_cred = target.peer_cred().unwrap();
|
||||
|
||||
impl connect_info::Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
|
||||
fn connect_info(stream: IncomingStream<'_, UnixListener>) -> Self {
|
||||
let peer_addr = stream.io().peer_addr().unwrap();
|
||||
let peer_cred = stream.io().peer_cred().unwrap();
|
||||
Self {
|
||||
peer_addr: Arc::new(peer_addr),
|
||||
peer_cred,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
|
||||
match result {
|
||||
Ok(value) => value,
|
||||
Err(err) => match err {},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue