From 64960bb19cebf87b25f4faca6d5fda2fc6ef49a3 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 18 Nov 2022 12:02:58 +0100 Subject: [PATCH] Type safe state inheritance (#1532) * Make state type safe * fix examples * remove unnecessary `#[track_caller]`s * Router::into_service -> Router::with_state * fixup docs * macro docs * add missing docs * fix examples * format * changelog * Update trybuild tests * Make sure fallbacks are still inherited for opaque services (#1540) * Document nesting routers with different state * fix leftover conflicts --- axum-core/src/extract/default_body_limit.rs | 10 +- axum-extra/src/extract/cookie/mod.rs | 8 +- axum-extra/src/extract/cookie/private.rs | 7 +- axum-extra/src/extract/cookie/signed.rs | 7 +- axum-extra/src/routing/resource.rs | 18 +- axum-extra/src/routing/spa.rs | 24 +- axum-macros/Cargo.toml | 2 +- axum-macros/src/lib.rs | 26 +- axum-macros/tests/from_ref/pass/basic.rs | 7 +- .../fail/parts_extracting_body.rs | 1 - .../fail/parts_extracting_body.stderr | 4 +- .../tests/from_request/pass/container.rs | 1 - .../from_request/pass/container_parts.rs | 1 - axum-macros/tests/from_request/pass/named.rs | 1 - .../tests/from_request/pass/named_parts.rs | 1 - .../tests/from_request/pass/named_via.rs | 1 - .../from_request/pass/named_via_parts.rs | 1 - .../from_request/pass/override_rejection.rs | 1 - .../pass/override_rejection_parts.rs | 1 - .../tests/from_request/pass/state_enum_via.rs | 5 +- .../from_request/pass/state_enum_via_parts.rs | 5 +- .../tests/from_request/pass/state_explicit.rs | 5 +- .../from_request/pass/state_explicit_parts.rs | 5 +- .../from_request/pass/state_field_explicit.rs | 5 +- .../from_request/pass/state_field_infer.rs | 5 +- .../tests/from_request/pass/state_via.rs | 5 +- .../from_request/pass/state_via_infer.rs | 5 +- .../from_request/pass/state_via_parts.rs | 5 +- .../from_request/pass/state_with_rejection.rs | 5 +- axum/CHANGELOG.md | 22 ++ axum/benches/benches.rs | 53 ++- axum/src/boxed.rs | 173 +++++++++ axum/src/docs/middleware.md | 7 +- .../into_make_service_with_connect_info.md | 4 + axum/src/docs/routing/nest.md | 37 ++ axum/src/extract/state.rs | 16 +- axum/src/handler/boxed.rs | 123 ------ axum/src/handler/mod.rs | 2 - axum/src/lib.rs | 6 +- axum/src/middleware/from_fn.rs | 7 +- axum/src/middleware/map_request.rs | 7 +- axum/src/middleware/map_response.rs | 7 +- axum/src/response/mod.rs | 4 +- axum/src/routing/method_routing.rs | 85 +---- axum/src/routing/mod.rs | 356 +++++++----------- axum/src/routing/route.rs | 24 +- axum/src/routing/service.rs | 41 +- axum/src/routing/tests/fallback.rs | 51 ++- axum/src/routing/tests/merge.rs | 84 ----- axum/src/routing/tests/mod.rs | 16 +- axum/src/routing/tests/nest.rs | 47 +-- axum/src/test_helpers/test_client.rs | 5 +- examples/async-graphql/src/main.rs | 4 +- examples/chat/src/main.rs | 5 +- .../src/main.rs | 5 +- examples/key-value-store/src/main.rs | 14 +- examples/oauth/src/main.rs | 5 +- examples/reverse-proxy/src/main.rs | 2 +- examples/sessions/src/main.rs | 2 +- examples/sqlx-postgres/src/main.rs | 10 +- examples/todos/src/main.rs | 5 +- examples/tokio-postgres/src/main.rs | 10 +- 62 files changed, 675 insertions(+), 736 deletions(-) create mode 100644 axum/src/boxed.rs delete mode 100644 axum/src/handler/boxed.rs diff --git a/axum-core/src/extract/default_body_limit.rs b/axum-core/src/extract/default_body_limit.rs index 9b9aa825..14c155ab 100644 --- a/axum-core/src/extract/default_body_limit.rs +++ b/axum-core/src/extract/default_body_limit.rs @@ -34,13 +34,14 @@ use tower_layer::Layer; /// http::Request, /// }; /// -/// let router = Router::new() +/// let app = Router::new() /// .route( /// "/", /// // even with `DefaultBodyLimit` the request body is still just `Body` /// post(|request: Request| async {}), /// ) /// .layer(DefaultBodyLimit::max(1024)); +/// # let _: Router<(), _> = app; /// ``` /// /// ``` @@ -48,7 +49,7 @@ use tower_layer::Layer; /// use tower_http::limit::RequestBodyLimitLayer; /// use http_body::Limited; /// -/// let router = Router::new() +/// let app = Router::new() /// .route( /// "/", /// // `RequestBodyLimitLayer` changes the request body type to `Limited` @@ -56,6 +57,7 @@ use tower_layer::Layer; /// post(|request: Request>| async {}), /// ) /// .layer(RequestBodyLimitLayer::new(1024)); +/// # let _: Router<(), _> = app; /// ``` /// /// In general using `DefaultBodyLimit` is recommended but if you need to use third party @@ -102,7 +104,7 @@ impl DefaultBodyLimit { /// use tower_http::limit::RequestBodyLimitLayer; /// use http_body::Limited; /// - /// let app: Router<_, Limited> = Router::new() + /// let app: Router<(), Limited> = Router::new() /// .route("/", get(|body: Bytes| async {})) /// // Disable the default limit /// .layer(DefaultBodyLimit::disable()) @@ -138,7 +140,7 @@ impl DefaultBodyLimit { /// use tower_http::limit::RequestBodyLimitLayer; /// use http_body::Limited; /// - /// let app: Router<_, Limited> = Router::new() + /// let app: Router<(), Limited> = Router::new() /// .route("/", get(|body: Bytes| async {})) /// // Replace the default of 2MB with 1024 bytes. /// .layer(DefaultBodyLimit::max(1024)); diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index cbbbdd1b..0c6bcc4c 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -255,11 +255,11 @@ mod tests { custom_key: CustomKey(Key::generate()), }; - let app = Router::<_, Body>::with_state(state) + let app = Router::<_, Body>::new() .route("/set", get(set_cookie)) .route("/get", get(get_cookie)) .route("/remove", get(remove_cookie)) - .into_service(); + .with_state(state); let res = app .clone() @@ -352,9 +352,9 @@ mod tests { custom_key: CustomKey(Key::generate()), }; - let app = Router::<_, Body>::with_state(state) + let app = Router::<_, Body>::new() .route("/get", get(get_cookie)) - .into_service(); + .with_state(state); let res = app .clone() diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index 0b08fdc9..cceea95e 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -64,10 +64,11 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// key: Key::generate(), /// }; /// -/// let app = Router::with_state(state) +/// let app = Router::new() /// .route("/set", post(set_secret)) -/// .route("/get", get(get_secret)); -/// # let app: Router<_> = app; +/// .route("/get", get(get_secret)) +/// .with_state(state); +/// # let _: axum::routing::RouterService = app; /// ``` pub struct PrivateCookieJar { jar: cookie::CookieJar, diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index ca0aa4ca..1ef837a2 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -82,10 +82,11 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// key: Key::generate(), /// }; /// -/// let app = Router::with_state(state) +/// let app = Router::new() /// .route("/sessions", post(create_session)) -/// .route("/me", get(me)); -/// # let app: Router<_> = app; +/// .route("/me", get(me)) +/// .with_state(state); +/// # let _: axum::routing::RouterService = app; /// ``` pub struct SignedCookieJar { jar: cookie::CookieJar, diff --git a/axum-extra/src/routing/resource.rs b/axum-extra/src/routing/resource.rs index 7648d9f5..25d56643 100644 --- a/axum-extra/src/routing/resource.rs +++ b/axum-extra/src/routing/resource.rs @@ -38,30 +38,18 @@ pub struct Resource { pub(crate) router: Router, } -impl Resource<(), B> -where - B: axum::body::HttpBody + Send + 'static, -{ - /// Create a `Resource` with the given name. - /// - /// All routes will be nested at `/{resource_name}`. - pub fn named(resource_name: &str) -> Self { - Self::named_with((), resource_name) - } -} - impl Resource where B: axum::body::HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { - /// Create a `Resource` with the given name and state. + /// Create a `Resource` with the given name. /// /// All routes will be nested at `/{resource_name}`. - pub fn named_with(state: S, resource_name: &str) -> Self { + pub fn named(resource_name: &str) -> Self { Self { name: resource_name.to_owned(), - router: Router::with_state(state), + router: Router::new(), } } diff --git a/axum-extra/src/routing/spa.rs b/axum-extra/src/routing/spa.rs index 2c402fb6..5fc6d852 100644 --- a/axum-extra/src/routing/spa.rs +++ b/axum-extra/src/routing/spa.rs @@ -50,10 +50,10 @@ use tower_service::Service; /// - `GET /some/other/path` will serve `index.html` since there isn't another /// route for it /// - `GET /api/foo` will serve the `api_foo` handler function -pub struct SpaRouter Ready> { +pub struct SpaRouter Ready> { paths: Arc, handle_error: F, - _marker: PhantomData (B, T)>, + _marker: PhantomData (S, B, T)>, } #[derive(Debug)] @@ -63,7 +63,7 @@ struct Paths { index_file: PathBuf, } -impl SpaRouter Ready> { +impl SpaRouter Ready> { /// Create a new `SpaRouter`. /// /// Assets will be served at `GET /{serve_assets_at}` from the directory at `assets_dir`. @@ -86,7 +86,7 @@ impl SpaRouter Ready> { } } -impl SpaRouter { +impl SpaRouter { /// Set the path to the index file. /// /// `path` must be relative to `assets_dir` passed to [`SpaRouter::new`]. @@ -138,7 +138,7 @@ impl SpaRouter { /// let app = Router::new().merge(spa); /// # let _: Router = app; /// ``` - pub fn handle_error(self, f: F2) -> SpaRouter { + pub fn handle_error(self, f: F2) -> SpaRouter { SpaRouter { paths: self.paths, handle_error: f, @@ -147,7 +147,7 @@ impl SpaRouter { } } -impl From> for Router<(), B> +impl From> for Router where F: Clone + Send + Sync + 'static, HandleError, F, T>: Service, Error = Infallible>, @@ -155,8 +155,9 @@ where , F, T> as Service>>::Future: Send, B: HttpBody + Send + 'static, T: 'static, + S: Clone + Send + Sync + 'static, { - fn from(spa: SpaRouter) -> Self { + fn from(spa: SpaRouter) -> Router { let assets_service = get_service(ServeDir::new(&spa.paths.assets_dir)) .handle_error(spa.handle_error.clone()); @@ -195,7 +196,7 @@ where fn clone(&self) -> Self { Self { paths: self.paths.clone(), - handle_error: self.handle_error.clone(), + handle_error: self.handle_error, _marker: self._marker, } } @@ -264,13 +265,14 @@ mod tests { let spa = SpaRouter::new("/assets", "test_files").handle_error(handle_error); - Router::<_, Body>::new().merge(spa); + Router::<(), Body>::new().merge(spa); } #[allow(dead_code)] fn works_with_router_with_state() { - let _: Router = Router::with_state(String::new()) + let _: axum::RouterService = Router::new() .merge(SpaRouter::new("/assets", "test_files")) - .route("/", get(|_: axum::extract::State| async {})); + .route("/", get(|_: axum::extract::State| async {})) + .with_state(String::new()); } } diff --git a/axum-macros/Cargo.toml b/axum-macros/Cargo.toml index e9bae065..0dabf261 100644 --- a/axum-macros/Cargo.toml +++ b/axum-macros/Cargo.toml @@ -25,7 +25,7 @@ syn = { version = "1.0", features = [ ] } [dev-dependencies] -axum = { path = "../axum", version = "0.6.0-rc.2", features = ["headers"] } +axum = { path = "../axum", version = "0.6.0-rc.2", features = ["headers", "macros"] } axum-extra = { path = "../axum-extra", version = "0.4.0-rc.1", features = ["typed-routing", "cookie-private"] } rustversion = "1.0" serde = { version = "1.0", features = ["derive"] } diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index 3f1a7db2..6a2b8ea4 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -181,7 +181,6 @@ use from_request::Trait::{FromRequest, FromRequestParts}; /// rejection type with `#[from_request(rejection(YourType))]`: /// /// ``` -/// use axum_macros::FromRequest; /// use axum::{ /// extract::{ /// rejection::{ExtensionRejection, StringRejection}, @@ -463,8 +462,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream { /// As the error message says, handler function needs to be async. /// /// ``` -/// use axum::{routing::get, Router}; -/// use axum_macros::debug_handler; +/// use axum::{routing::get, Router, debug_handler}; /// /// #[tokio::main] /// async fn main() { @@ -493,8 +491,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream { /// To work around that the request body type can be customized like so: /// /// ``` -/// use axum::{body::BoxBody, http::Request}; -/// # use axum_macros::debug_handler; +/// use axum::{body::BoxBody, http::Request, debug_handler}; /// /// #[debug_handler(body = BoxBody)] /// async fn handler(request: Request) {} @@ -506,8 +503,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream { /// [`axum::extract::State`] argument: /// /// ``` -/// use axum::extract::State; -/// # use axum_macros::debug_handler; +/// use axum::{debug_handler, extract::State}; /// /// #[debug_handler] /// async fn handler( @@ -523,8 +519,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream { /// customize the state type you can set it with `#[debug_handler(state = ...)]`: /// /// ``` -/// use axum::extract::{State, FromRef}; -/// # use axum_macros::debug_handler; +/// use axum::{debug_handler, extract::{State, FromRef}}; /// /// #[debug_handler(state = AppState)] /// async fn handler( @@ -579,8 +574,11 @@ pub fn derive_typed_path(input: TokenStream) -> TokenStream { /// # Example /// /// ``` -/// use axum_macros::FromRef; -/// use axum::{Router, routing::get, extract::State}; +/// use axum::{ +/// Router, +/// routing::get, +/// extract::{State, FromRef}, +/// }; /// /// # /// # type AuthToken = String; @@ -605,8 +603,10 @@ pub fn derive_typed_path(input: TokenStream) -> TokenStream { /// database_pool, /// }; /// -/// let app = Router::with_state(state).route("/", get(handler).post(other_handler)); -/// # let _: Router = app; +/// let app = Router::new() +/// .route("/", get(handler).post(other_handler)) +/// .with_state(state); +/// # let _: axum::routing::RouterService = app; /// ``` /// /// [`FromRef`]: https://docs.rs/axum/latest/axum/extract/trait.FromRef.html diff --git a/axum-macros/tests/from_ref/pass/basic.rs b/axum-macros/tests/from_ref/pass/basic.rs index 4a6631d3..055d63b8 100644 --- a/axum-macros/tests/from_ref/pass/basic.rs +++ b/axum-macros/tests/from_ref/pass/basic.rs @@ -1,5 +1,4 @@ -use axum_macros::FromRef; -use axum::{Router, routing::get, extract::State}; +use axum::{Router, routing::get, extract::{State, FromRef}}; // This will implement `FromRef` for each field in the struct. #[derive(Clone, FromRef)] @@ -15,5 +14,7 @@ fn main() { auth_token: Default::default(), }; - let _: Router = Router::with_state(state).route("/", get(handler)); + let _: axum::routing::RouterService = Router::new() + .route("/", get(handler)) + .with_state(state); } diff --git a/axum-macros/tests/from_request/fail/parts_extracting_body.rs b/axum-macros/tests/from_request/fail/parts_extracting_body.rs index 18fb312d..45a93777 100644 --- a/axum-macros/tests/from_request/fail/parts_extracting_body.rs +++ b/axum-macros/tests/from_request/fail/parts_extracting_body.rs @@ -1,5 +1,4 @@ use axum::{extract::FromRequestParts, response::Response}; -use axum_macros::FromRequestParts; #[derive(FromRequestParts)] struct Extractor { diff --git a/axum-macros/tests/from_request/fail/parts_extracting_body.stderr b/axum-macros/tests/from_request/fail/parts_extracting_body.stderr index 9d92f40a..32341828 100644 --- a/axum-macros/tests/from_request/fail/parts_extracting_body.stderr +++ b/axum-macros/tests/from_request/fail/parts_extracting_body.stderr @@ -1,7 +1,7 @@ error[E0277]: the trait bound `String: FromRequestParts` is not satisfied - --> tests/from_request/fail/parts_extracting_body.rs:6:11 + --> tests/from_request/fail/parts_extracting_body.rs:5:11 | -6 | body: String, +5 | body: String, | ^^^^^^ the trait `FromRequestParts` is not implemented for `String` | = help: the following other types implement trait `FromRequestParts`: diff --git a/axum-macros/tests/from_request/pass/container.rs b/axum-macros/tests/from_request/pass/container.rs index c125902b..6e62c569 100644 --- a/axum-macros/tests/from_request/pass/container.rs +++ b/axum-macros/tests/from_request/pass/container.rs @@ -3,7 +3,6 @@ use axum::{ extract::{FromRequest, Json}, response::Response, }; -use axum_macros::FromRequest; use serde::Deserialize; #[derive(Deserialize, FromRequest)] diff --git a/axum-macros/tests/from_request/pass/container_parts.rs b/axum-macros/tests/from_request/pass/container_parts.rs index c3dabe54..dedc1719 100644 --- a/axum-macros/tests/from_request/pass/container_parts.rs +++ b/axum-macros/tests/from_request/pass/container_parts.rs @@ -2,7 +2,6 @@ use axum::{ extract::{FromRequestParts, Extension}, response::Response, }; -use axum_macros::FromRequestParts; #[derive(Clone, FromRequestParts)] #[from_request(via(Extension))] diff --git a/axum-macros/tests/from_request/pass/named.rs b/axum-macros/tests/from_request/pass/named.rs index 092989d1..e042477b 100644 --- a/axum-macros/tests/from_request/pass/named.rs +++ b/axum-macros/tests/from_request/pass/named.rs @@ -4,7 +4,6 @@ use axum::{ response::Response, headers::{self, UserAgent}, }; -use axum_macros::FromRequest; #[derive(FromRequest)] struct Extractor { diff --git a/axum-macros/tests/from_request/pass/named_parts.rs b/axum-macros/tests/from_request/pass/named_parts.rs index b997a0dd..27dce64f 100644 --- a/axum-macros/tests/from_request/pass/named_parts.rs +++ b/axum-macros/tests/from_request/pass/named_parts.rs @@ -3,7 +3,6 @@ use axum::{ headers::{self, UserAgent}, response::Response, }; -use axum_macros::FromRequestParts; #[derive(FromRequestParts)] struct Extractor { diff --git a/axum-macros/tests/from_request/pass/named_via.rs b/axum-macros/tests/from_request/pass/named_via.rs index 23da2ac6..41cc3615 100644 --- a/axum-macros/tests/from_request/pass/named_via.rs +++ b/axum-macros/tests/from_request/pass/named_via.rs @@ -7,7 +7,6 @@ use axum::{ }, headers::{self, UserAgent}, }; -use axum_macros::FromRequest; #[derive(FromRequest)] struct Extractor { diff --git a/axum-macros/tests/from_request/pass/named_via_parts.rs b/axum-macros/tests/from_request/pass/named_via_parts.rs index bdf1ac6e..9a389e54 100644 --- a/axum-macros/tests/from_request/pass/named_via_parts.rs +++ b/axum-macros/tests/from_request/pass/named_via_parts.rs @@ -6,7 +6,6 @@ use axum::{ }, headers::{self, UserAgent}, }; -use axum_macros::FromRequestParts; #[derive(FromRequestParts)] struct Extractor { diff --git a/axum-macros/tests/from_request/pass/override_rejection.rs b/axum-macros/tests/from_request/pass/override_rejection.rs index 7167ffe0..0147c9a8 100644 --- a/axum-macros/tests/from_request/pass/override_rejection.rs +++ b/axum-macros/tests/from_request/pass/override_rejection.rs @@ -6,7 +6,6 @@ use axum::{ routing::get, Extension, Router, }; -use axum_macros::FromRequest; fn main() { let _: Router = Router::new().route("/", get(handler).post(handler_result)); diff --git a/axum-macros/tests/from_request/pass/override_rejection_parts.rs b/axum-macros/tests/from_request/pass/override_rejection_parts.rs index 8e2271c3..8ef9cb22 100644 --- a/axum-macros/tests/from_request/pass/override_rejection_parts.rs +++ b/axum-macros/tests/from_request/pass/override_rejection_parts.rs @@ -6,7 +6,6 @@ use axum::{ routing::get, Extension, Router, }; -use axum_macros::FromRequestParts; fn main() { let _: Router = Router::new().route("/", get(handler).post(handler_result)); diff --git a/axum-macros/tests/from_request/pass/state_enum_via.rs b/axum-macros/tests/from_request/pass/state_enum_via.rs index 99af401c..8da81901 100644 --- a/axum-macros/tests/from_request/pass/state_enum_via.rs +++ b/axum-macros/tests/from_request/pass/state_enum_via.rs @@ -6,9 +6,10 @@ use axum::{ use axum_macros::FromRequest; fn main() { - let _: Router = Router::with_state(AppState::default()) + let _: axum::routing::RouterService = Router::new() .route("/a", get(|_: AppState| async {})) - .route("/b", get(|_: InnerState| async {})); + .route("/b", get(|_: InnerState| async {})) + .with_state(AppState::default()); } #[derive(Clone, FromRequest)] diff --git a/axum-macros/tests/from_request/pass/state_enum_via_parts.rs b/axum-macros/tests/from_request/pass/state_enum_via_parts.rs index 9700ac94..3d5b5b0b 100644 --- a/axum-macros/tests/from_request/pass/state_enum_via_parts.rs +++ b/axum-macros/tests/from_request/pass/state_enum_via_parts.rs @@ -6,10 +6,11 @@ use axum::{ use axum_macros::FromRequestParts; fn main() { - let _: Router = Router::with_state(AppState::default()) + let _: axum::routing::RouterService = Router::new() .route("/a", get(|_: AppState| async {})) .route("/b", get(|_: InnerState| async {})) - .route("/c", get(|_: AppState, _: InnerState| async {})); + .route("/c", get(|_: AppState, _: InnerState| async {})) + .with_state(AppState::default()); } #[derive(Clone, FromRequestParts)] diff --git a/axum-macros/tests/from_request/pass/state_explicit.rs b/axum-macros/tests/from_request/pass/state_explicit.rs index 5a608eab..bea2958d 100644 --- a/axum-macros/tests/from_request/pass/state_explicit.rs +++ b/axum-macros/tests/from_request/pass/state_explicit.rs @@ -6,8 +6,9 @@ use axum::{ }; fn main() { - let _: Router = Router::with_state(AppState::default()) - .route("/b", get(|_: Extractor| async {})); + let _: axum::routing::RouterService = Router::new() + .route("/b", get(|_: Extractor| async {})) + .with_state(AppState::default()); } #[derive(FromRequest)] diff --git a/axum-macros/tests/from_request/pass/state_explicit_parts.rs b/axum-macros/tests/from_request/pass/state_explicit_parts.rs index 2aeb6e3a..a4865fc8 100644 --- a/axum-macros/tests/from_request/pass/state_explicit_parts.rs +++ b/axum-macros/tests/from_request/pass/state_explicit_parts.rs @@ -7,8 +7,9 @@ use axum::{ use std::collections::HashMap; fn main() { - let _: Router = Router::with_state(AppState::default()) - .route("/b", get(|_: Extractor| async {})); + let _: axum::routing::RouterService = Router::new() + .route("/b", get(|_: Extractor| async {})) + .with_state(AppState::default()); } #[derive(FromRequestParts)] diff --git a/axum-macros/tests/from_request/pass/state_field_explicit.rs b/axum-macros/tests/from_request/pass/state_field_explicit.rs index 363efab8..1caccf46 100644 --- a/axum-macros/tests/from_request/pass/state_field_explicit.rs +++ b/axum-macros/tests/from_request/pass/state_field_explicit.rs @@ -6,8 +6,9 @@ use axum::{ use axum_macros::FromRequest; fn main() { - let _: Router = Router::with_state(AppState::default()) - .route("/", get(|_: Extractor| async {})); + let _: axum::routing::RouterService = Router::new() + .route("/", get(|_: Extractor| async {})) + .with_state(AppState::default()); } #[derive(FromRequest)] diff --git a/axum-macros/tests/from_request/pass/state_field_infer.rs b/axum-macros/tests/from_request/pass/state_field_infer.rs index 03330578..08884dcb 100644 --- a/axum-macros/tests/from_request/pass/state_field_infer.rs +++ b/axum-macros/tests/from_request/pass/state_field_infer.rs @@ -6,8 +6,9 @@ use axum::{ use axum_macros::FromRequest; fn main() { - let _: Router = Router::with_state(AppState::default()) - .route("/", get(|_: Extractor| async {})); + let _: axum::routing::RouterService = Router::new() + .route("/", get(|_: Extractor| async {})) + .with_state(AppState::default()); } #[derive(FromRequest)] diff --git a/axum-macros/tests/from_request/pass/state_via.rs b/axum-macros/tests/from_request/pass/state_via.rs index b7196a39..590ec535 100644 --- a/axum-macros/tests/from_request/pass/state_via.rs +++ b/axum-macros/tests/from_request/pass/state_via.rs @@ -6,9 +6,10 @@ use axum::{ use axum_macros::FromRequest; fn main() { - let _: Router = Router::with_state(AppState::default()) + let _: axum::routing::RouterService = Router::new() .route("/b", get(|_: (), _: AppState| async {})) - .route("/c", get(|_: (), _: InnerState| async {})); + .route("/c", get(|_: (), _: InnerState| async {})) + .with_state(AppState::default()); } #[derive(Clone, Default, FromRequest)] diff --git a/axum-macros/tests/from_request/pass/state_via_infer.rs b/axum-macros/tests/from_request/pass/state_via_infer.rs index 75b170a6..50b4a1bc 100644 --- a/axum-macros/tests/from_request/pass/state_via_infer.rs +++ b/axum-macros/tests/from_request/pass/state_via_infer.rs @@ -6,8 +6,9 @@ use axum::{ use axum_macros::FromRequest; fn main() { - let _: Router = Router::with_state(AppState::default()) - .route("/b", get(|_: AppState| async {})); + let _: axum::routing::RouterService = Router::new() + .route("/b", get(|_: AppState| async {})) + .with_state(AppState::default()); } // if we're extract "via" `State` and not specifying state diff --git a/axum-macros/tests/from_request/pass/state_via_parts.rs b/axum-macros/tests/from_request/pass/state_via_parts.rs index b747f474..2817e832 100644 --- a/axum-macros/tests/from_request/pass/state_via_parts.rs +++ b/axum-macros/tests/from_request/pass/state_via_parts.rs @@ -6,10 +6,11 @@ use axum::{ use axum_macros::FromRequestParts; fn main() { - let _: Router = Router::with_state(AppState::default()) + let _: axum::routing::RouterService = Router::new() .route("/a", get(|_: AppState, _: InnerState, _: String| async {})) .route("/b", get(|_: AppState, _: String| async {})) - .route("/c", get(|_: InnerState, _: String| async {})); + .route("/c", get(|_: InnerState, _: String| async {})) + .with_state(AppState::default()); } #[derive(Clone, Default, FromRequestParts)] diff --git a/axum-macros/tests/from_request/pass/state_with_rejection.rs b/axum-macros/tests/from_request/pass/state_with_rejection.rs index 82ecfe3b..aef3d9c7 100644 --- a/axum-macros/tests/from_request/pass/state_with_rejection.rs +++ b/axum-macros/tests/from_request/pass/state_with_rejection.rs @@ -8,8 +8,9 @@ use axum::{ use axum_macros::FromRequest; fn main() { - let _: Router = - Router::with_state(AppState::default()).route("/a", get(|_: Extractor| async {})); + let _: axum::routing::RouterService = Router::new() + .route("/a", get(|_: Extractor| async {})) + .with_state(AppState::default()); } #[derive(Clone, Default, FromRequest)] diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 55314de5..5b7e96e8 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,6 +7,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +- **breaking:** `Router::with_state` is no longer a constructor. It is instead + used to convert the router into a `RouterService` ([#1532]) + + This nested router on 0.6.0-rc.4 + + ```rust + Router::with_state(state).route(...); + ``` + + Becomes this in 0.6.0-rc.5 + + ```rust + Router::new().route(...).with_state(state); + ``` + +- **breaking:**: `Router::nest` and `Router::merge` now only supports nesting + routers that use the same state type as the router they're being merged into. + Use `FromRef` for substates ([#1532]) + +- **added:** Add `accept_unmasked_frames` setting in WebSocketUpgrade ([#1529]) - **fixed:** Nested routers will now inherit fallbacks from outer routers ([#1521]) - **added:** Add `accept_unmasked_frames` setting in WebSocketUpgrade ([#1529]) - **added:** Add `WebSocketUpgrade::on_failed_upgrade` to customize what to do @@ -15,6 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1539]: https://github.com/tokio-rs/axum/pull/1539 [#1521]: https://github.com/tokio-rs/axum/pull/1521 +[#1529]: https://github.com/tokio-rs/axum/pull/1529 +[#1532]: https://github.com/tokio-rs/axum/pull/1532 # 0.6.0-rc.4 (9. November, 2022) diff --git a/axum/benches/benches.rs b/axum/benches/benches.rs index 13295a3c..0ff269f0 100644 --- a/axum/benches/benches.rs +++ b/axum/benches/benches.rs @@ -1,7 +1,7 @@ use axum::{ extract::State, routing::{get, post}, - Extension, Json, Router, Server, + Extension, Json, Router, RouterService, Server, }; use hyper::server::conn::AddrIncoming; use serde::{Deserialize, Serialize}; @@ -17,9 +17,13 @@ fn main() { ensure_rewrk_is_installed(); } - benchmark("minimal").run(Router::new); + benchmark("minimal").run(|| Router::new().into_service()); - benchmark("basic").run(|| Router::new().route("/", get(|| async { "Hello, World!" }))); + benchmark("basic").run(|| { + Router::new() + .route("/", get(|| async { "Hello, World!" })) + .into_service() + }); benchmark("routing").path("/foo/bar/baz").run(|| { let mut app = Router::new(); @@ -30,26 +34,32 @@ fn main() { } } } - app.route("/foo/bar/baz", get(|| async {})) + app.route("/foo/bar/baz", get(|| async {})).into_service() }); benchmark("receive-json") .method("post") .headers(&[("content-type", "application/json")]) .body(r#"{"n": 123, "s": "hi there", "b": false}"#) - .run(|| Router::new().route("/", post(|_: Json| async {}))); + .run(|| { + Router::new() + .route("/", post(|_: Json| async {})) + .into_service() + }); benchmark("send-json").run(|| { - Router::new().route( - "/", - get(|| async { - Json(Payload { - n: 123, - s: "hi there".to_owned(), - b: false, - }) - }), - ) + Router::new() + .route( + "/", + get(|| async { + Json(Payload { + n: 123, + s: "hi there".to_owned(), + b: false, + }) + }), + ) + .into_service() }); let state = AppState { @@ -65,10 +75,14 @@ fn main() { Router::new() .route("/", get(|_: Extension| async {})) .layer(Extension(state.clone())) + .into_service() }); - benchmark("state") - .run(|| Router::with_state(state.clone()).route("/", get(|_: State| async {}))); + benchmark("state").run(|| { + Router::new() + .route("/", get(|_: State| async {})) + .with_state(state.clone()) + }); } #[derive(Clone)] @@ -117,10 +131,9 @@ impl BenchmarkBuilder { config_method!(headers, &'static [(&'static str, &'static str)]); config_method!(body, &'static str); - fn run(self, f: F) + fn run(self, f: F) where - F: FnOnce() -> Router, - S: Clone + Send + Sync + 'static, + F: FnOnce() -> RouterService, { // support only running some benchmarks with // ``` diff --git a/axum/src/boxed.rs b/axum/src/boxed.rs new file mode 100644 index 00000000..3612d211 --- /dev/null +++ b/axum/src/boxed.rs @@ -0,0 +1,173 @@ +use std::{convert::Infallible, fmt}; + +use crate::{body::HttpBody, handler::Handler, routing::Route, Router}; + +pub(crate) struct BoxedIntoRoute(Box>); + +impl BoxedIntoRoute +where + S: Clone + Send + Sync + 'static, + B: Send + 'static, +{ + pub(crate) fn from_handler(handler: H) -> Self + where + H: Handler, + T: 'static, + { + Self(Box::new(MakeErasedHandler { + handler, + into_route: |handler, state| Route::new(Handler::with_state(handler, state)), + })) + } + + pub(crate) fn from_router(router: Router) -> Self + where + B: HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, + { + Self(Box::new(MakeErasedRouter { + router, + into_route: |router, state| Route::new(router.with_state(state)), + })) + } +} + +impl BoxedIntoRoute { + pub(crate) fn map(self, f: F) -> BoxedIntoRoute + where + S: 'static, + B: 'static, + E: 'static, + F: FnOnce(Route) -> Route + Clone + Send + 'static, + B2: 'static, + E2: 'static, + { + BoxedIntoRoute(Box::new(Map { + inner: self.0, + layer: Box::new(f), + })) + } + + pub(crate) fn into_route(self, state: S) -> Route { + self.0.into_route(state) + } +} + +impl Clone for BoxedIntoRoute { + fn clone(&self) -> Self { + Self(self.0.clone_box()) + } +} + +impl fmt::Debug for BoxedIntoRoute { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("BoxedIntoRoute").finish() + } +} + +pub(crate) trait ErasedIntoRoute: Send { + fn clone_box(&self) -> Box>; + + fn into_route(self: Box, state: S) -> Route; +} + +pub(crate) struct MakeErasedHandler { + pub(crate) handler: H, + pub(crate) into_route: fn(H, S) -> Route, +} + +impl ErasedIntoRoute for MakeErasedHandler +where + H: Clone + Send + 'static, + S: 'static, + B: 'static, +{ + fn clone_box(&self) -> Box> { + Box::new(self.clone()) + } + + fn into_route(self: Box, state: S) -> Route { + (self.into_route)(self.handler, state) + } +} + +impl Clone for MakeErasedHandler +where + H: Clone, +{ + fn clone(&self) -> Self { + Self { + handler: self.handler.clone(), + into_route: self.into_route, + } + } +} + +pub(crate) struct MakeErasedRouter { + pub(crate) router: Router, + pub(crate) into_route: fn(Router, S) -> Route, +} + +impl ErasedIntoRoute for MakeErasedRouter +where + S: Clone + Send + 'static, + B: 'static, +{ + fn clone_box(&self) -> Box> { + Box::new(self.clone()) + } + + fn into_route(self: Box, state: S) -> Route { + (self.into_route)(self.router, state) + } +} + +impl Clone for MakeErasedRouter +where + S: Clone, +{ + fn clone(&self) -> Self { + Self { + router: self.router.clone(), + into_route: self.into_route, + } + } +} + +pub(crate) struct Map { + pub(crate) inner: Box>, + pub(crate) layer: Box>, +} + +impl ErasedIntoRoute for Map +where + S: 'static, + B: 'static, + E: 'static, + B2: 'static, + E2: 'static, +{ + fn clone_box(&self) -> Box> { + Box::new(Self { + inner: self.inner.clone_box(), + layer: self.layer.clone_box(), + }) + } + + fn into_route(self: Box, state: S) -> Route { + (self.layer)(self.inner.into_route(state)) + } +} + +pub(crate) trait LayerFn: FnOnce(Route) -> Route + Send { + fn clone_box(&self) -> Box>; +} + +impl LayerFn for F +where + F: FnOnce(Route) -> Route + Clone + Send + 'static, +{ + fn clone_box(&self) -> Box> { + Box::new(self.clone()) + } +} diff --git a/axum/src/docs/middleware.md b/axum/src/docs/middleware.md index c3579c13..603fac78 100644 --- a/axum/src/docs/middleware.md +++ b/axum/src/docs/middleware.md @@ -462,10 +462,11 @@ async fn handler(_: State) {} let state = AppState {}; -let app = Router::with_state(state.clone()) +let app = Router::new() .route("/", get(handler)) - .layer(MyLayer { state }); -# let _: Router<_> = app; + .layer(MyLayer { state: state.clone() }) + .with_state(state); +# let _: axum::routing::RouterService = app; ``` # Passing state from middleware to handlers diff --git a/axum/src/docs/routing/into_make_service_with_connect_info.md b/axum/src/docs/routing/into_make_service_with_connect_info.md index 05ee750c..86165361 100644 --- a/axum/src/docs/routing/into_make_service_with_connect_info.md +++ b/axum/src/docs/routing/into_make_service_with_connect_info.md @@ -2,6 +2,10 @@ Convert this router into a [`MakeService`], that will store `C`'s associated `ConnectInfo` in a request extension such that [`ConnectInfo`] can extract it. +This is a convenience method for routers that don't have any state (i.e. the +state type is `()`). Use [`RouterService::into_make_service_with_connect_info`] +otherwise. + This enables extracting things like the client's remote address. Extracting [`std::net::SocketAddr`] is supported out of the box: diff --git a/axum/src/docs/routing/nest.md b/axum/src/docs/routing/nest.md index 6ea05478..96dbd334 100644 --- a/axum/src/docs/routing/nest.md +++ b/axum/src/docs/routing/nest.md @@ -147,6 +147,43 @@ let app = Router::new() Here requests like `GET /api/not-found` will go to `api_fallback`. +# Nesting a router with a different state type + +By default `nest` requires a `Router` with the same state type as the outer +`Router`. If you need to nest a `Router` with a different state type you can +use [`Router::with_state`] and [`Router::nest_service`]: + +```rust +use axum::{ + Router, + routing::get, + extract::State, +}; + +#[derive(Clone)] +struct InnerState {} + +#[derive(Clone)] +struct OuterState {} + +async fn inner_handler(state: State) {} + +let inner_router = Router::new() + .route("/bar", get(inner_handler)) + .with_state(InnerState {}); + +async fn outer_handler(state: State) {} + +let app = Router::new() + .route("/", get(outer_handler)) + .nest_service("/foo", inner_router) + .with_state(OuterState {}); +# let _: axum::routing::RouterService = app; +``` + +Note that the inner router will still inherit the fallback from the outer +router. + # Panics - If the route overlaps with another route. See [`Router::route`] diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index 7e7b68b5..966d72a0 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -31,7 +31,10 @@ use std::{ /// let state = AppState {}; /// /// // create a `Router` that holds our state -/// let app = Router::with_state(state).route("/", get(handler)); +/// let app = Router::new() +/// .route("/", get(handler)) +/// // provide the state so the router can access it +/// .with_state(state); /// /// async fn handler( /// // access the state via the `State` extractor @@ -40,7 +43,7 @@ use std::{ /// ) { /// // use `state`... /// } -/// # let _: Router = app; +/// # let _: axum::routing::RouterService = app; /// ``` /// /// # With `MethodRouter` @@ -119,9 +122,10 @@ use std::{ /// api_state: ApiState {}, /// }; /// -/// let app = Router::with_state(state) +/// let app = Router::new() /// .route("/", get(handler)) -/// .route("/api/users", get(api_users)); +/// .route("/api/users", get(api_users)) +/// .with_state(state); /// /// async fn api_users( /// // access the api specific state @@ -134,9 +138,11 @@ use std::{ /// State(state): State, /// ) { /// } -/// # let _: Router = app; +/// # let _: axum::routing::RouterService = app; /// ``` /// +/// For convenience `FromRef` can also be derived using `#[derive(FromRef)]`. +/// /// # For library authors /// /// If you're writing a library that has an extractor that needs state, this is the recommended way diff --git a/axum/src/handler/boxed.rs b/axum/src/handler/boxed.rs deleted file mode 100644 index bfdd0d0f..00000000 --- a/axum/src/handler/boxed.rs +++ /dev/null @@ -1,123 +0,0 @@ -use std::convert::Infallible; - -use super::Handler; -use crate::routing::Route; - -pub(crate) struct BoxedHandler(Box>); - -impl BoxedHandler -where - S: Clone + Send + Sync + 'static, - B: Send + 'static, -{ - pub(crate) fn new(handler: H) -> Self - where - H: Handler, - T: 'static, - { - Self(Box::new(MakeErasedHandler { - handler, - into_route: |handler, state| Route::new(Handler::with_state(handler, state)), - })) - } -} - -impl BoxedHandler { - pub(crate) fn map(self, f: F) -> BoxedHandler - where - S: 'static, - B: 'static, - E: 'static, - F: FnOnce(Route) -> Route + Clone + Send + 'static, - B2: 'static, - E2: 'static, - { - BoxedHandler(Box::new(Map { - handler: self.0, - layer: Box::new(f), - })) - } - - pub(crate) fn into_route(self, state: S) -> Route { - self.0.into_route(state) - } -} - -impl Clone for BoxedHandler { - fn clone(&self) -> Self { - Self(self.0.clone_box()) - } -} - -trait ErasedHandler: Send { - fn clone_box(&self) -> Box>; - - fn into_route(self: Box, state: S) -> Route; -} - -struct MakeErasedHandler { - handler: H, - into_route: fn(H, S) -> Route, -} - -impl ErasedHandler for MakeErasedHandler -where - H: Clone + Send + 'static, - S: 'static, - B: 'static, -{ - fn clone_box(&self) -> Box> { - Box::new(self.clone()) - } - - fn into_route(self: Box, state: S) -> Route { - (self.into_route)(self.handler, state) - } -} - -impl Clone for MakeErasedHandler { - fn clone(&self) -> Self { - Self { - handler: self.handler.clone(), - into_route: self.into_route, - } - } -} - -struct Map { - handler: Box>, - layer: Box>, -} - -impl ErasedHandler for Map -where - S: 'static, - B: 'static, - E: 'static, - B2: 'static, - E2: 'static, -{ - fn clone_box(&self) -> Box> { - Box::new(Self { - handler: self.handler.clone_box(), - layer: self.layer.clone_box(), - }) - } - - fn into_route(self: Box, state: S) -> Route { - (self.layer)(self.handler.into_route(state)) - } -} - -trait LayerFn: FnOnce(Route) -> Route + Send { - fn clone_box(&self) -> Box>; -} - -impl LayerFn for F -where - F: FnOnce(Route) -> Route + Clone + Send + 'static, -{ - fn clone_box(&self) -> Box> { - Box::new(self.clone()) - } -} diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index d7427dc1..fe76f9bd 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -49,11 +49,9 @@ use tower::ServiceExt; use tower_layer::Layer; use tower_service::Service; -mod boxed; pub mod future; mod service; -pub(crate) use self::boxed::BoxedHandler; pub use self::service::HandlerService; /// Trait for async functions that can be used to handle requests. diff --git a/axum/src/lib.rs b/axum/src/lib.rs index 837701e4..60c33834 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -188,8 +188,9 @@ //! //! let shared_state = Arc::new(AppState { /* ... */ }); //! -//! let app = Router::with_state(shared_state) -//! .route("/", get(handler)); +//! let app = Router::new() +//! .route("/", get(handler)) +//! .with_state(shared_state); //! //! async fn handler( //! State(state): State>, @@ -434,6 +435,7 @@ #[macro_use] pub(crate) mod macros; +mod boxed; mod extension; #[cfg(feature = "form")] mod form; diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index 3324ad87..e9a525ee 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -133,10 +133,11 @@ pub fn from_fn(f: F) -> FromFnLayer { /// /// let state = AppState { /* ... */ }; /// -/// let app = Router::with_state(state.clone()) +/// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) -/// .route_layer(middleware::from_fn_with_state(state, my_middleware)); -/// # let app: Router<_> = app; +/// .route_layer(middleware::from_fn_with_state(state.clone(), my_middleware)) +/// .with_state(state); +/// # let _: axum::routing::RouterService = app; /// ``` pub fn from_fn_with_state(state: S, f: F) -> FromFnLayer { FromFnLayer { diff --git a/axum/src/middleware/map_request.rs b/axum/src/middleware/map_request.rs index f574681d..eeac8161 100644 --- a/axum/src/middleware/map_request.rs +++ b/axum/src/middleware/map_request.rs @@ -148,10 +148,11 @@ pub fn map_request(f: F) -> MapRequestLayer { /// /// let state = AppState { /* ... */ }; /// -/// let app = Router::with_state(state.clone()) +/// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) -/// .route_layer(map_request_with_state(state, my_middleware)); -/// # let app: Router<_> = app; +/// .route_layer(map_request_with_state(state.clone(), my_middleware)) +/// .with_state(state); +/// # let _: axum::routing::RouterService = app; /// ``` pub fn map_request_with_state(state: S, f: F) -> MapRequestLayer { MapRequestLayer { diff --git a/axum/src/middleware/map_response.rs b/axum/src/middleware/map_response.rs index a8b37d2e..3be34ab6 100644 --- a/axum/src/middleware/map_response.rs +++ b/axum/src/middleware/map_response.rs @@ -132,10 +132,11 @@ pub fn map_response(f: F) -> MapResponseLayer { /// /// let state = AppState { /* ... */ }; /// -/// let app = Router::with_state(state.clone()) +/// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) -/// .route_layer(map_response_with_state(state, my_middleware)); -/// # let app: Router<_> = app; +/// .route_layer(map_response_with_state(state.clone(), my_middleware)) +/// .with_state(state); +/// # let _: axum::routing::RouterService = app; /// ``` pub fn map_response_with_state(state: S, f: F) -> MapResponseLayer { MapResponseLayer { diff --git a/axum/src/response/mod.rs b/axum/src/response/mod.rs index 3bfcead3..393885c4 100644 --- a/axum/src/response/mod.rs +++ b/axum/src/response/mod.rs @@ -98,7 +98,7 @@ mod tests { } } - Router::<_, Body>::new() + Router::<(), Body>::new() .route("/", get(impl_trait_ok)) .route("/", get(impl_trait_err)) .route("/", get(impl_trait_both)) @@ -208,7 +208,7 @@ mod tests { ) } - Router::<_, Body>::new() + Router::<(), Body>::new() .route("/", get(status)) .route("/", get(status_headermap)) .route("/", get(status_header_array)) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 8672b1b2..4b5d5b3f 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -5,12 +5,12 @@ use super::{FallbackRoute, IntoMakeService}; use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::{ body::{Body, Bytes, HttpBody}, + boxed::BoxedIntoRoute, error_handling::{HandleError, HandleErrorLayer}, - handler::{BoxedHandler, Handler}, + handler::Handler, http::{Method, Request, StatusCode}, response::Response, routing::{future::RouteFuture, Fallback, MethodFilter, Route}, - util::try_downcast, }; use axum_core::response::IntoResponse; use bytes::BytesMut; @@ -606,7 +606,7 @@ where { self.on_endpoint( filter, - MethodEndpoint::BoxedHandler(BoxedHandler::new(handler)), + MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)), ) } @@ -626,7 +626,7 @@ where T: 'static, S: Send + Sync + 'static, { - self.fallback = Fallback::BoxedHandler(BoxedHandler::new(handler)); + self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler)); self } } @@ -749,46 +749,6 @@ where } } - pub(crate) fn map_state(self, state: &S) -> MethodRouter - where - E: 'static, - S: 'static, - S2: 'static, - { - MethodRouter { - get: self.get.map_state(state), - head: self.head.map_state(state), - delete: self.delete.map_state(state), - options: self.options.map_state(state), - patch: self.patch.map_state(state), - post: self.post.map_state(state), - put: self.put.map_state(state), - trace: self.trace.map_state(state), - fallback: self.fallback.map_state(state), - allow_header: self.allow_header, - } - } - - pub(crate) fn downcast_state(self) -> Option> - where - E: 'static, - S: 'static, - S2: 'static, - { - Some(MethodRouter { - get: self.get.downcast_state()?, - head: self.head.downcast_state()?, - delete: self.delete.downcast_state()?, - options: self.options.downcast_state()?, - patch: self.patch.downcast_state()?, - post: self.post.downcast_state()?, - put: self.put.downcast_state()?, - trace: self.trace.downcast_state()?, - fallback: self.fallback.downcast_state()?, - allow_header: self.allow_header, - }) - } - /// Chain an additional service that will accept requests matching the given /// `MethodFilter`. /// @@ -964,17 +924,14 @@ where ) -> MethodRouter where L: Layer> + Clone + Send + 'static, - L::Service: Service, Error = NewError> + Clone + Send + 'static, + L::Service: Service> + Clone + Send + 'static, >>::Response: IntoResponse + 'static, + >>::Error: Into + 'static, >>::Future: Send + 'static, E: 'static, S: 'static, { - let layer_fn = move |svc| { - let svc = layer.layer(svc); - let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc); - Route::new(svc) - }; + let layer_fn = move |route: Route| route.layer(layer.clone()); MethodRouter { get: self.get.map(layer_fn.clone()), @@ -1182,7 +1139,7 @@ where enum MethodEndpoint { None, Route(Route), - BoxedHandler(BoxedHandler), + BoxedHandler(BoxedIntoRoute), } impl MethodEndpoint @@ -1213,32 +1170,6 @@ where } } - fn map_state(self, state: &S) -> MethodEndpoint { - match self { - Self::None => MethodEndpoint::None, - Self::Route(route) => MethodEndpoint::Route(route), - Self::BoxedHandler(handler) => MethodEndpoint::Route(handler.into_route(state.clone())), - } - } - - fn downcast_state(self) -> Option> - where - S: 'static, - B: 'static, - E: 'static, - S2: 'static, - { - match self { - Self::None => Some(MethodEndpoint::None), - Self::Route(route) => Some(MethodEndpoint::Route(route)), - Self::BoxedHandler(handler) => { - try_downcast::, BoxedHandler>(handler) - .map(MethodEndpoint::BoxedHandler) - .ok() - } - } - } - fn into_route(self, state: &S) -> Option> { match self { Self::None => None, diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 88a23127..5eaf1c0d 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -1,27 +1,19 @@ //! Routing between [`Service`]s and handlers. -use self::not_found::NotFound; +use self::{not_found::NotFound, strip_prefix::StripPrefix}; #[cfg(feature = "tokio")] use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::{ body::{Body, HttpBody}, - handler::{BoxedHandler, Handler}, + boxed::BoxedIntoRoute, + handler::Handler, util::try_downcast, }; use axum_core::response::{IntoResponse, Response}; use http::Request; use matchit::MatchError; -use std::{ - any::{type_name, TypeId}, - collections::HashMap, - convert::Infallible, - fmt, - sync::Arc, -}; -use tower::{ - util::{BoxCloneService, MapResponseLayer, Oneshot}, - ServiceBuilder, -}; +use std::{collections::HashMap, convert::Infallible, fmt, sync::Arc}; +use tower::util::{BoxCloneService, Oneshot}; use tower_layer::Layer; use tower_service::Service; @@ -68,19 +60,14 @@ impl RouteId { /// The router type for composing handlers and services. pub struct Router { - state: Option, routes: HashMap>, node: Arc, fallback: Fallback, } -impl Clone for Router -where - S: Clone, -{ +impl Clone for Router { fn clone(&self) -> Self { Self { - state: self.state.clone(), routes: self.routes.clone(), node: Arc::clone(&self.node), fallback: self.fallback.clone(), @@ -91,10 +78,10 @@ where impl Default for Router where B: HttpBody + Send + 'static, - S: Default + Clone + Send + Sync + 'static, + S: Clone + Send + Sync + 'static, { fn default() -> Self { - Self::with_state(S::default()) + Self::new() } } @@ -104,7 +91,6 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Router") - .field("state", &self.state) .field("routes", &self.routes) .field("node", &self.node) .field("fallback", &self.fallback) @@ -115,71 +101,19 @@ where pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param"; pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param"; -impl Router<(), B> -where - B: HttpBody + Send + 'static, -{ - /// Create a new `Router`. - /// - /// Unless you add additional routes this will respond with `404 Not Found` to - /// all requests. - pub fn new() -> Self { - Self::with_state(()) - } -} +impl Router<(), B> where B: HttpBody + Send + 'static {} impl Router where B: HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { - /// Create a new `Router` with the given state. - /// - /// See [`State`](crate::extract::State) for more details about accessing state. + /// Create a new `Router`. /// /// Unless you add additional routes this will respond with `404 Not Found` to /// all requests. - pub fn with_state(state: S) -> Self { + pub fn new() -> Self { Self { - state: Some(state), - routes: Default::default(), - node: Default::default(), - fallback: Fallback::Default(Route::new(NotFound)), - } - } - - /// Create a new `Router` that inherits its state from another `Router` that it is merged into - /// or nested under. - /// - /// # Example - /// - /// ``` - /// use axum::{Router, routing::get, extract::State}; - /// - /// #[derive(Clone)] - /// struct AppState {} - /// - /// // A router that will be nested under the `app` router. - /// // - /// // By using `inherit_state` we'll reuse the state from the `app` router. - /// let nested_router = Router::inherit_state() - /// .route("/bar", get(|state: State| async {})); - /// - /// // A router that will be merged into the `app` router. - /// let merged_router = Router::inherit_state() - /// .route("/baz", get(|state: State| async {})); - /// - /// let app = Router::with_state(AppState {}) - /// .route("/", get(|state: State| async {})) - /// .nest("/foo", nested_router) - /// .merge(merged_router); - /// - /// // `app` now has routes for `/`, `/foo/bar`, and `/baz` that all use the same state. - /// # let _: Router = app; - /// ``` - pub fn inherit_state() -> Self { - Self { - state: None, routes: Default::default(), node: Default::default(), fallback: Fallback::Default(Route::new(NotFound)), @@ -228,18 +162,12 @@ where } #[doc = include_str!("../docs/routing/route_service.md")] - pub fn route_service(mut self, path: &str, service: T) -> Self + pub fn route_service(self, path: &str, service: T) -> Self where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { - if path.is_empty() { - panic!("Paths must start with a `/`. Use \"/\" for root routes"); - } else if !path.starts_with('/') { - panic!("Paths must start with a `/`"); - } - let service = match try_downcast::, _>(service) { Ok(_) => { panic!( @@ -250,11 +178,20 @@ where Err(svc) => svc, }; + self.route_endpoint(path, Endpoint::Route(Route::new(service))) + } + + #[track_caller] + fn route_endpoint(mut self, path: &str, endpoint: Endpoint) -> Self { + if path.is_empty() { + panic!("Paths must start with a `/`. Use \"/\" for root routes"); + } else if !path.starts_with('/') { + panic!("Paths must start with a `/`"); + } + let id = RouteId::next(); - let endpoint = Endpoint::Route(Route::new(service)); self.set_node(path, id); self.routes.insert(id, endpoint); - self } @@ -270,29 +207,27 @@ where #[doc = include_str!("../docs/routing/nest.md")] #[track_caller] - pub fn nest(self, path: &str, mut router: Router) -> Self - where - S2: Clone + Send + Sync + 'static, - { - if router.state.is_none() { - let s = self.state.clone(); - router.state = match try_downcast::, Option>(s) { - Ok(state) => state, - Err(_) => panic!( - "can't nest a `Router` that wants to inherit state of type `{}` \ - into a `Router` with a state type of `{}`", - type_name::(), - type_name::(), - ), - }; - } - - self.nest_service(path, router.into_service()) + pub fn nest(self, path: &str, router: Router) -> Self { + self.nest_endpoint(path, RouterOrService::<_, _, NotFound>::Router(router)) } /// Like [`nest`](Self::nest), but accepts an arbitrary `Service`. #[track_caller] - pub fn nest_service(mut self, mut path: &str, svc: T) -> Self + pub fn nest_service(self, path: &str, svc: T) -> Self + where + T: Service, Error = Infallible> + Clone + Send + 'static, + T::Response: IntoResponse, + T::Future: Send + 'static, + { + self.nest_endpoint(path, RouterOrService::Service(svc)) + } + + #[track_caller] + fn nest_endpoint( + mut self, + mut path: &str, + router_or_service: RouterOrService, + ) -> Self where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, @@ -315,16 +250,27 @@ where format!("{path}/*{NEST_TAIL_PARAM}") }; - let svc = strip_prefix::StripPrefix::new(svc, prefix); - self = self.route_service(&path, svc.clone()); + let endpoint = match router_or_service { + RouterOrService::Router(router) => { + let prefix = prefix.to_owned(); + let boxed = BoxedIntoRoute::from_router(router) + .map(move |route| Route::new(StripPrefix::new(route, &prefix))); + Endpoint::NestedRouter(boxed) + } + RouterOrService::Service(svc) => { + Endpoint::Route(Route::new(StripPrefix::new(svc, prefix))) + } + }; + + self = self.route_endpoint(&path, endpoint.clone()); // `/*rest` is not matched by `/` so we need to also register a router at the // prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself // wouldn't match, which it should - self = self.route_service(prefix, svc.clone()); + self = self.route_endpoint(prefix, endpoint.clone()); if !prefix.ends_with('/') { // same goes for `/foo/`, that should also match - self = self.route_service(&format!("{prefix}/"), svc); + self = self.route_endpoint(&format!("{prefix}/"), endpoint); } self @@ -332,66 +278,27 @@ where #[doc = include_str!("../docs/routing/merge.md")] #[track_caller] - pub fn merge(mut self, other: R) -> Self + pub fn merge(mut self, other: R) -> Self where - R: Into>, - S2: Clone + Send + Sync + 'static, + R: Into>, { let Router { - state, routes, node, fallback, } = other.into(); - let cast_method_router_closure_slot; - let (fallback, cast_method_router) = match state { - // other has its state set - Some(state) => { - let fallback = fallback.map_state(&state); - cast_method_router_closure_slot = move |r: MethodRouter<_, _>| r.map_state(&state); - let cast_method_router = &cast_method_router_closure_slot - as &dyn Fn(MethodRouter<_, _>) -> MethodRouter<_, _>; - - (fallback, cast_method_router) - } - // other wants to inherit its state - None => { - if TypeId::of::() != TypeId::of::() { - panic!( - "can't merge a `Router` that wants to inherit state of type `{}` \ - into a `Router` with a state type of `{}`", - type_name::(), - type_name::(), - ); - } - - // With the branch above not taken, we know we can cast S2 to S - let fallback = fallback.downcast_state::().unwrap(); - - fn cast_method_router(r: MethodRouter) -> MethodRouter - where - B: Send + 'static, - S: 'static, - S2: Clone + 'static, - { - r.downcast_state().unwrap() - } - - (fallback, &cast_method_router as _) - } - }; - for (id, route) in routes { let path = node .route_id_to_path .get(&id) .expect("no path for route id. This is a bug in axum. Please file an issue"); self = match route { - Endpoint::MethodRouter(method_router) => { - self.route(path, cast_method_router(method_router)) - } + Endpoint::MethodRouter(method_router) => self.route(path, method_router), Endpoint::Route(route) => self.route_service(path, route), + Endpoint::NestedRouter(router) => { + self.route_endpoint(path, Endpoint::NestedRouter(router)) + } }; } @@ -412,30 +319,18 @@ where >>::Error: Into + 'static, >>::Future: Send + 'static, { - let layer = ServiceBuilder::new() - .map_err(Into::into) - .layer(MapResponseLayer::new(IntoResponse::into_response)) - .layer(layer) - .into_inner(); - let routes = self .routes .into_iter() - .map(|(id, route)| { - let route = match route { - Endpoint::MethodRouter(method_router) => { - Endpoint::MethodRouter(method_router.layer(layer.clone())) - } - Endpoint::Route(route) => Endpoint::Route(Route::new(layer.layer(route))), - }; + .map(|(id, endpoint)| { + let route = endpoint.layer(layer.clone()); (id, route) }) .collect(); - let fallback = self.fallback.map(move |svc| Route::new(layer.layer(svc))); + let fallback = self.fallback.map(|route| route.layer(layer)); Router { - state: self.state, routes, node: self.node, fallback, @@ -459,28 +354,16 @@ where ); } - let layer = ServiceBuilder::new() - .map_err(Into::into) - .layer(MapResponseLayer::new(IntoResponse::into_response)) - .layer(layer) - .into_inner(); - let routes = self .routes .into_iter() - .map(|(id, route)| { - let route = match route { - Endpoint::MethodRouter(method_router) => { - Endpoint::MethodRouter(method_router.layer(layer.clone())) - } - Endpoint::Route(route) => Endpoint::Route(Route::new(layer.layer(route))), - }; + .map(|(id, endpoint)| { + let route = endpoint.layer(layer.clone()); (id, route) }) .collect(); Router { - state: self.state, routes, node: self.node, fallback: self.fallback, @@ -493,7 +376,7 @@ where H: Handler, T: 'static, { - self.fallback = Fallback::BoxedHandler(BoxedHandler::new(handler)); + self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler)); self } @@ -510,14 +393,26 @@ where self } + /// Convert this router into a [`RouterService`] by providing the state. + /// + /// Once this method has been called you cannot add more routes. So it must be called as last. + pub fn with_state(self, state: S) -> RouterService { + RouterService::new(self, state) + } +} + +impl Router<(), B> +where + B: HttpBody + Send + 'static, +{ /// Convert this router into a [`RouterService`]. /// - /// # Panics + /// This is a convenience method for routers that don't have any state (i.e. the state type is + /// `()`). Use [`Router::with_state`] otherwise. /// - /// Panics if the router was constructed with [`Router::inherit_state`]. - #[track_caller] + /// Once this method has been called you cannot add more routes. So it must be called as last. pub fn into_service(self) -> RouterService { - RouterService::new(self) + RouterService::new(self, ()) } /// Convert this router into a [`MakeService`], that is a [`Service`] whose @@ -542,8 +437,10 @@ where /// # }; /// ``` /// + /// This is a convenience method for routers that don't have any state (i.e. the state type is + /// `()`). Use [`RouterService::into_make_service`] otherwise. + /// /// [`MakeService`]: tower::make::MakeService - #[track_caller] pub fn into_make_service(self) -> IntoMakeService> { IntoMakeService::new(self.into_service()) } @@ -601,39 +498,13 @@ impl fmt::Debug for Node { enum Fallback { Default(Route), Service(Route), - BoxedHandler(BoxedHandler), + BoxedHandler(BoxedIntoRoute), } impl Fallback where S: Clone, { - fn map_state(self, state: &S) -> Fallback { - match self { - Self::Default(route) => Fallback::Default(route), - Self::Service(route) => Fallback::Service(route), - Self::BoxedHandler(handler) => Fallback::Service(handler.into_route(state.clone())), - } - } - - fn downcast_state(self) -> Option> - where - S: 'static, - B: 'static, - E: 'static, - S2: 'static, - { - match self { - Self::Default(route) => Some(Fallback::Default(route)), - Self::Service(route) => Some(Fallback::Service(route)), - Self::BoxedHandler(handler) => { - try_downcast::, BoxedHandler>(handler) - .map(Fallback::BoxedHandler) - .ok() - } - } - } - fn merge(self, other: Self) -> Option { match (self, other) { (Self::Default(_), pick @ Self::Default(_)) => Some(pick), @@ -729,6 +600,41 @@ impl FallbackRoute { enum Endpoint { MethodRouter(MethodRouter), Route(Route), + NestedRouter(BoxedIntoRoute), +} + +impl Endpoint +where + B: HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, +{ + fn into_route(self, state: S) -> Route { + match self { + Endpoint::MethodRouter(method_router) => Route::new(method_router.with_state(state)), + Endpoint::Route(route) => route, + Endpoint::NestedRouter(router) => router.into_route(state), + } + } + + fn layer(self, layer: L) -> Endpoint + where + L: Layer> + Clone + Send + 'static, + L::Service: Service> + Clone + Send + 'static, + >>::Response: IntoResponse + 'static, + >>::Error: Into + 'static, + >>::Future: Send + 'static, + NewReqBody: 'static, + { + match self { + Endpoint::MethodRouter(method_router) => { + Endpoint::MethodRouter(method_router.layer(layer)) + } + Endpoint::Route(route) => Endpoint::Route(route.layer(layer)), + Endpoint::NestedRouter(router) => { + Endpoint::NestedRouter(router.map(|route| route.layer(layer))) + } + } + } } impl Clone for Endpoint { @@ -736,19 +642,31 @@ impl Clone for Endpoint { match self { Self::MethodRouter(inner) => Self::MethodRouter(inner.clone()), Self::Route(inner) => Self::Route(inner.clone()), + Self::NestedRouter(router) => Self::NestedRouter(router.clone()), } } } -impl fmt::Debug for Endpoint { +impl fmt::Debug for Endpoint +where + S: fmt::Debug, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::MethodRouter(inner) => inner.fmt(f), - Self::Route(inner) => inner.fmt(f), + Self::MethodRouter(method_router) => { + f.debug_tuple("MethodRouter").field(method_router).finish() + } + Self::Route(route) => f.debug_tuple("Route").field(route).finish(), + Self::NestedRouter(router) => f.debug_tuple("NestedRouter").field(router).finish(), } } } +enum RouterOrService { + Router(Router), + Service(T), +} + #[test] #[allow(warnings)] fn traits() { diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 4b735a8d..25c8b859 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -17,9 +17,10 @@ use std::{ task::{Context, Poll}, }; use tower::{ - util::{BoxCloneService, Oneshot}, - ServiceExt, + util::{BoxCloneService, MapResponseLayer, Oneshot}, + ServiceBuilder, ServiceExt, }; +use tower_layer::Layer; use tower_service::Service; /// How routes are stored inside a [`Router`](super::Router). @@ -46,6 +47,25 @@ impl Route { ) -> Oneshot, Response, E>, Request> { self.0.clone().oneshot(req) } + + pub(crate) fn layer(self, layer: L) -> Route + where + L: Layer> + Clone + Send + 'static, + L::Service: Service> + Clone + Send + 'static, + >>::Response: IntoResponse + 'static, + >>::Error: Into + 'static, + >>::Future: Send + 'static, + NewReqBody: 'static, + NewError: 'static, + { + let layer = ServiceBuilder::new() + .map_err(Into::into) + .layer(MapResponseLayer::new(IntoResponse::into_response)) + .layer(layer) + .into_inner(); + + Route::new(layer.layer(self)) + } } impl Clone for Route { diff --git a/axum/src/routing/service.rs b/axum/src/routing/service.rs index f19f3856..f3ba1f9e 100644 --- a/axum/src/routing/service.rs +++ b/axum/src/routing/service.rs @@ -1,5 +1,5 @@ use super::{ - future::RouteFuture, url_params, Endpoint, FallbackRoute, Node, Route, RouteId, Router, + future::RouteFuture, url_params, FallbackRoute, IntoMakeService, Node, Route, RouteId, Router, }; use crate::{ body::{Body, HttpBody}, @@ -28,26 +28,17 @@ impl RouterService where B: HttpBody + Send + 'static, { - #[track_caller] - pub(super) fn new(router: Router) -> Self + pub(super) fn new(router: Router, state: S) -> Self where S: Clone + Send + Sync + 'static, { - let state = router - .state - .expect("Can't turn a `Router` that wants to inherit state into a service"); + let fallback = router.fallback.into_fallback_route(&state); let routes = router .routes .into_iter() .map(|(route_id, endpoint)| { - let route = match endpoint { - Endpoint::MethodRouter(method_router) => { - Route::new(method_router.with_state(state.clone())) - } - Endpoint::Route(route) => route, - }; - + let route = endpoint.into_route(state.clone()); (route_id, route) }) .collect(); @@ -55,7 +46,7 @@ where Self { routes, node: router.node, - fallback: router.fallback.into_fallback_route(&state), + fallback, } } @@ -84,6 +75,28 @@ where route.call(req) } + + /// Convert the router into a [`MakeService`] and no state. + /// + /// See [`Router::into_make_service`] for more details. + /// + /// [`MakeService`]: tower::make::MakeService + pub fn into_make_service(self) -> IntoMakeService> { + IntoMakeService::new(self) + } + + /// Convert the router into a [`MakeService`] which stores information + /// about the incoming connection and has no state. + /// + /// See [`Router::into_make_service_with_connect_info`] for more details. + /// + /// [`MakeService`]: tower::make::MakeService + #[cfg(feature = "tokio")] + pub fn into_make_service_with_connect_info( + self, + ) -> crate::extract::connect_info::IntoMakeServiceWithConnectInfo, C> { + crate::extract::connect_info::IntoMakeServiceWithConnectInfo::new(self) + } } impl Clone for RouterService { diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index b9993ab8..923e10d1 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -1,3 +1,5 @@ +use tower::ServiceExt; + use super::*; use crate::middleware::{map_request, map_response}; @@ -50,10 +52,11 @@ async fn or() { #[tokio::test] async fn fallback_accessing_state() { - let app = Router::with_state("state") - .fallback(|State(state): State<&'static str>| async move { state }); + let app = Router::new() + .fallback(|State(state): State<&'static str>| async move { state }) + .with_state("state"); - let client = TestClient::new(app); + let client = TestClient::from_service(app); let res = client.get("/does-not-exist").send().await; assert_eq!(res.status(), StatusCode::OK); @@ -158,3 +161,45 @@ async fn also_inherits_default_layered_fallback() { assert_eq!(res.headers()["x-from-fallback"], "1"); assert_eq!(res.text().await, "outer"); } + +#[tokio::test] +async fn fallback_inherited_into_nested_router_service() { + let inner = Router::new() + .route( + "/bar", + get(|State(state): State<&'static str>| async move { state }), + ) + .with_state("inner"); + + // with a different state + let app = Router::<()>::new() + .nest_service("/foo", inner) + .fallback(outer_fallback); + + let client = TestClient::new(app); + let res = client.get("/foo/not-found").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.text().await, "outer"); +} + +#[tokio::test] +async fn fallback_inherited_into_nested_opaque_service() { + let inner = Router::new() + .route( + "/bar", + get(|State(state): State<&'static str>| async move { state }), + ) + .with_state("inner") + // even if the service is made more opaque it should still inherit the fallback + .boxed_clone(); + + // with a different state + let app = Router::<()>::new() + .nest_service("/foo", inner) + .fallback(outer_fallback); + + let client = TestClient::new(app); + let res = client.get("/foo/not-found").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.text().await, "outer"); +} diff --git a/axum/src/routing/tests/merge.rs b/axum/src/routing/tests/merge.rs index 804549d3..abad660a 100644 --- a/axum/src/routing/tests/merge.rs +++ b/axum/src/routing/tests/merge.rs @@ -397,87 +397,3 @@ async fn middleware_that_return_early() { ); assert_eq!(client.get("/public").send().await.status(), StatusCode::OK); } - -#[tokio::test] -async fn merge_with_different_state_type() { - let inner = Router::with_state("inner".to_owned()).route( - "/foo", - get(|State(state): State| async move { state }), - ); - - let app = Router::with_state("outer").merge(inner).route( - "/bar", - get(|State(state): State<&'static str>| async move { state }), - ); - - let client = TestClient::new(app); - - let res = client.get("/foo").send().await; - assert_eq!(res.text().await, "inner"); - - let res = client.get("/bar").send().await; - assert_eq!(res.text().await, "outer"); -} - -#[tokio::test] -async fn merging_routes_different_method_different_states() { - let get = Router::with_state("get state").route( - "/", - get(|State(state): State<&'static str>| async move { state }), - ); - - let post = Router::with_state("post state").route( - "/", - post(|State(state): State<&'static str>| async move { state }), - ); - - let app = Router::new().merge(get).merge(post); - - let client = TestClient::new(app); - - let res = client.get("/").send().await; - assert_eq!(res.text().await, "get state"); - - let res = client.post("/").send().await; - assert_eq!(res.text().await, "post state"); -} - -#[tokio::test] -async fn merging_routes_different_paths_different_states() { - let foo = Router::with_state("foo state").route( - "/foo", - get(|State(state): State<&'static str>| async move { state }), - ); - - let bar = Router::with_state("bar state").route( - "/bar", - get(|State(state): State<&'static str>| async move { state }), - ); - - let app = Router::new().merge(foo).merge(bar); - - let client = TestClient::new(app); - - let res = client.get("/foo").send().await; - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await, "foo state"); - - let res = client.get("/bar").send().await; - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await, "bar state"); -} - -#[tokio::test] -async fn inherit_state_via_merge() { - let foo = Router::inherit_state().route( - "/foo", - get(|State(state): State<&'static str>| async move { state }), - ); - - let app = Router::with_state("state").merge(foo); - let client = TestClient::new(app); - - let res = client.get("/foo").send().await; - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await, "state"); -} diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 0f31256b..43db1512 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -760,8 +760,8 @@ async fn extract_state() { inner: InnerState { value: 2 }, }; - let app = Router::with_state(state).route("/", get(handler)); - let client = TestClient::new(app); + let app = Router::new().route("/", get(handler)).with_state(state); + let client = TestClient::from_service(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::OK); @@ -769,12 +769,14 @@ async fn extract_state() { #[tokio::test] async fn explicitly_set_state() { - let app = Router::with_state("...").route_service( - "/", - get(|State(state): State<&'static str>| async move { state }).with_state("foo"), - ); + let app = Router::new() + .route_service( + "/", + get(|State(state): State<&'static str>| async move { state }).with_state("foo"), + ) + .with_state("..."); - let client = TestClient::new(app); + let client = TestClient::from_service(app); let res = client.get("/").send().await; assert_eq!(res.text().await, "foo"); } diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index 29e82860..4daa8f67 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -265,7 +265,7 @@ async fn multiple_top_level_nests() { #[tokio::test] #[should_panic(expected = "Invalid route: nested routes cannot contain wildcards (*)")] async fn nest_cannot_contain_wildcards() { - Router::<_, Body>::new().nest("/one/*rest", Router::new()); + Router::<(), Body>::new().nest("/one/*rest", Router::new()); } #[tokio::test] @@ -424,48 +424,3 @@ nested_route_test!(nest_9, nest = "/a", route = "/a/", expected = "/a/a/"); nested_route_test!(nest_11, nest = "/a/", route = "/", expected = "/a/"); nested_route_test!(nest_12, nest = "/a/", route = "/a", expected = "/a/a"); nested_route_test!(nest_13, nest = "/a/", route = "/a/", expected = "/a/a/"); - -#[tokio::test] -async fn nesting_with_different_state() { - let inner = Router::with_state("inner".to_owned()).route( - "/foo", - get(|State(state): State| async move { state }), - ); - - let outer = Router::with_state("outer") - .route( - "/foo", - get(|State(state): State<&'static str>| async move { state }), - ) - .nest("/nested", inner) - .route( - "/bar", - get(|State(state): State<&'static str>| async move { state }), - ); - - let client = TestClient::new(outer); - - let res = client.get("/foo").send().await; - assert_eq!(res.text().await, "outer"); - - let res = client.get("/nested/foo").send().await; - assert_eq!(res.text().await, "inner"); - - let res = client.get("/bar").send().await; - assert_eq!(res.text().await, "outer"); -} - -#[tokio::test] -async fn inherit_state_via_nest() { - let foo = Router::inherit_state().route( - "/foo", - get(|State(state): State<&'static str>| async move { state }), - ); - - let app = Router::with_state("state").nest("/test", foo); - let client = TestClient::new(app); - - let res = client.get("/test/foo").send().await; - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await, "state"); -} diff --git a/axum/src/test_helpers/test_client.rs b/axum/src/test_helpers/test_client.rs index 26abe0c0..296b8131 100644 --- a/axum/src/test_helpers/test_client.rs +++ b/axum/src/test_helpers/test_client.rs @@ -15,10 +15,7 @@ pub(crate) struct TestClient { } impl TestClient { - pub(crate) fn new(router: Router) -> Self - where - S: Clone + Send + Sync + 'static, - { + pub(crate) fn new(router: Router<(), Body>) -> Self { Self::from_service(router.into_service()) } diff --git a/examples/async-graphql/src/main.rs b/examples/async-graphql/src/main.rs index a8d84cb9..6dd4cc2c 100644 --- a/examples/async-graphql/src/main.rs +++ b/examples/async-graphql/src/main.rs @@ -34,7 +34,9 @@ async fn main() { .data(StarWars::new()) .finish(); - let app = Router::with_state(schema).route("/", get(graphql_playground).post(graphql_handler)); + let app = Router::new() + .route("/", get(graphql_playground).post(graphql_handler)) + .with_state(schema); println!("Playground: http://localhost:3000"); diff --git a/examples/chat/src/main.rs b/examples/chat/src/main.rs index 092a7d96..32b38613 100644 --- a/examples/chat/src/main.rs +++ b/examples/chat/src/main.rs @@ -44,9 +44,10 @@ async fn main() { let app_state = Arc::new(AppState { user_set, tx }); - let app = Router::with_state(app_state) + let app = Router::new() .route("/", get(index)) - .route("/websocket", get(websocket_handler)); + .route("/websocket", get(websocket_handler)) + .with_state(app_state); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); diff --git a/examples/error-handling-and-dependency-injection/src/main.rs b/examples/error-handling-and-dependency-injection/src/main.rs index 32f57049..af897e3a 100644 --- a/examples/error-handling-and-dependency-injection/src/main.rs +++ b/examples/error-handling-and-dependency-injection/src/main.rs @@ -36,9 +36,10 @@ async fn main() { let user_repo = Arc::new(ExampleUserRepo) as DynUserRepo; // Build our application with some routes - let app = Router::with_state(user_repo) + let app = Router::new() .route("/users/:id", get(users_show)) - .route("/users", post(users_create)); + .route("/users", post(users_create)) + .with_state(user_repo); // Run our application let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index e176436b..334e0e25 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -43,7 +43,7 @@ async fn main() { let shared_state = SharedState::default(); // Build our application by composing routes - let app = Router::with_state(Arc::clone(&shared_state)) + let app = Router::new() .route( "/:key", // Add compression to `kv_get` @@ -60,7 +60,7 @@ async fn main() { ) .route("/keys", get(list_keys)) // Nest our admin routes under `/admin` - .nest("/admin", admin_routes(shared_state)) + .nest("/admin", admin_routes()) // Add middleware to all routes .layer( ServiceBuilder::new() @@ -69,9 +69,9 @@ async fn main() { .load_shed() .concurrency_limit(1024) .timeout(Duration::from_secs(10)) - .layer(TraceLayer::new_for_http()) - .into_inner(), - ); + .layer(TraceLayer::new_for_http()), + ) + .with_state(Arc::clone(&shared_state)); // Run our app with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -115,7 +115,7 @@ async fn list_keys(State(state): State) -> String { .join("\n") } -fn admin_routes(state: SharedState) -> Router { +fn admin_routes() -> Router { async fn delete_all_keys(State(state): State) { state.write().unwrap().db.clear(); } @@ -124,7 +124,7 @@ fn admin_routes(state: SharedState) -> Router { state.write().unwrap().db.remove(&key); } - Router::with_state(state) + Router::new() .route("/keys", delete(delete_all_keys)) .route("/key/:key", delete(remove_key)) // Require bearer auth for all admin routes diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index 67327184..ea692df8 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -47,12 +47,13 @@ async fn main() { oauth_client, }; - let app = Router::with_state(app_state) + let app = Router::new() .route("/", get(index)) .route("/auth/discord", get(discord_auth)) .route("/auth/authorized", get(login_authorized)) .route("/protected", get(protected)) - .route("/logout", get(logout)); + .route("/logout", get(logout)) + .with_state(app_state); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); diff --git a/examples/reverse-proxy/src/main.rs b/examples/reverse-proxy/src/main.rs index af74ea12..d8cabb79 100644 --- a/examples/reverse-proxy/src/main.rs +++ b/examples/reverse-proxy/src/main.rs @@ -24,7 +24,7 @@ async fn main() { let client = Client::new(); - let app = Router::with_state(client).route("/", get(handler)); + let app = Router::new().route("/", get(handler)).with_state(client); let addr = SocketAddr::from(([127, 0, 0, 1], 4000)); println!("reverse proxy listening on {}", addr); diff --git a/examples/sessions/src/main.rs b/examples/sessions/src/main.rs index 60ebfd21..716ee41a 100644 --- a/examples/sessions/src/main.rs +++ b/examples/sessions/src/main.rs @@ -39,7 +39,7 @@ async fn main() { // `MemoryStore` just used as an example. Don't use this in production. let store = MemoryStore::new(); - let app = Router::with_state(store).route("/", get(handler)); + let app = Router::new().route("/", get(handler)).with_state(store); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); diff --git a/examples/sqlx-postgres/src/main.rs b/examples/sqlx-postgres/src/main.rs index 9ba41ed8..0330edab 100644 --- a/examples/sqlx-postgres/src/main.rs +++ b/examples/sqlx-postgres/src/main.rs @@ -46,10 +46,12 @@ async fn main() { .expect("can connect to database"); // build our application with some routes - let app = Router::with_state(pool).route( - "/", - get(using_connection_pool_extractor).post(using_connection_extractor), - ); + let app = Router::new() + .route( + "/", + get(using_connection_pool_extractor).post(using_connection_extractor), + ) + .with_state(pool); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs index f323ebb3..98910c5a 100644 --- a/examples/todos/src/main.rs +++ b/examples/todos/src/main.rs @@ -46,7 +46,7 @@ async fn main() { let db = Db::default(); // Compose the routes - let app = Router::with_state(db) + let app = Router::new() .route("/todos", get(todos_index).post(todos_create)) .route("/todos/:id", patch(todos_update).delete(todos_delete)) // Add middleware to all routes @@ -65,7 +65,8 @@ async fn main() { .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) .into_inner(), - ); + ) + .with_state(db); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); diff --git a/examples/tokio-postgres/src/main.rs b/examples/tokio-postgres/src/main.rs index 5e60c2bb..faef52b5 100644 --- a/examples/tokio-postgres/src/main.rs +++ b/examples/tokio-postgres/src/main.rs @@ -33,10 +33,12 @@ async fn main() { let pool = Pool::builder().build(manager).await.unwrap(); // build our application with some routes - let app = Router::with_state(pool).route( - "/", - get(using_connection_pool_extractor).post(using_connection_extractor), - ); + let app = Router::new() + .route( + "/", + get(using_connection_pool_extractor).post(using_connection_extractor), + ) + .with_state(pool); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000));