From 5d8541de6e5fc0970ea7e29ee3605cb730ac0e71 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Fri, 27 Sep 2024 21:16:54 +0000 Subject: [PATCH] Remove duplication in serving with and without graceful shutdown (#2803) --- axum/src/routing/tests/mod.rs | 3 +- axum/src/serve.rs | 57 +--------------- axum/src/test_helpers/tracing_helpers.rs | 86 +++++++++++++++++------- 3 files changed, 67 insertions(+), 79 deletions(-) diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 144c870d..e9ab56e1 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -966,7 +966,7 @@ async fn logging_rejections() { rejection_type: String, } - let events = capture_tracing::(|| async { + let events = capture_tracing::(|| async { let app = Router::new() .route("/extension", get(|_: Extension| async {})) .route("/string", post(|_: String| async {})); @@ -987,6 +987,7 @@ async fn logging_rejections() { StatusCode::BAD_REQUEST, ); }) + .with_filter("axum::rejection=trace") .await; assert_eq!( diff --git a/axum/src/serve.rs b/axum/src/serve.rs index e2b974cf..1ba9a145 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -213,61 +213,8 @@ where type IntoFuture = private::ServeFuture; fn into_future(self) -> Self::IntoFuture { - private::ServeFuture(Box::pin(async move { - let Self { - tcp_listener, - mut make_service, - tcp_nodelay, - _marker: _, - } = self; - - loop { - let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await { - Some(conn) => conn, - None => continue, - }; - - if let Some(nodelay) = tcp_nodelay { - if let Err(err) = tcp_stream.set_nodelay(nodelay) { - trace!("failed to set TCP_NODELAY on incoming connection: {err:#}"); - } - } - - let tcp_stream = TokioIo::new(tcp_stream); - - poll_fn(|cx| make_service.poll_ready(cx)) - .await - .unwrap_or_else(|err| match err {}); - - let tower_service = make_service - .call(IncomingStream { - tcp_stream: &tcp_stream, - remote_addr, - }) - .await - .unwrap_or_else(|err| match err {}) - .map_request(|req: Request| req.map(Body::new)); - - let hyper_service = TowerToHyperService::new(tower_service); - - tokio::spawn(async move { - match Builder::new(TokioExecutor::new()) - // upgrades needed for websockets - .serve_connection_with_upgrades(tcp_stream, hyper_service) - .await - { - Ok(()) => {} - Err(_err) => { - // This error only appears when the client doesn't send a request and - // terminate the connection. - // - // If client sends one request then terminate connection whenever, it doesn't - // appear. - } - } - }); - } - })) + self.with_graceful_shutdown(std::future::pending()) + .into_future() } } diff --git a/axum/src/test_helpers/tracing_helpers.rs b/axum/src/test_helpers/tracing_helpers.rs index 2240717e..96abe510 100644 --- a/axum/src/test_helpers/tracing_helpers.rs +++ b/axum/src/test_helpers/tracing_helpers.rs @@ -1,7 +1,14 @@ use crate::util::AxumMutex; -use std::{future::Future, io, sync::Arc}; +use std::{ + future::{Future, IntoFuture}, + io, + marker::PhantomData, + pin::Pin, + sync::Arc, +}; use serde::{de::DeserializeOwned, Deserialize}; +use tracing::instrument::WithSubscriber; use tracing_subscriber::prelude::*; use tracing_subscriber::{filter::Targets, fmt::MakeWriter}; @@ -14,36 +21,69 @@ pub(crate) struct TracingEvent { } /// Run an async closure and capture the tracing output it produces. -pub(crate) async fn capture_tracing(f: F) -> Vec> +pub(crate) fn capture_tracing(f: F) -> CaptureTracing where - F: Fn() -> Fut, - Fut: Future, T: DeserializeOwned, { - let (make_writer, handle) = TestMakeWriter::new(); + CaptureTracing { + f, + filter: None, + _phantom: PhantomData, + } +} - let subscriber = tracing_subscriber::registry().with( - tracing_subscriber::fmt::layer() - .with_writer(make_writer) - .with_target(true) - .without_time() - .with_ansi(false) - .json() - .flatten_event(false) - .with_filter("axum=trace".parse::().unwrap()), - ); +pub(crate) struct CaptureTracing { + f: F, + filter: Option, + _phantom: PhantomData T>, +} - let guard = tracing::subscriber::set_default(subscriber); +impl CaptureTracing { + pub(crate) fn with_filter(mut self, filter_string: &str) -> Self { + self.filter = Some(filter_string.parse().unwrap()); + self + } +} - f().await; +impl IntoFuture for CaptureTracing +where + F: Fn() -> Fut + Send + Sync + 'static, + Fut: Future + Send, + T: DeserializeOwned, +{ + type Output = Vec>; + type IntoFuture = Pin + Send>>; - drop(guard); + fn into_future(self) -> Self::IntoFuture { + let Self { f, filter, .. } = self; + Box::pin(async move { + let (make_writer, handle) = TestMakeWriter::new(); - handle - .take() - .lines() - .map(|line| serde_json::from_str(line).unwrap()) - .collect() + let filter = filter.unwrap_or_else(|| "axum=trace".parse().unwrap()); + let subscriber = tracing_subscriber::registry().with( + tracing_subscriber::fmt::layer() + .with_writer(make_writer) + .with_target(true) + .without_time() + .with_ansi(false) + .json() + .flatten_event(false) + .with_filter(filter), + ); + + let guard = tracing::subscriber::set_default(subscriber); + + f().with_current_subscriber().await; + + drop(guard); + + handle + .take() + .lines() + .map(|line| serde_json::from_str(line).unwrap()) + .collect() + }) + } } struct TestMakeWriter {