Remove duplication in serving with and without graceful shutdown (#2803)

This commit is contained in:
Jonas Platte 2024-09-27 21:16:54 +00:00 committed by GitHub
parent 4b48f308c3
commit 5d8541de6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 67 additions and 79 deletions

View file

@ -966,7 +966,7 @@ async fn logging_rejections() {
rejection_type: String,
}
let events = capture_tracing::<RejectionEvent, _, _>(|| async {
let events = capture_tracing::<RejectionEvent, _>(|| async {
let app = Router::new()
.route("/extension", get(|_: Extension<Infallible>| async {}))
.route("/string", post(|_: String| async {}));
@ -987,6 +987,7 @@ async fn logging_rejections() {
StatusCode::BAD_REQUEST,
);
})
.with_filter("axum::rejection=trace")
.await;
assert_eq!(

View file

@ -213,61 +213,8 @@ where
type IntoFuture = private::ServeFuture;
fn into_future(self) -> Self::IntoFuture {
private::ServeFuture(Box::pin(async move {
let Self {
tcp_listener,
mut make_service,
tcp_nodelay,
_marker: _,
} = self;
loop {
let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await {
Some(conn) => conn,
None => continue,
};
if let Some(nodelay) = tcp_nodelay {
if let Err(err) = tcp_stream.set_nodelay(nodelay) {
trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
}
}
let tcp_stream = TokioIo::new(tcp_stream);
poll_fn(|cx| make_service.poll_ready(cx))
.await
.unwrap_or_else(|err| match err {});
let tower_service = make_service
.call(IncomingStream {
tcp_stream: &tcp_stream,
remote_addr,
})
.await
.unwrap_or_else(|err| match err {})
.map_request(|req: Request<Incoming>| req.map(Body::new));
let hyper_service = TowerToHyperService::new(tower_service);
tokio::spawn(async move {
match Builder::new(TokioExecutor::new())
// upgrades needed for websockets
.serve_connection_with_upgrades(tcp_stream, hyper_service)
.await
{
Ok(()) => {}
Err(_err) => {
// This error only appears when the client doesn't send a request and
// terminate the connection.
//
// If client sends one request then terminate connection whenever, it doesn't
// appear.
}
}
});
}
}))
self.with_graceful_shutdown(std::future::pending())
.into_future()
}
}

View file

@ -1,7 +1,14 @@
use crate::util::AxumMutex;
use std::{future::Future, io, sync::Arc};
use std::{
future::{Future, IntoFuture},
io,
marker::PhantomData,
pin::Pin,
sync::Arc,
};
use serde::{de::DeserializeOwned, Deserialize};
use tracing::instrument::WithSubscriber;
use tracing_subscriber::prelude::*;
use tracing_subscriber::{filter::Targets, fmt::MakeWriter};
@ -14,36 +21,69 @@ pub(crate) struct TracingEvent<T> {
}
/// Run an async closure and capture the tracing output it produces.
pub(crate) async fn capture_tracing<T, F, Fut>(f: F) -> Vec<TracingEvent<T>>
pub(crate) fn capture_tracing<T, F>(f: F) -> CaptureTracing<T, F>
where
F: Fn() -> Fut,
Fut: Future,
T: DeserializeOwned,
{
let (make_writer, handle) = TestMakeWriter::new();
CaptureTracing {
f,
filter: None,
_phantom: PhantomData,
}
}
let subscriber = tracing_subscriber::registry().with(
tracing_subscriber::fmt::layer()
.with_writer(make_writer)
.with_target(true)
.without_time()
.with_ansi(false)
.json()
.flatten_event(false)
.with_filter("axum=trace".parse::<Targets>().unwrap()),
);
pub(crate) struct CaptureTracing<T, F> {
f: F,
filter: Option<Targets>,
_phantom: PhantomData<fn() -> T>,
}
let guard = tracing::subscriber::set_default(subscriber);
impl<T, F> CaptureTracing<T, F> {
pub(crate) fn with_filter(mut self, filter_string: &str) -> Self {
self.filter = Some(filter_string.parse().unwrap());
self
}
}
f().await;
impl<T, F, Fut> IntoFuture for CaptureTracing<T, F>
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future + Send,
T: DeserializeOwned,
{
type Output = Vec<TracingEvent<T>>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
drop(guard);
fn into_future(self) -> Self::IntoFuture {
let Self { f, filter, .. } = self;
Box::pin(async move {
let (make_writer, handle) = TestMakeWriter::new();
handle
.take()
.lines()
.map(|line| serde_json::from_str(line).unwrap())
.collect()
let filter = filter.unwrap_or_else(|| "axum=trace".parse().unwrap());
let subscriber = tracing_subscriber::registry().with(
tracing_subscriber::fmt::layer()
.with_writer(make_writer)
.with_target(true)
.without_time()
.with_ansi(false)
.json()
.flatten_event(false)
.with_filter(filter),
);
let guard = tracing::subscriber::set_default(subscriber);
f().with_current_subscriber().await;
drop(guard);
handle
.take()
.lines()
.map(|line| serde_json::from_str(line).unwrap())
.collect()
})
}
}
struct TestMakeWriter {