From 28d8d9b747efa9adc24aafb404b45790a1e85269 Mon Sep 17 00:00:00 2001 From: Tobias Bieniek Date: Fri, 27 Dec 2024 18:24:50 +0100 Subject: [PATCH] Refactor `TestClient` usage (#3121) --- Cargo.lock | 1 + axum-core/Cargo.toml | 3 +- axum-core/src/extract/request_parts.rs | 25 +++++++++++++++++ axum-core/src/lib.rs | 3 ++ axum-extra/Cargo.toml | 2 +- axum-extra/src/lib.rs | 14 +--------- axum/Cargo.toml | 8 ++++++ axum/src/extract/mod.rs | 1 - axum/src/extract/request_parts.rs | 27 ------------------ axum/src/lib.rs | 5 ++-- axum/src/test_helpers/mod.rs | 6 +++- axum/src/test_helpers/test_client.rs | 38 +++++++++++++------------- 12 files changed, 68 insertions(+), 65 deletions(-) delete mode 100644 axum/src/extract/request_parts.rs diff --git a/Cargo.lock b/Cargo.lock index 71471050..876912df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -376,6 +376,7 @@ version = "0.5.0-rc.1" dependencies = [ "axum", "axum-extra", + "axum-macros", "bytes", "futures-util", "http 1.2.0", diff --git a/axum-core/Cargo.toml b/axum-core/Cargo.toml index f17c8d9b..d67fb18c 100644 --- a/axum-core/Cargo.toml +++ b/axum-core/Cargo.toml @@ -35,8 +35,9 @@ tower-http = { version = "0.6.0", optional = true, features = ["limit"] } tracing = { version = "0.1.37", default-features = false, optional = true } [dev-dependencies] -axum = { path = "../axum" } +axum = { path = "../axum", features = ["__private"] } axum-extra = { path = "../axum-extra", features = ["typed-header"] } +axum-macros = { path = "../axum-macros", features = ["__private"] } futures-util = { version = "0.3", default-features = false, features = ["alloc"] } hyper = "1.0.0" tokio = { version = "1.25.0", features = ["macros"] } diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index 70fc021a..ffa358c6 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -166,3 +166,28 @@ where Ok(req.into_body()) } } + +#[cfg(test)] +mod tests { + use axum::{extract::Extension, routing::get, test_helpers::*, Router}; + use http::{Method, StatusCode}; + + #[crate::test] + async fn extract_request_parts() { + #[derive(Clone)] + struct Ext; + + async fn handler(parts: http::request::Parts) { + assert_eq!(parts.method, Method::GET); + assert_eq!(parts.uri, "/"); + assert_eq!(parts.version, http::Version::HTTP_11); + assert_eq!(parts.headers["x-foo"], "123"); + parts.extensions.get::().unwrap(); + } + + let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext))); + + let res = client.get("/").header("x-foo", "123").await; + assert_eq!(res.status(), StatusCode::OK); + } +} diff --git a/axum-core/src/lib.rs b/axum-core/src/lib.rs index 60f0dcfc..1fd45347 100644 --- a/axum-core/src/lib.rs +++ b/axum-core/src/lib.rs @@ -31,3 +31,6 @@ pub mod response; pub type BoxError = Box; pub use self::ext_traits::{request::RequestExt, request_parts::RequestPartsExt}; + +#[cfg(test)] +use axum_macros::__private_axum_test as test; diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index 16c6d38b..e52a109d 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -76,7 +76,7 @@ tracing = { version = "0.1.37", default-features = false, optional = true } typed-json = { version = "0.1.1", optional = true } [dev-dependencies] -axum = { path = "../axum", features = ["macros"] } +axum = { path = "../axum", features = ["macros", "__private"] } axum-macros = { path = "../axum-macros", features = ["__private"] } hyper = "1.0.0" reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] } diff --git a/axum-extra/src/lib.rs b/axum-extra/src/lib.rs index e672d7ad..67d50137 100644 --- a/axum-extra/src/lib.rs +++ b/axum-extra/src/lib.rs @@ -80,16 +80,4 @@ pub mod __private { use axum_macros::__private_axum_test as test; #[cfg(test)] -#[allow(unused_imports)] -pub(crate) mod test_helpers { - use axum::{extract::Request, response::Response, serve}; - - mod test_client { - #![allow(dead_code)] - include!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/../axum/src/test_helpers/test_client.rs" - )); - } - pub(crate) use self::test_client::*; -} +pub(crate) use axum::test_helpers; diff --git a/axum/Cargo.toml b/axum/Cargo.toml index 78c6ccf8..8110f052 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -45,6 +45,10 @@ __private_docs = [ "tower/full", "dep:tower-http", ] +# This feature is used to enable private test helper usage +# in `axum-core` and `axum-extra`. +__private = ["tokio", "http1", "dep:reqwest"] + [dependencies] axum-core = { path = "../axum-core", version = "0.5.0-rc.1" } bytes = "1.0" @@ -72,6 +76,7 @@ form_urlencoded = { version = "1.1.0", optional = true } hyper = { version = "1.1.0", optional = true } hyper-util = { version = "0.1.3", features = ["tokio", "server", "service"], optional = true } multer = { version = "3.0.0", optional = true } +reqwest = { version = "0.12", optional = true, default-features = false, features = ["json", "stream", "multipart"] } serde_json = { version = "1.0", features = ["raw_value"], optional = true } serde_path_to_error = { version = "0.1.8", optional = true } serde_urlencoded = { version = "0.7", optional = true } @@ -214,6 +219,9 @@ allowed = [ "http_body", "serde", "tokio", + + # for the `__private` feature + "reqwest", ] [[bench]] diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index f63a7506..8a0af3da 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -15,7 +15,6 @@ pub(crate) mod nested_path; mod original_uri; mod raw_form; mod raw_query; -mod request_parts; mod state; #[doc(inline)] diff --git a/axum/src/extract/request_parts.rs b/axum/src/extract/request_parts.rs deleted file mode 100644 index 0fb3247b..00000000 --- a/axum/src/extract/request_parts.rs +++ /dev/null @@ -1,27 +0,0 @@ -/// This module contains the tests for the `impl FromRequestParts for Parts` -/// implementation in the `axum-core` crate. The tests cannot be moved there -/// because we don't have access to the `TestClient` and `Router` types there. -#[cfg(test)] -mod tests { - use crate::{extract::Extension, routing::get, test_helpers::*, Router}; - use http::{Method, StatusCode}; - - #[crate::test] - async fn extract_request_parts() { - #[derive(Clone)] - struct Ext; - - async fn handler(parts: http::request::Parts) { - assert_eq!(parts.method, Method::GET); - assert_eq!(parts.uri, "/"); - assert_eq!(parts.version, http::Version::HTTP_11); - assert_eq!(parts.headers["x-foo"], "123"); - parts.extensions.get::().unwrap(); - } - - let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext))); - - let res = client.get("/").header("x-foo", "123").await; - assert_eq!(res.status(), StatusCode::OK); - } -} diff --git a/axum/src/lib.rs b/axum/src/lib.rs index 09e0edf0..217b4cc9 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -446,8 +446,9 @@ pub mod routing; #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] pub mod serve; -#[cfg(test)] -mod test_helpers; +#[cfg(any(test, feature = "__private"))] +#[allow(missing_docs, missing_debug_implementations, clippy::print_stdout)] +pub mod test_helpers; #[doc(no_inline)] pub use http; diff --git a/axum/src/test_helpers/mod.rs b/axum/src/test_helpers/mod.rs index 5c29f78d..06f9101c 100644 --- a/axum/src/test_helpers/mod.rs +++ b/axum/src/test_helpers/mod.rs @@ -3,13 +3,17 @@ use crate::{extract::Request, response::Response, serve}; mod test_client; -pub(crate) use self::test_client::*; +pub use self::test_client::*; +#[cfg(test)] pub(crate) mod tracing_helpers; +#[cfg(test)] pub(crate) mod counting_cloneable_state; +#[cfg(test)] pub(crate) fn assert_send() {} +#[cfg(test)] pub(crate) fn assert_sync() {} #[allow(dead_code)] diff --git a/axum/src/test_helpers/test_client.rs b/axum/src/test_helpers/test_client.rs index 7a177aa5..3981db5a 100644 --- a/axum/src/test_helpers/test_client.rs +++ b/axum/src/test_helpers/test_client.rs @@ -29,13 +29,13 @@ where addr } -pub(crate) struct TestClient { +pub struct TestClient { client: reqwest::Client, addr: SocketAddr, } impl TestClient { - pub(crate) fn new(svc: S) -> Self + pub fn new(svc: S) -> Self where S: Service + Clone + Send + 'static, S::Future: Send, @@ -50,55 +50,55 @@ impl TestClient { TestClient { client, addr } } - pub(crate) fn get(&self, url: &str) -> RequestBuilder { + pub fn get(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.get(format!("http://{}{url}", self.addr)), } } - pub(crate) fn head(&self, url: &str) -> RequestBuilder { + pub fn head(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.head(format!("http://{}{url}", self.addr)), } } - pub(crate) fn post(&self, url: &str) -> RequestBuilder { + pub fn post(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.post(format!("http://{}{url}", self.addr)), } } #[allow(dead_code)] - pub(crate) fn put(&self, url: &str) -> RequestBuilder { + pub fn put(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.put(format!("http://{}{url}", self.addr)), } } #[allow(dead_code)] - pub(crate) fn patch(&self, url: &str) -> RequestBuilder { + pub fn patch(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.patch(format!("http://{}{url}", self.addr)), } } #[allow(dead_code)] - pub(crate) fn server_port(&self) -> u16 { + pub fn server_port(&self) -> u16 { self.addr.port() } } -pub(crate) struct RequestBuilder { +pub struct RequestBuilder { builder: reqwest::RequestBuilder, } impl RequestBuilder { - pub(crate) fn body(mut self, body: impl Into) -> Self { + pub fn body(mut self, body: impl Into) -> Self { self.builder = self.builder.body(body); self } - pub(crate) fn json(mut self, json: &T) -> Self + pub fn json(mut self, json: &T) -> Self where T: serde::Serialize, { @@ -106,7 +106,7 @@ impl RequestBuilder { self } - pub(crate) fn header(mut self, key: K, value: V) -> Self + pub fn header(mut self, key: K, value: V) -> Self where HeaderName: TryFrom, >::Error: Into, @@ -118,7 +118,7 @@ impl RequestBuilder { } #[allow(dead_code)] - pub(crate) fn multipart(mut self, form: reqwest::multipart::Form) -> Self { + pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self { self.builder = self.builder.multipart(form); self } @@ -138,7 +138,7 @@ impl IntoFuture for RequestBuilder { } #[derive(Debug)] -pub(crate) struct TestResponse { +pub struct TestResponse { response: reqwest::Response, } @@ -152,27 +152,27 @@ impl Deref for TestResponse { impl TestResponse { #[allow(dead_code)] - pub(crate) async fn bytes(self) -> Bytes { + pub async fn bytes(self) -> Bytes { self.response.bytes().await.unwrap() } - pub(crate) async fn text(self) -> String { + pub async fn text(self) -> String { self.response.text().await.unwrap() } #[allow(dead_code)] - pub(crate) async fn json(self) -> T + pub async fn json(self) -> T where T: serde::de::DeserializeOwned, { self.response.json().await.unwrap() } - pub(crate) async fn chunk(&mut self) -> Option { + pub async fn chunk(&mut self) -> Option { self.response.chunk().await.unwrap() } - pub(crate) async fn chunk_text(&mut self) -> Option { + pub async fn chunk_text(&mut self) -> Option { let chunk = self.chunk().await?; Some(String::from_utf8(chunk.to_vec()).unwrap()) }