Clone state a bit less (#1837)

This commit is contained in:
David Pedersen 2023-03-12 15:43:22 +01:00 committed by GitHub
parent f65fa22f34
commit fe9c4a0b5b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 79 additions and 33 deletions

View file

@ -8,8 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased # Unreleased
- **fixed:** Don't require `S: Debug` for `impl Debug for Router<S>` ([#1836]) - **fixed:** Don't require `S: Debug` for `impl Debug for Router<S>` ([#1836])
- **fixed:** Clone state a bit less when handling requests ([#1837])
[#1836]: https://github.com/tokio-rs/axum/pull/1836 [#1836]: https://github.com/tokio-rs/axum/pull/1836
[#1837]: https://github.com/tokio-rs/axum/pull/1837
# 0.6.10 (03. March, 2023) # 0.6.10 (03. March, 2023)

View file

@ -107,6 +107,7 @@ futures = "0.3"
quickcheck = "1.0" quickcheck = "1.0"
quickcheck_macros = "1.0" quickcheck_macros = "1.0"
reqwest = { version = "0.11.14", default-features = false, features = ["json", "stream", "multipart"] } reqwest = { version = "0.11.14", default-features = false, features = ["json", "stream", "multipart"] }
rustversion = "1.0.9"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
time = { version = "0.3", features = ["serde-human-readable"] } time = { version = "0.3", features = ["serde-human-readable"] }

View file

@ -447,7 +447,29 @@ where
} }
} }
self.call_route(match_, req, state) let id = *match_.value;
#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
req.extensions_mut(),
);
url_params::insert_url_params(req.extensions_mut(), match_.params);
let endpont = self
.routes
.get_mut(&id)
.expect("no route for id. This is a bug in axum. Please file an issue");
match endpont {
Endpoint::MethodRouter(method_router) => {
method_router.call_with_state(req, state)
}
Endpoint::Route(route) => route.call(req),
Endpoint::NestedRouter(router) => router.clone().call_with_state(req, state),
}
} }
Err( Err(
MatchError::NotFound MatchError::NotFound
@ -469,37 +491,6 @@ where
} }
} }
#[inline]
fn call_route(
&self,
match_: matchit::Match<&RouteId>,
mut req: Request<B>,
state: S,
) -> RouteFuture<B, Infallible> {
let id = *match_.value;
#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
req.extensions_mut(),
);
url_params::insert_url_params(req.extensions_mut(), match_.params);
let endpont = self
.routes
.get(&id)
.expect("no route for id. This is a bug in axum. Please file an issue")
.clone();
match endpont {
Endpoint::MethodRouter(mut method_router) => method_router.call_with_state(req, state),
Endpoint::Route(mut route) => route.call(req),
Endpoint::NestedRouter(router) => router.call_with_state(req, state),
}
}
fn next_route_id(&mut self) -> RouteId { fn next_route_id(&mut self) -> RouteId {
let next_id = self let next_id = self
.prev_route_id .prev_route_id

View file

@ -15,7 +15,7 @@ use serde_json::json;
use std::{ use std::{
convert::Infallible, convert::Infallible,
future::{ready, Ready}, future::{ready, Ready},
sync::atomic::{AtomicUsize, Ordering}, sync::atomic::{AtomicBool, AtomicUsize, Ordering},
task::{Context, Poll}, task::{Context, Poll},
time::Duration, time::Duration,
}; };
@ -812,3 +812,55 @@ fn method_router_fallback_with_state() {
.fallback(get(fallback).fallback(not_found)) .fallback(get(fallback).fallback(not_found))
.with_state(state); .with_state(state);
} }
#[crate::test]
async fn state_isnt_cloned_too_much() {
static SETUP_DONE: AtomicBool = AtomicBool::new(false);
static COUNT: AtomicUsize = AtomicUsize::new(0);
struct AppState;
impl Clone for AppState {
fn clone(&self) -> Self {
#[rustversion::since(1.65)]
#[track_caller]
fn count() {
if SETUP_DONE.load(Ordering::SeqCst) {
let bt = std::backtrace::Backtrace::force_capture();
let bt = bt
.to_string()
.lines()
.filter(|line| line.contains("axum") || line.contains("./src"))
.collect::<Vec<_>>()
.join("\n");
println!("AppState::Clone:\n===============\n{}\n", bt);
COUNT.fetch_add(1, Ordering::SeqCst);
}
}
#[rustversion::not(since(1.65))]
fn count() {
if SETUP_DONE.load(Ordering::SeqCst) {
COUNT.fetch_add(1, Ordering::SeqCst);
}
}
count();
Self
}
}
let app = Router::new()
.route("/", get(|_: State<AppState>| async {}))
.with_state(AppState);
let client = TestClient::new(app);
// ignore clones made during setup
SETUP_DONE.store(true, Ordering::SeqCst);
client.get("/").send().await;
assert_eq!(COUNT.load(Ordering::SeqCst), 4);
}