diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index ee116a41..4ff55ae4 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -328,3 +328,21 @@ async fn merge_router_with_fallback_into_empty() { assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } + +#[crate::test] +async fn state_isnt_cloned_too_much_with_fallback() { + let state = CountingCloneableState::new(); + + let app = Router::new() + .fallback(|_: State| async {}) + .with_state(state.clone()); + + let client = TestClient::new(app); + + // ignore clones made during setup + state.setup_done(); + + client.get("/does-not-exist").await; + + assert_eq!(state.count(), 4); +} diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 3dad0728..7e91a977 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -16,6 +16,7 @@ use crate::{ BoxError, Extension, Json, Router, ServiceExt, }; use axum_core::extract::Request; +use counting_cloneable_state::CountingCloneableState; use futures_util::stream::StreamExt; use http::{ header::{ALLOW, CONTENT_LENGTH, HOST}, @@ -27,7 +28,7 @@ use serde_json::json; use std::{ convert::Infallible, future::{ready, IntoFuture, Ready}, - sync::atomic::{AtomicBool, AtomicUsize, Ordering}, + sync::atomic::{AtomicUsize, Ordering}, task::{Context, Poll}, time::Duration, }; @@ -905,54 +906,20 @@ fn test_path_for_nested_route() { #[crate::test] async fn state_isnt_cloned_too_much() { - static SETUP_DONE: AtomicBool = AtomicBool::new(false); - static COUNT: AtomicUsize = AtomicUsize::new(0); - - struct AppState; - - impl Clone for AppState { - fn clone(&self) -> Self { - #[rustversion::since(1.66)] - #[track_caller] - fn count() { - if SETUP_DONE.load(Ordering::SeqCst) { - let bt = std::backtrace::Backtrace::force_capture(); - let bt = bt - .to_string() - .lines() - .filter(|line| line.contains("axum") || line.contains("./src")) - .collect::>() - .join("\n"); - println!("AppState::Clone:\n===============\n{bt}\n"); - COUNT.fetch_add(1, Ordering::SeqCst); - } - } - - #[rustversion::not(since(1.66))] - fn count() { - if SETUP_DONE.load(Ordering::SeqCst) { - COUNT.fetch_add(1, Ordering::SeqCst); - } - } - - count(); - - Self - } - } + let state = CountingCloneableState::new(); let app = Router::new() - .route("/", get(|_: State| async {})) - .with_state(AppState); + .route("/", get(|_: State| async {})) + .with_state(state.clone()); let client = TestClient::new(app); // ignore clones made during setup - SETUP_DONE.store(true, Ordering::SeqCst); + state.setup_done(); client.get("/").await; - assert_eq!(COUNT.load(Ordering::SeqCst), 3); + assert_eq!(state.count(), 3); } #[crate::test] diff --git a/axum/src/test_helpers/counting_cloneable_state.rs b/axum/src/test_helpers/counting_cloneable_state.rs new file mode 100644 index 00000000..762d5ce9 --- /dev/null +++ b/axum/src/test_helpers/counting_cloneable_state.rs @@ -0,0 +1,52 @@ +use std::sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering}, + Arc, +}; + +pub(crate) struct CountingCloneableState { + state: Arc, +} + +struct InnerState { + setup_done: AtomicBool, + count: AtomicUsize, +} + +impl CountingCloneableState { + pub(crate) fn new() -> Self { + let inner_state = InnerState { + setup_done: AtomicBool::new(false), + count: AtomicUsize::new(0), + }; + CountingCloneableState { + state: Arc::new(inner_state), + } + } + + pub(crate) fn setup_done(&self) { + self.state.setup_done.store(true, Ordering::SeqCst); + } + + pub(crate) fn count(&self) -> usize { + self.state.count.load(Ordering::SeqCst) + } +} + +impl Clone for CountingCloneableState { + fn clone(&self) -> Self { + let state = self.state.clone(); + if state.setup_done.load(Ordering::SeqCst) { + let bt = std::backtrace::Backtrace::force_capture(); + let bt = bt + .to_string() + .lines() + .filter(|line| line.contains("axum") || line.contains("./src")) + .collect::>() + .join("\n"); + println!("AppState::Clone:\n===============\n{bt}\n"); + state.count.fetch_add(1, Ordering::SeqCst); + } + + CountingCloneableState { state } + } +} diff --git a/axum/src/test_helpers/mod.rs b/axum/src/test_helpers/mod.rs index c6ae1bff..5c29f78d 100644 --- a/axum/src/test_helpers/mod.rs +++ b/axum/src/test_helpers/mod.rs @@ -7,6 +7,8 @@ pub(crate) use self::test_client::*; pub(crate) mod tracing_helpers; +pub(crate) mod counting_cloneable_state; + pub(crate) fn assert_send() {} pub(crate) fn assert_sync() {}