From 69d64cecc3707d4c7897e7f8abb2111f305c5419 Mon Sep 17 00:00:00 2001 From: Jonas Platte <jplatte+git@posteo.de> Date: Thu, 22 Sep 2022 12:10:55 +0200 Subject: [PATCH] Split RouterService off of Router (#1381) --- axum-extra/src/extract/cookie/mod.rs | 7 +- axum-extra/src/lib.rs | 2 +- axum-extra/src/routing/resource.rs | 6 +- axum/src/docs/middleware.md | 2 +- axum/src/docs/routing/nest.md | 14 +- axum/src/docs/routing/route_service.md | 2 +- axum/src/extract/matched_path.rs | 15 +- axum/src/extract/request_parts.rs | 4 +- axum/src/handler/mod.rs | 2 +- axum/src/lib.rs | 2 +- axum/src/middleware/from_fn.rs | 1 + axum/src/routing/mod.rs | 143 +++------------ axum/src/routing/service.rs | 164 ++++++++++++++++++ axum/src/routing/tests/fallback.rs | 5 +- axum/src/routing/tests/get_to_head.rs | 2 + axum/src/routing/tests/merge.rs | 63 ++++--- axum/src/routing/tests/mod.rs | 9 +- axum/src/routing/tests/nest.rs | 113 +++++++----- axum/src/test_helpers/mod.rs | 2 +- axum/src/test_helpers/test_client.rs | 11 +- examples/handle-head-request/src/main.rs | 6 +- examples/http-proxy/src/main.rs | 8 +- examples/key-value-store/src/main.rs | 2 +- .../src/main.rs | 1 + examples/rest-grpc-multiplex/src/main.rs | 2 +- examples/testing/src/main.rs | 8 +- 26 files changed, 371 insertions(+), 225 deletions(-) create mode 100644 axum/src/routing/service.rs diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index 20d015d1..afe2451c 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -258,7 +258,8 @@ mod tests { let app = Router::<_, Body>::with_state(state) .route("/set", get(set_cookie)) .route("/get", get(get_cookie)) - .route("/remove", get(remove_cookie)); + .route("/remove", get(remove_cookie)) + .into_service(); let res = app .clone() @@ -344,7 +345,9 @@ mod tests { custom_key: CustomKey(Key::generate()), }; - let app = Router::<_, Body>::with_state(state).route("/get", get(get_cookie)); + let app = Router::<_, Body>::with_state(state) + .route("/get", get(get_cookie)) + .into_service(); let res = app .clone() diff --git a/axum-extra/src/lib.rs b/axum-extra/src/lib.rs index d20bae20..3138cac5 100644 --- a/axum-extra/src/lib.rs +++ b/axum-extra/src/lib.rs @@ -95,7 +95,7 @@ pub mod __private { pub(crate) mod test_helpers { #![allow(unused_imports)] - use axum::{body::HttpBody, BoxError}; + use axum::{body::HttpBody, BoxError, Router}; mod test_client { #![allow(dead_code)] diff --git a/axum-extra/src/routing/resource.rs b/axum-extra/src/routing/resource.rs index c15f94d8..5d5150e0 100644 --- a/axum-extra/src/routing/resource.rs +++ b/axum-extra/src/routing/resource.rs @@ -159,7 +159,7 @@ impl<B> From<Resource<B>> for Router<B> { mod tests { #[allow(unused_imports)] use super::*; - use axum::{extract::Path, http::Method, Router}; + use axum::{extract::Path, http::Method, routing::RouterService, Router}; use http::Request; use tower::{Service, ServiceExt}; @@ -174,7 +174,7 @@ mod tests { .update(|Path(id): Path<u64>| async move { format!("users#update id={}", id) }) .destroy(|Path(id): Path<u64>| async move { format!("users#destroy id={}", id) }); - let mut app = Router::new().merge(users); + let mut app = Router::new().merge(users).into_service(); assert_eq!( call_route(&mut app, Method::GET, "/users").await, @@ -217,7 +217,7 @@ mod tests { ); } - async fn call_route(app: &mut Router<()>, method: Method, uri: &str) -> String { + async fn call_route(app: &mut RouterService, method: Method, uri: &str) -> String { let res = app .ready() .await diff --git a/axum/src/docs/middleware.md b/axum/src/docs/middleware.md index d070d1ec..c3579c13 100644 --- a/axum/src/docs/middleware.md +++ b/axum/src/docs/middleware.md @@ -555,7 +555,7 @@ async fn rewrite_request_uri<B>(req: Request<B>, next: Next<B>) -> Response { // this can be any `tower::Layer` let middleware = axum::middleware::from_fn(rewrite_request_uri); -let app = Router::new(); +let app = Router::new().into_service(); // apply the layer around the whole `Router` // this way the middleware will run before `Router` receives the request diff --git a/axum/src/docs/routing/nest.md b/axum/src/docs/routing/nest.md index 79558927..f60a83d8 100644 --- a/axum/src/docs/routing/nest.md +++ b/axum/src/docs/routing/nest.md @@ -16,10 +16,10 @@ let user_routes = Router::new().route("/:id", get(|| async {})); let team_routes = Router::new().route("/", post(|| async {})); let api_routes = Router::new() - .nest("/users", user_routes) - .nest("/teams", team_routes); + .nest("/users", user_routes.into_service()) + .nest("/teams", team_routes.into_service()); -let app = Router::new().nest("/api", api_routes); +let app = Router::new().nest("/api", api_routes.into_service()); // Our app now accepts // - GET /api/users/:id @@ -58,7 +58,7 @@ async fn users_get(Path(params): Path<HashMap<String, String>>) { let users_api = Router::new().route("/users/:id", get(users_get)); -let app = Router::new().nest("/:version/api", users_api); +let app = Router::new().nest("/:version/api", users_api.into_service()); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; @@ -82,7 +82,7 @@ let app = Router::new() .route("/foo/*rest", get(|uri: Uri| async { // `uri` will contain `/foo` })) - .nest("/bar", nested_router); + .nest("/bar", nested_router.into_service()); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); # }; @@ -103,7 +103,7 @@ async fn fallback() -> (StatusCode, &'static str) { let api_routes = Router::new().nest("/users", get(|| async {})); let app = Router::new() - .nest("/api", api_routes) + .nest("/api", api_routes.into_service()) .fallback(fallback); # let _: Router = app; ``` @@ -135,7 +135,7 @@ let api_routes = Router::new() .fallback(api_fallback); let app = Router::new() - .nest("/api", api_routes) + .nest("/api", api_routes.into_service()) .fallback(fallback); # let _: Router = app; ``` diff --git a/axum/src/docs/routing/route_service.md b/axum/src/docs/routing/route_service.md index a14323a9..623c6cb6 100644 --- a/axum/src/docs/routing/route_service.md +++ b/axum/src/docs/routing/route_service.md @@ -69,7 +69,7 @@ use axum::{routing::get, Router}; let app = Router::new().route_service( "/", - Router::new().route("/foo", get(|| async {})), + Router::new().route("/foo", get(|| async {})).into_service(), ); # async { # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index 6af783d9..4472012a 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -148,7 +148,7 @@ mod tests { "/:key", get(|path: MatchedPath| async move { path.as_str().to_owned() }), ) - .nest("/api", api) + .nest("/api", api.into_service()) .nest( "/public", Router::new() @@ -156,7 +156,8 @@ mod tests { // have to set the middleware here since otherwise the // matched path is just `/public/*` since we're nesting // this router - .layer(layer_fn(SetMatchedPathExtension)), + .layer(layer_fn(SetMatchedPathExtension)) + .into_service(), ) .nest("/foo", handler.into_service()) .layer(layer_fn(SetMatchedPathExtension)); @@ -197,10 +198,12 @@ mod tests { async fn nested_opaque_routers_append_to_matched_path() { let app = Router::new().nest( "/:a", - Router::new().route( - "/:b", - get(|path: MatchedPath| async move { path.as_str().to_owned() }), - ), + Router::new() + .route( + "/:b", + get(|path: MatchedPath| async move { path.as_str().to_owned() }), + ) + .into_service(), ); let client = TestClient::new(app); diff --git a/axum/src/extract/request_parts.rs b/axum/src/extract/request_parts.rs index 933e7a71..bf35fe75 100644 --- a/axum/src/extract/request_parts.rs +++ b/axum/src/extract/request_parts.rs @@ -38,7 +38,7 @@ use sync_wrapper::SyncWrapper; /// }), /// ); /// -/// let app = Router::new().nest("/api", api_routes); +/// let app = Router::new().nest("/api", api_routes.into_service()); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; @@ -75,7 +75,7 @@ use sync_wrapper::SyncWrapper; /// }), /// ); /// -/// let app = Router::new().nest("/api", api_routes); +/// let app = Router::new().nest("/api", api_routes.into_service()); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 785d9277..3fa51ca0 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -375,7 +375,7 @@ mod tests { format!("you said: {}", body) } - let client = TestClient::new(handle.into_service()); + let client = TestClient::from_service(handle.into_service()); let res = client.post("/").body("hi there!").send().await; assert_eq!(res.status(), StatusCode::OK); diff --git a/axum/src/lib.rs b/axum/src/lib.rs index d04d5805..171df9e2 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -470,7 +470,7 @@ pub use self::extension::Extension; #[cfg(feature = "json")] pub use self::json::Json; #[doc(inline)] -pub use self::routing::Router; +pub use self::routing::{Router, RouterService}; #[doc(inline)] #[cfg(feature = "headers")] diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index 7fb93474..23d25acf 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -404,6 +404,7 @@ mod tests { .layer(from_fn(insert_header)); let res = app + .into_service() .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 1a107560..12cc507b 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -1,24 +1,17 @@ //! Routing between [`Service`]s and handlers. -use self::{future::RouteFuture, not_found::NotFound}; +use self::not_found::NotFound; use crate::{ body::{Body, HttpBody}, extract::connect_info::IntoMakeServiceWithConnectInfo, handler::Handler, - response::Response, util::try_downcast, Extension, }; use axum_core::response::IntoResponse; use http::Request; use matchit::MatchError; -use std::{ - collections::HashMap, - convert::Infallible, - fmt, - sync::Arc, - task::{Context, Poll}, -}; +use std::{collections::HashMap, convert::Infallible, fmt, sync::Arc}; use tower::{util::MapResponseLayer, ServiceBuilder}; use tower_layer::Layer; use tower_service::Service; @@ -33,10 +26,14 @@ mod route; mod strip_prefix; pub(crate) mod url_params; +mod service; #[cfg(test)] mod tests; -pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route}; +pub use self::{ + into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route, + service::RouterService, +}; pub use self::method_routing::{ any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service, @@ -226,9 +223,12 @@ where panic!("Paths must start with a `/`"); } - let service = match try_downcast::<Router<S, B>, _>(service) { + let service = match try_downcast::<RouterService<B>, _>(service) { Ok(_) => { - panic!("Invalid route: `Router::route_service` cannot be used with `Router`s. Use `Router::nest` instead") + panic!( + "Invalid route: `Router::route_service` cannot be used with `RouterService`s. \ + Use `Router::nest` instead" + ); } Err(svc) => svc, }; @@ -438,6 +438,11 @@ where self } + /// Convert this router into a [`RouterService`]. + pub fn into_service(self) -> RouterService<B> { + RouterService::new(self) + } + /// Convert this router into a [`MakeService`], that is a [`Service`] whose /// response is another service. /// @@ -461,73 +466,15 @@ where /// ``` /// /// [`MakeService`]: tower::make::MakeService - pub fn into_make_service(self) -> IntoMakeService<Self> { - IntoMakeService::new(self) + pub fn into_make_service(self) -> IntoMakeService<RouterService<B>> { + IntoMakeService::new(self.into_service()) } #[doc = include_str!("../docs/routing/into_make_service_with_connect_info.md")] - pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> { - IntoMakeServiceWithConnectInfo::new(self) - } - - #[inline] - fn call_route( - &self, - match_: matchit::Match<&RouteId>, - mut req: Request<B>, - ) -> RouteFuture<B, Infallible> { - let id = *match_.value; - - #[cfg(feature = "matched-path")] - { - fn set_matched_path( - id: RouteId, - route_id_to_path: &HashMap<RouteId, Arc<str>>, - extensions: &mut http::Extensions, - ) { - if let Some(matched_path) = route_id_to_path.get(&id) { - use crate::extract::MatchedPath; - - let matched_path = if let Some(previous) = extensions.get::<MatchedPath>() { - // a previous `MatchedPath` might exist if we're inside a nested Router - let previous = if let Some(previous) = - previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE) - { - previous - } else { - previous.as_str() - }; - - let matched_path = format!("{}{}", previous, matched_path); - matched_path.into() - } else { - Arc::clone(matched_path) - }; - extensions.insert(MatchedPath(matched_path)); - } else { - #[cfg(debug_assertions)] - panic!("should always have a matched path for a route id"); - } - } - - set_matched_path(id, &self.node.route_id_to_path, req.extensions_mut()); - } - - url_params::insert_url_params(req.extensions_mut(), match_.params); - - let mut route = self - .routes - .get(&id) - .expect("no route for id. This is a bug in axum. Please file an issue") - .clone(); - - match &mut route { - Endpoint::MethodRouter(inner) => inner - .clone() - .with_state_arc(Arc::clone(&self.state)) - .call(req), - Endpoint::Route(inner) => inner.call(req), - } + pub fn into_make_service_with_connect_info<C>( + self, + ) -> IntoMakeServiceWithConnectInfo<RouterService<B>, C> { + IntoMakeServiceWithConnectInfo::new(self.into_service()) } /// Get a reference to the state. @@ -536,48 +483,6 @@ where } } -impl<S, B> Service<Request<B>> for Router<S, B> -where - B: HttpBody + Send + 'static, - S: Send + Sync + 'static, -{ - type Response = Response; - type Error = Infallible; - type Future = RouteFuture<B, Infallible>; - - #[inline] - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { - Poll::Ready(Ok(())) - } - - #[inline] - fn call(&mut self, mut req: Request<B>) -> Self::Future { - #[cfg(feature = "original-uri")] - { - use crate::extract::OriginalUri; - - if req.extensions().get::<OriginalUri>().is_none() { - let original_uri = OriginalUri(req.uri().clone()); - req.extensions_mut().insert(original_uri); - } - } - - let path = req.uri().path().to_owned(); - - match self.node.at(&path) { - Ok(match_) => self.call_route(match_, req), - Err( - MatchError::NotFound - | MatchError::ExtraTrailingSlash - | MatchError::MissingTrailingSlash, - ) => match &self.fallback { - Fallback::Default(inner) => inner.clone().call(req), - Fallback::Custom(inner) => inner.clone().call(req), - }, - } - } -} - /// Wrapper around `matchit::Router` that supports merging two `Router`s. #[derive(Clone, Default)] struct Node { @@ -665,7 +570,7 @@ impl<B, E> Fallback<B, E> { } enum Endpoint<S, B> { - MethodRouter(MethodRouter<S, B, Infallible>), + MethodRouter(MethodRouter<S, B>), Route(Route<B>), } diff --git a/axum/src/routing/service.rs b/axum/src/routing/service.rs new file mode 100644 index 00000000..6b0b4ac2 --- /dev/null +++ b/axum/src/routing/service.rs @@ -0,0 +1,164 @@ +use std::{ + collections::HashMap, + convert::Infallible, + sync::Arc, + task::{Context, Poll}, +}; + +use http::Request; +use matchit::MatchError; +use tower::Service; + +use super::{ + future::RouteFuture, url_params, Endpoint, Fallback, Node, Route, RouteId, Router, + NEST_TAIL_PARAM_CAPTURE, +}; +use crate::{ + body::{Body, HttpBody}, + response::Response, +}; + +/// TOOD: Docs +#[derive(Debug)] +pub struct RouterService<B = Body> { + routes: HashMap<RouteId, Route<B>>, + node: Arc<Node>, + fallback: Route<B>, +} + +impl<B> RouterService<B> +where + B: HttpBody + Send + 'static, +{ + pub(super) fn new<S>(router: Router<S, B>) -> Self + where + S: Send + Sync + 'static, + { + let routes = router + .routes + .into_iter() + .map(|(route_id, endpoint)| { + let route = match endpoint { + Endpoint::MethodRouter(method_router) => { + Route::new(method_router.with_state_arc(Arc::clone(&router.state))) + } + Endpoint::Route(route) => route, + }; + + (route_id, route) + }) + .collect(); + + Self { + routes, + node: router.node, + fallback: match router.fallback { + Fallback::Default(route) => route, + Fallback::Custom(route) => route, + }, + } + } + + #[inline] + fn call_route( + &self, + match_: matchit::Match<&RouteId>, + mut req: Request<B>, + ) -> RouteFuture<B, Infallible> { + let id = *match_.value; + + #[cfg(feature = "matched-path")] + { + fn set_matched_path( + id: RouteId, + route_id_to_path: &HashMap<RouteId, Arc<str>>, + extensions: &mut http::Extensions, + ) { + if let Some(matched_path) = route_id_to_path.get(&id) { + use crate::extract::MatchedPath; + + let matched_path = if let Some(previous) = extensions.get::<MatchedPath>() { + // a previous `MatchedPath` might exist if we're inside a nested Router + let previous = if let Some(previous) = + previous.as_str().strip_suffix(NEST_TAIL_PARAM_CAPTURE) + { + previous + } else { + previous.as_str() + }; + + let matched_path = format!("{}{}", previous, matched_path); + matched_path.into() + } else { + Arc::clone(matched_path) + }; + extensions.insert(MatchedPath(matched_path)); + } else { + #[cfg(debug_assertions)] + panic!("should always have a matched path for a route id"); + } + } + + set_matched_path(id, &self.node.route_id_to_path, req.extensions_mut()); + } + + url_params::insert_url_params(req.extensions_mut(), match_.params); + + let mut route = self + .routes + .get(&id) + .expect("no route for id. This is a bug in axum. Please file an issue") + .clone(); + + route.call(req) + } +} + +impl<B> Clone for RouterService<B> { + fn clone(&self) -> Self { + Self { + routes: self.routes.clone(), + node: Arc::clone(&self.node), + fallback: self.fallback.clone(), + } + } +} + +impl<B> Service<Request<B>> for RouterService<B> +where + B: HttpBody + Send + 'static, + //S: Send + Sync + 'static, +{ + type Response = Response; + type Error = Infallible; + type Future = RouteFuture<B, Infallible>; + + #[inline] + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + #[inline] + fn call(&mut self, mut req: Request<B>) -> Self::Future { + #[cfg(feature = "original-uri")] + { + use crate::extract::OriginalUri; + + if req.extensions().get::<OriginalUri>().is_none() { + let original_uri = OriginalUri(req.uri().clone()); + req.extensions_mut().insert(original_uri); + } + } + + let path = req.uri().path().to_owned(); + + match self.node.at(&path) { + Ok(match_) => self.call_route(match_, req), + Err( + MatchError::NotFound + | MatchError::ExtraTrailingSlash + | MatchError::MissingTrailingSlash, + ) => self.fallback.clone().call(req), + } + } +} diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index 4da166ba..d9b27da0 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -18,7 +18,10 @@ async fn basic() { #[tokio::test] async fn nest() { let app = Router::new() - .nest("/foo", Router::new().route("/bar", get(|| async {}))) + .nest( + "/foo", + Router::new().route("/bar", get(|| async {})).into_service(), + ) .fallback(|| async { "fallback" }); let client = TestClient::new(app); diff --git a/axum/src/routing/tests/get_to_head.rs b/axum/src/routing/tests/get_to_head.rs index 21888e6e..f0cd201c 100644 --- a/axum/src/routing/tests/get_to_head.rs +++ b/axum/src/routing/tests/get_to_head.rs @@ -19,6 +19,7 @@ mod for_handlers { // don't use reqwest because it always strips bodies from HEAD responses let res = app + .into_service() .oneshot( Request::builder() .uri("/") @@ -54,6 +55,7 @@ mod for_services { // don't use reqwest because it always strips bodies from HEAD responses let res = app + .into_service() .oneshot( Request::builder() .uri("/") diff --git a/axum/src/routing/tests/merge.rs b/axum/src/routing/tests/merge.rs index 5f67dcd7..7ff34a77 100644 --- a/axum/src/routing/tests/merge.rs +++ b/axum/src/routing/tests/merge.rs @@ -1,8 +1,5 @@ use super::*; -use crate::{ - body::HttpBody, error_handling::HandleErrorLayer, extract::OriginalUri, response::IntoResponse, - Json, -}; +use crate::{error_handling::HandleErrorLayer, extract::OriginalUri, response::IntoResponse, Json}; use serde_json::{json, Value}; use tower::{limit::ConcurrencyLimitLayer, timeout::TimeoutLayer}; @@ -62,15 +59,7 @@ async fn multiple_ors_balanced_differently() { test("four", one.merge(two.merge(three.merge(four)))).await; - async fn test<S, ResBody>(name: &str, app: S) - where - S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static, - ResBody: HttpBody + Send + 'static, - ResBody::Data: Send, - ResBody::Error: Into<BoxError>, - S::Future: Send, - S::Error: Into<BoxError>, - { + async fn test(name: &str, app: Router) { let client = TestClient::new(app); for n in ["one", "two", "three", "four"].iter() { @@ -93,7 +82,7 @@ async fn nested_or() { assert_eq!(client.get("/bar").send().await.text().await, "bar"); assert_eq!(client.get("/baz").send().await.text().await, "baz"); - let client = TestClient::new(Router::new().nest("/foo", bar_or_baz)); + let client = TestClient::new(Router::new().nest("/foo", bar_or_baz.into_service())); assert_eq!(client.get("/foo/bar").send().await.text().await, "bar"); assert_eq!(client.get("/foo/baz").send().await.text().await, "baz"); } @@ -156,7 +145,10 @@ async fn layer_and_handle_error() { #[tokio::test] async fn nesting() { let one = Router::new().route("/foo", get(|| async {})); - let two = Router::new().nest("/bar", Router::new().route("/baz", get(|| async {}))); + let two = Router::new().nest( + "/bar", + Router::new().route("/baz", get(|| async {})).into_service(), + ); let app = one.merge(two); let client = TestClient::new(app); @@ -240,7 +232,12 @@ async fn all_the_uris( #[tokio::test] async fn nesting_and_seeing_the_right_uri() { - let one = Router::new().nest("/foo/", Router::new().route("/bar", get(all_the_uris))); + let one = Router::new().nest( + "/foo/", + Router::new() + .route("/bar", get(all_the_uris)) + .into_service(), + ); let two = Router::new().route("/foo", get(all_the_uris)); let client = TestClient::new(one.merge(two)); @@ -272,7 +269,14 @@ async fn nesting_and_seeing_the_right_uri() { async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() { let one = Router::new().nest( "/foo/", - Router::new().nest("/bar", Router::new().route("/baz", get(all_the_uris))), + Router::new() + .nest( + "/bar", + Router::new() + .route("/baz", get(all_the_uris)) + .into_service(), + ) + .into_service(), ); let two = Router::new().route("/foo", get(all_the_uris)); @@ -305,9 +309,21 @@ async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() { async fn nesting_and_seeing_the_right_uri_ors_with_nesting() { let one = Router::new().nest( "/one", - Router::new().nest("/bar", Router::new().route("/baz", get(all_the_uris))), + Router::new() + .nest( + "/bar", + Router::new() + .route("/baz", get(all_the_uris)) + .into_service(), + ) + .into_service(), + ); + let two = Router::new().nest( + "/two", + Router::new() + .route("/qux", get(all_the_uris)) + .into_service(), ); - let two = Router::new().nest("/two", Router::new().route("/qux", get(all_the_uris))); let three = Router::new().route("/three", get(all_the_uris)); let client = TestClient::new(one.merge(two).merge(three)); @@ -350,7 +366,14 @@ async fn nesting_and_seeing_the_right_uri_ors_with_nesting() { async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() { let one = Router::new().nest( "/one", - Router::new().nest("/foo", Router::new().route("/bar", get(all_the_uris))), + Router::new() + .nest( + "/foo", + Router::new() + .route("/bar", get(all_the_uris)) + .into_service(), + ) + .into_service(), ); let two = Router::new().route("/two/foo", get(all_the_uris)); diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 4d4c58cd..c61efbf4 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -444,11 +444,12 @@ async fn middleware_still_run_for_unmatched_requests() { } #[tokio::test] -#[should_panic( - expected = "Invalid route: `Router::route_service` cannot be used with `Router`s. Use `Router::nest` instead" -)] +#[should_panic(expected = "\ + Invalid route: `Router::route_service` cannot be used with `RouterService`s. \ + Use `Router::nest` instead\ +")] async fn routing_to_router_panics() { - TestClient::new(Router::new().route_service("/", Router::new())); + TestClient::new(Router::new().route_service("/", Router::new().into_service())); } #[tokio::test] diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index f856b6e8..a0243f1c 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -37,7 +37,7 @@ async fn nesting_apps() { let app = Router::new() .route("/", get(|| async { "hi" })) - .nest("/:version/api", api_routes); + .nest("/:version/api", api_routes.into_service()); let client = TestClient::new(app); @@ -61,7 +61,7 @@ async fn nesting_apps() { #[tokio::test] async fn wrong_method_nest() { let nested_app = Router::new().route("/", get(|| async {})); - let app = Router::new().nest("/", nested_app); + let app = Router::new().nest("/", nested_app.into_service()); let client = TestClient::new(app); @@ -78,7 +78,7 @@ async fn wrong_method_nest() { #[tokio::test] async fn nesting_router_at_root() { let nested = Router::new().route("/foo", get(|uri: Uri| async move { uri.to_string() })); - let app = Router::new().nest("/", nested); + let app = Router::new().nest("/", nested.into_service()); let client = TestClient::new(app); @@ -96,7 +96,7 @@ async fn nesting_router_at_root() { #[tokio::test] async fn nesting_router_at_empty_path() { let nested = Router::new().route("/foo", get(|uri: Uri| async move { uri.to_string() })); - let app = Router::new().nest("", nested); + let app = Router::new().nest("", nested.into_service()); let client = TestClient::new(app); @@ -134,15 +134,18 @@ async fn nesting_handler_at_root() { async fn nested_url_extractor() { let app = Router::new().nest( "/foo", - Router::new().nest( - "/bar", - Router::new() - .route("/baz", get(|uri: Uri| async move { uri.to_string() })) - .route( - "/qux", - get(|req: Request<Body>| async move { req.uri().to_string() }), - ), - ), + Router::new() + .nest( + "/bar", + Router::new() + .route("/baz", get(|uri: Uri| async move { uri.to_string() })) + .route( + "/qux", + get(|req: Request<Body>| async move { req.uri().to_string() }), + ) + .into_service(), + ) + .into_service(), ); let client = TestClient::new(app); @@ -160,13 +163,17 @@ async fn nested_url_extractor() { async fn nested_url_original_extractor() { let app = Router::new().nest( "/foo", - Router::new().nest( - "/bar", - Router::new().route( - "/baz", - get(|uri: extract::OriginalUri| async move { uri.0.to_string() }), - ), - ), + Router::new() + .nest( + "/bar", + Router::new() + .route( + "/baz", + get(|uri: extract::OriginalUri| async move { uri.0.to_string() }), + ) + .into_service(), + ) + .into_service(), ); let client = TestClient::new(app); @@ -180,16 +187,20 @@ async fn nested_url_original_extractor() { async fn nested_service_sees_stripped_uri() { let app = Router::new().nest( "/foo", - Router::new().nest( - "/bar", - Router::new().route_service( - "/baz", - service_fn(|req: Request<Body>| async move { - let body = boxed(Body::from(req.uri().to_string())); - Ok::<_, Infallible>(Response::new(body)) - }), - ), - ), + Router::new() + .nest( + "/bar", + Router::new() + .route_service( + "/baz", + service_fn(|req: Request<Body>| async move { + let body = boxed(Body::from(req.uri().to_string())); + Ok::<_, Infallible>(Response::new(body)) + }), + ) + .into_service(), + ) + .into_service(), ); let client = TestClient::new(app); @@ -224,7 +235,8 @@ async fn nested_multiple_routes() { "/api", Router::new() .route("/users", get(|| async { "users" })) - .route("/teams", get(|| async { "teams" })), + .route("/teams", get(|| async { "teams" })) + .into_service(), ) .route("/", get(|| async { "root" })); @@ -239,7 +251,12 @@ async fn nested_multiple_routes() { #[should_panic = "Invalid route \"/\": insertion failed due to conflict with previously registered route: /*__private__axum_nest_tail_param"] fn nested_at_root_with_other_routes() { let _: Router = Router::new() - .nest("/", Router::new().route("/users", get(|| async {}))) + .nest( + "/", + Router::new() + .route("/users", get(|| async {})) + .into_service(), + ) .route("/", get(|| async {})); } @@ -248,11 +265,15 @@ async fn multiple_top_level_nests() { let app = Router::new() .nest( "/one", - Router::new().route("/route", get(|| async { "one" })), + Router::new() + .route("/route", get(|| async { "one" })) + .into_service(), ) .nest( "/two", - Router::new().route("/route", get(|| async { "two" })), + Router::new() + .route("/route", get(|| async { "two" })) + .into_service(), ); let client = TestClient::new(app); @@ -264,7 +285,7 @@ async fn multiple_top_level_nests() { #[tokio::test] #[should_panic(expected = "Invalid route: nested routes cannot contain wildcards (*)")] async fn nest_cannot_contain_wildcards() { - Router::<_, Body>::new().nest("/one/*rest", Router::new()); + Router::<_, Body>::new().nest("/one/*rest", Router::new().into_service()); } #[tokio::test] @@ -302,7 +323,10 @@ async fn outer_middleware_still_see_whole_url() { .route("/", get(handler)) .route("/foo", get(handler)) .route("/foo/bar", get(handler)) - .nest("/one", Router::new().route("/two", get(handler))) + .nest( + "/one", + Router::new().route("/two", get(handler)).into_service(), + ) .fallback(handler) .layer(tower::layer::layer_fn(SetUriExtension)); @@ -325,6 +349,7 @@ async fn nest_at_capture() { "/:b", get(|Path((a, b)): Path<(String, String)>| async move { format!("a={} b={}", a, b) }), ) + .into_service() .boxed_clone(); let app = Router::new().nest("/:a", api_routes); @@ -355,7 +380,10 @@ async fn nest_with_and_without_trailing() { #[tokio::test] async fn doesnt_call_outer_fallback() { let app = Router::new() - .nest("/foo", Router::new().route("/", get(|| async {}))) + .nest( + "/foo", + Router::new().route("/", get(|| async {})).into_service(), + ) .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") }); let client = TestClient::new(app); @@ -373,7 +401,9 @@ async fn doesnt_call_outer_fallback() { async fn nesting_with_root_inner_router() { let app = Router::new().nest( "/foo", - Router::new().route("/", get(|| async { "inner route" })), + Router::new() + .route("/", get(|| async { "inner route" })) + .into_service(), ); let client = TestClient::new(app); @@ -396,7 +426,8 @@ async fn fallback_on_inner() { "/foo", Router::new() .route("/", get(|| async {})) - .fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }), + .fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }) + .into_service(), ) .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") }); @@ -420,7 +451,7 @@ macro_rules! nested_route_test { #[tokio::test] async fn $name() { let inner = Router::new().route($route_path, get(|| async {})); - let app = Router::new().nest($nested_path, inner); + let app = Router::new().nest($nested_path, inner.into_service()); let client = TestClient::new(app); let res = client.get($expected_path).send().await; let status = res.status(); @@ -455,7 +486,7 @@ async fn nesting_with_different_state() { "/foo", get(|State(state): State<&'static str>| async move { state }), ) - .nest("/nested", inner) + .nest("/nested", inner.into_service()) .route( "/bar", get(|State(state): State<&'static str>| async move { state }), diff --git a/axum/src/test_helpers/mod.rs b/axum/src/test_helpers/mod.rs index 60ff6575..f588d920 100644 --- a/axum/src/test_helpers/mod.rs +++ b/axum/src/test_helpers/mod.rs @@ -1,6 +1,6 @@ #![allow(clippy::blacklisted_name)] -use crate::{body::HttpBody, BoxError}; +use crate::{body::HttpBody, BoxError, Router}; mod test_client; pub(crate) use self::test_client::*; diff --git a/axum/src/test_helpers/test_client.rs b/axum/src/test_helpers/test_client.rs index 45a72b6c..74d82623 100644 --- a/axum/src/test_helpers/test_client.rs +++ b/axum/src/test_helpers/test_client.rs @@ -1,4 +1,4 @@ -use super::{BoxError, HttpBody}; +use super::{BoxError, HttpBody, Router}; use bytes::Bytes; use http::{ header::{HeaderName, HeaderValue}, @@ -15,7 +15,14 @@ pub(crate) struct TestClient { } impl TestClient { - pub(crate) fn new<S, ResBody>(svc: S) -> Self + pub(crate) fn new<S>(router: Router<S, Body>) -> Self + where + S: Send + Sync + 'static, + { + Self::from_service(router.into_service()) + } + + pub(crate) fn from_service<S, ResBody>(svc: S) -> Self where S: Service<Request<Body>, Response = http::Response<ResBody>> + Clone + Send + 'static, ResBody: HttpBody + Send + 'static, diff --git a/examples/handle-head-request/src/main.rs b/examples/handle-head-request/src/main.rs index 492d3425..aece6c25 100644 --- a/examples/handle-head-request/src/main.rs +++ b/examples/handle-head-request/src/main.rs @@ -5,7 +5,7 @@ //! ``` use axum::response::{IntoResponse, Response}; -use axum::{http, routing::get, Router}; +use axum::{http, routing::get, Router, RouterService}; use std::net::SocketAddr; fn app() -> Router { @@ -50,7 +50,7 @@ mod tests { #[tokio::test] async fn test_get() { - let app = app(); + let app = app().into_service(); let response = app .oneshot(Request::get("/get-head").body(Body::empty()).unwrap()) @@ -66,7 +66,7 @@ mod tests { #[tokio::test] async fn test_implicit_head() { - let app = app(); + let app = app().into_service(); let response = app .oneshot(Request::head("/get-head").body(Body::empty()).unwrap()) diff --git a/examples/http-proxy/src/main.rs b/examples/http-proxy/src/main.rs index a4b412bb..e57a73f4 100644 --- a/examples/http-proxy/src/main.rs +++ b/examples/http-proxy/src/main.rs @@ -35,15 +35,17 @@ async fn main() { .with(tracing_subscriber::fmt::layer()) .init(); - let router = Router::new().route("/", get(|| async { "Hello, World!" })); + let router_svc = Router::new() + .route("/", get(|| async { "Hello, World!" })) + .into_service(); let service = tower::service_fn(move |req: Request<Body>| { - let router = router.clone(); + let router_svc = router_svc.clone(); async move { if req.method() == Method::CONNECT { proxy(req).await } else { - router.oneshot(req).await.map_err(|err| match err {}) + router_svc.oneshot(req).await.map_err(|err| match err {}) } } }); diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index 4fb9e45b..4f46af74 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -52,7 +52,7 @@ async fn main() { ) .route("/keys", get(list_keys)) // Nest our admin routes under `/admin` - .nest("/admin", admin_routes(shared_state)) + .nest("/admin", admin_routes(shared_state).into_service()) // Add middleware to all routes .layer( ServiceBuilder::new() diff --git a/examples/query-params-with-empty-strings/src/main.rs b/examples/query-params-with-empty-strings/src/main.rs index 0af20111..d5f2ba2d 100644 --- a/examples/query-params-with-empty-strings/src/main.rs +++ b/examples/query-params-with-empty-strings/src/main.rs @@ -104,6 +104,7 @@ mod tests { async fn send_request_get_body(query: &str) -> String { let body = app() + .into_service() .oneshot( Request::builder() .uri(format!("/?{}", query)) diff --git a/examples/rest-grpc-multiplex/src/main.rs b/examples/rest-grpc-multiplex/src/main.rs index 8376fcb0..7a0ea3d6 100644 --- a/examples/rest-grpc-multiplex/src/main.rs +++ b/examples/rest-grpc-multiplex/src/main.rs @@ -55,7 +55,7 @@ async fn main() { .init(); // build the rest service - let rest = Router::new().route("/", get(web_root)); + let rest = Router::new().route("/", get(web_root)).into_service(); // build the grpc service let grpc = GreeterServer::new(GrpcServiceImpl::default()); diff --git a/examples/testing/src/main.rs b/examples/testing/src/main.rs index 0bb9b352..2671df6d 100644 --- a/examples/testing/src/main.rs +++ b/examples/testing/src/main.rs @@ -61,7 +61,7 @@ mod tests { #[tokio::test] async fn hello_world() { - let app = app(); + let app = app().into_service(); // `Router` implements `tower::Service<Request<Body>>` so we can // call it like any tower service, no need to run an HTTP server. @@ -78,7 +78,7 @@ mod tests { #[tokio::test] async fn json() { - let app = app(); + let app = app().into_service(); let response = app .oneshot( @@ -103,7 +103,7 @@ mod tests { #[tokio::test] async fn not_found() { - let app = app(); + let app = app().into_service(); let response = app .oneshot( @@ -154,7 +154,7 @@ mod tests { // in multiple request #[tokio::test] async fn multiple_request() { - let mut app = app(); + let mut app = app().into_service(); let request = Request::builder().uri("/").body(Body::empty()).unwrap(); let response = app.ready().await.unwrap().call(request).await.unwrap();