From fe9c4a0b5ba9c86c8b7913c1390962c73d32621c Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sun, 12 Mar 2023 15:43:22 +0100 Subject: [PATCH] Clone state a bit less (#1837) --- axum/CHANGELOG.md | 2 ++ axum/Cargo.toml | 1 + axum/src/routing/mod.rs | 55 +++++++++++++++-------------------- axum/src/routing/tests/mod.rs | 54 +++++++++++++++++++++++++++++++++- 4 files changed, 79 insertions(+), 33 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 7bb85253..64c6d7dc 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -8,8 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased - **fixed:** Don't require `S: Debug` for `impl Debug for Router` ([#1836]) +- **fixed:** Clone state a bit less when handling requests ([#1837]) [#1836]: https://github.com/tokio-rs/axum/pull/1836 +[#1837]: https://github.com/tokio-rs/axum/pull/1837 # 0.6.10 (03. March, 2023) diff --git a/axum/Cargo.toml b/axum/Cargo.toml index bf05310e..e9460852 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -107,6 +107,7 @@ futures = "0.3" quickcheck = "1.0" quickcheck_macros = "1.0" reqwest = { version = "0.11.14", default-features = false, features = ["json", "stream", "multipart"] } +rustversion = "1.0.9" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" time = { version = "0.3", features = ["serde-human-readable"] } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 4f028389..7322401c 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -447,7 +447,29 @@ where } } - self.call_route(match_, req, state) + let id = *match_.value; + + #[cfg(feature = "matched-path")] + crate::extract::matched_path::set_matched_path_for_request( + id, + &self.node.route_id_to_path, + req.extensions_mut(), + ); + + url_params::insert_url_params(req.extensions_mut(), match_.params); + + let endpont = self + .routes + .get_mut(&id) + .expect("no route for id. This is a bug in axum. Please file an issue"); + + match endpont { + Endpoint::MethodRouter(method_router) => { + method_router.call_with_state(req, state) + } + Endpoint::Route(route) => route.call(req), + Endpoint::NestedRouter(router) => router.clone().call_with_state(req, state), + } } Err( MatchError::NotFound @@ -469,37 +491,6 @@ where } } - #[inline] - fn call_route( - &self, - match_: matchit::Match<&RouteId>, - mut req: Request, - state: S, - ) -> RouteFuture { - let id = *match_.value; - - #[cfg(feature = "matched-path")] - crate::extract::matched_path::set_matched_path_for_request( - id, - &self.node.route_id_to_path, - req.extensions_mut(), - ); - - url_params::insert_url_params(req.extensions_mut(), match_.params); - - let endpont = self - .routes - .get(&id) - .expect("no route for id. This is a bug in axum. Please file an issue") - .clone(); - - match endpont { - Endpoint::MethodRouter(mut method_router) => method_router.call_with_state(req, state), - Endpoint::Route(mut route) => route.call(req), - Endpoint::NestedRouter(router) => router.call_with_state(req, state), - } - } - fn next_route_id(&mut self) -> RouteId { let next_id = self .prev_route_id diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 33770e33..70fccaf0 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -15,7 +15,7 @@ use serde_json::json; use std::{ convert::Infallible, future::{ready, Ready}, - sync::atomic::{AtomicUsize, Ordering}, + sync::atomic::{AtomicBool, AtomicUsize, Ordering}, task::{Context, Poll}, time::Duration, }; @@ -812,3 +812,55 @@ fn method_router_fallback_with_state() { .fallback(get(fallback).fallback(not_found)) .with_state(state); } + +#[crate::test] +async fn state_isnt_cloned_too_much() { + static SETUP_DONE: AtomicBool = AtomicBool::new(false); + static COUNT: AtomicUsize = AtomicUsize::new(0); + + struct AppState; + + impl Clone for AppState { + fn clone(&self) -> Self { + #[rustversion::since(1.65)] + #[track_caller] + fn count() { + if SETUP_DONE.load(Ordering::SeqCst) { + let bt = std::backtrace::Backtrace::force_capture(); + let bt = bt + .to_string() + .lines() + .filter(|line| line.contains("axum") || line.contains("./src")) + .collect::>() + .join("\n"); + println!("AppState::Clone:\n===============\n{}\n", bt); + COUNT.fetch_add(1, Ordering::SeqCst); + } + } + + #[rustversion::not(since(1.65))] + fn count() { + if SETUP_DONE.load(Ordering::SeqCst) { + COUNT.fetch_add(1, Ordering::SeqCst); + } + } + + count(); + + Self + } + } + + let app = Router::new() + .route("/", get(|_: State| async {})) + .with_state(AppState); + + let client = TestClient::new(app); + + // ignore clones made during setup + SETUP_DONE.store(true, Ordering::SeqCst); + + client.get("/").send().await; + + assert_eq!(COUNT.load(Ordering::SeqCst), 4); +}