Change Router::with_state and impl Service for Router<()> (#1552)

* Implement `Service` for `Router<(), B>`

* wip

* wip

* fix some tests

* fix examples

* fix doc tests

* clean up docs

* changelog

* fix

* also call `with_state` when converting `MethodRouter` into a `MakeService`

* suggestions from review
This commit is contained in:
David Pedersen 2022-11-24 15:43:10 +01:00 committed by GitHub
parent fde38f6618
commit 0b26411f39
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
45 changed files with 576 additions and 738 deletions

View file

@ -68,7 +68,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// .route("/set", post(set_secret))
/// .route("/get", get(get_secret))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// # let _: axum::Router = app;
/// ```
///
/// If you have been using `Arc<AppState>` you cannot implement `FromRef<Arc<AppState>> for Key`.

View file

@ -86,7 +86,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// .route("/sessions", post(create_session))
/// .route("/me", get(me))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// # let _: axum::Router = app;
/// ```
/// If you have been using `Arc<AppState>` you cannot implement `FromRef<Arc<AppState>> for Key`.
/// You can use a new type instead:

View file

@ -147,7 +147,7 @@ impl<B> From<Resource<B>> for Router<B> {
mod tests {
#[allow(unused_imports)]
use super::*;
use axum::{extract::Path, http::Method, routing::RouterService, Router};
use axum::{extract::Path, http::Method, Router};
use http::Request;
use tower::{Service, ServiceExt};
@ -162,7 +162,7 @@ mod tests {
.update(|Path(id): Path<u64>| async move { format!("users#update id={}", id) })
.destroy(|Path(id): Path<u64>| async move { format!("users#destroy id={}", id) });
let mut app = Router::new().merge(users).into_service();
let mut app = Router::new().merge(users);
assert_eq!(
call_route(&mut app, Method::GET, "/users").await,
@ -205,7 +205,7 @@ mod tests {
);
}
async fn call_route(app: &mut RouterService, method: Method, uri: &str) -> String {
async fn call_route(app: &mut Router, method: Method, uri: &str) -> String {
let res = app
.ready()
.await

View file

@ -270,7 +270,7 @@ mod tests {
#[allow(dead_code)]
fn works_with_router_with_state() {
let _: axum::RouterService = Router::new()
let _: Router = Router::new()
.merge(SpaRouter::new("/assets", "test_files"))
.route("/", get(|_: axum::extract::State<String>| async {}))
.with_state(String::new());

View file

@ -606,7 +606,7 @@ pub fn derive_typed_path(input: TokenStream) -> TokenStream {
/// let app = Router::new()
/// .route("/", get(handler).post(other_handler))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// # let _: axum::Router = app;
/// ```
///
/// [`FromRef`]: https://docs.rs/axum/latest/axum/extract/trait.FromRef.html

View file

@ -14,7 +14,7 @@ fn main() {
auth_token: Default::default(),
};
let _: axum::routing::RouterService = Router::new()
let _: axum::Router = Router::new()
.route("/", get(handler))
.with_state(state);
}

View file

@ -6,7 +6,7 @@ use axum::{
use axum_macros::FromRequest;
fn main() {
let _: axum::routing::RouterService = Router::new()
let _: axum::Router = Router::new()
.route("/a", get(|_: AppState| async {}))
.route("/b", get(|_: InnerState| async {}))
.with_state(AppState::default());

View file

@ -6,7 +6,7 @@ use axum::{
use axum_macros::FromRequestParts;
fn main() {
let _: axum::routing::RouterService = Router::new()
let _: axum::Router = Router::new()
.route("/a", get(|_: AppState| async {}))
.route("/b", get(|_: InnerState| async {}))
.route("/c", get(|_: AppState, _: InnerState| async {}))

View file

@ -6,7 +6,7 @@ use axum::{
};
fn main() {
let _: axum::routing::RouterService = Router::new()
let _: axum::Router = Router::new()
.route("/b", get(|_: Extractor| async {}))
.with_state(AppState::default());
}

View file

@ -7,7 +7,7 @@ use axum::{
use std::collections::HashMap;
fn main() {
let _: axum::routing::RouterService = Router::new()
let _: axum::Router = Router::new()
.route("/b", get(|_: Extractor| async {}))
.with_state(AppState::default());
}

View file

@ -6,7 +6,7 @@ use axum::{
use axum_macros::FromRequest;
fn main() {
let _: axum::routing::RouterService = Router::new()
let _: axum::Router = Router::new()
.route("/", get(|_: Extractor| async {}))
.with_state(AppState::default());
}

View file

@ -6,7 +6,7 @@ use axum::{
use axum_macros::FromRequest;
fn main() {
let _: axum::routing::RouterService = Router::new()
let _: axum::Router = Router::new()
.route("/", get(|_: Extractor| async {}))
.with_state(AppState::default());
}

View file

@ -6,7 +6,7 @@ use axum::{
use axum_macros::FromRequest;
fn main() {
let _: axum::routing::RouterService = Router::new()
let _: axum::Router = Router::new()
.route("/b", get(|_: (), _: AppState| async {}))
.route("/c", get(|_: (), _: InnerState| async {}))
.with_state(AppState::default());

View file

@ -6,7 +6,7 @@ use axum::{
use axum_macros::FromRequest;
fn main() {
let _: axum::routing::RouterService = Router::new()
let _: axum::Router = Router::new()
.route("/b", get(|_: AppState| async {}))
.with_state(AppState::default());
}

View file

@ -6,7 +6,7 @@ use axum::{
use axum_macros::FromRequestParts;
fn main() {
let _: axum::routing::RouterService = Router::new()
let _: axum::Router = Router::new()
.route("/a", get(|_: AppState, _: InnerState, _: String| async {}))
.route("/b", get(|_: AppState, _: String| async {}))
.route("/c", get(|_: InnerState, _: String| async {}))

View file

@ -8,7 +8,7 @@ use axum::{
use axum_macros::FromRequest;
fn main() {
let _: axum::routing::RouterService = Router::new()
let _: axum::Router = Router::new()
.route("/a", get(|_: Extractor| async {}))
.with_state(AppState::default());
}

View file

@ -7,9 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- **added:** Add `RouterService::{layer, route_layer}` ([#1550])
- **breaking:** `RouterService` has been removed since `Router` now implements
`Service` when the state is `()`. Use `Router::with_state` to provide the
state and get a `Router<()>` ([#1552])
[#1550]: https://github.com/tokio-rs/axum/pull/1550
[#1552]: https://github.com/tokio-rs/axum/pull/1552
# 0.6.0-rc.5 (18. November, 2022)

View file

@ -1,7 +1,7 @@
use axum::{
extract::State,
routing::{get, post},
Extension, Json, Router, RouterService, Server,
Extension, Json, Router, Server,
};
use hyper::server::conn::AddrIncoming;
use serde::{Deserialize, Serialize};
@ -17,13 +17,9 @@ fn main() {
ensure_rewrk_is_installed();
}
benchmark("minimal").run(|| Router::new().into_service());
benchmark("minimal").run(Router::new);
benchmark("basic").run(|| {
Router::new()
.route("/", get(|| async { "Hello, World!" }))
.into_service()
});
benchmark("basic").run(|| Router::new().route("/", get(|| async { "Hello, World!" })));
benchmark("routing").path("/foo/bar/baz").run(|| {
let mut app = Router::new();
@ -34,32 +30,26 @@ fn main() {
}
}
}
app.route("/foo/bar/baz", get(|| async {})).into_service()
app.route("/foo/bar/baz", get(|| async {}))
});
benchmark("receive-json")
.method("post")
.headers(&[("content-type", "application/json")])
.body(r#"{"n": 123, "s": "hi there", "b": false}"#)
.run(|| {
Router::new()
.route("/", post(|_: Json<Payload>| async {}))
.into_service()
});
.run(|| Router::new().route("/", post(|_: Json<Payload>| async {})));
benchmark("send-json").run(|| {
Router::new()
.route(
"/",
get(|| async {
Json(Payload {
n: 123,
s: "hi there".to_owned(),
b: false,
})
}),
)
.into_service()
Router::new().route(
"/",
get(|| async {
Json(Payload {
n: 123,
s: "hi there".to_owned(),
b: false,
})
}),
)
});
let state = AppState {
@ -75,7 +65,6 @@ fn main() {
Router::new()
.route("/", get(|_: Extension<AppState>| async {}))
.layer(Extension(state.clone()))
.into_service()
});
benchmark("state").run(|| {
@ -133,7 +122,7 @@ impl BenchmarkBuilder {
fn run<F>(self, f: F)
where
F: FnOnce() -> RouterService,
F: FnOnce() -> Router<()>,
{
// support only running some benchmarks with
// ```

View file

@ -1,6 +1,14 @@
use std::{convert::Infallible, fmt};
use crate::{body::HttpBody, handler::Handler, routing::Route, Router};
use http::Request;
use tower::Service;
use crate::{
body::HttpBody,
handler::Handler,
routing::{future::RouteFuture, Route},
Router,
};
pub(crate) struct BoxedIntoRoute<S, B, E>(Box<dyn ErasedIntoRoute<S, B, E>>);
@ -13,6 +21,7 @@ where
where
H: Handler<T, S, B>,
T: 'static,
B: HttpBody,
{
Self(Box::new(MakeErasedHandler {
handler,
@ -30,6 +39,14 @@ where
into_route: |router, state| Route::new(router.with_state(state)),
}))
}
pub(crate) fn call_with_state(
self,
request: Request<B>,
state: S,
) -> RouteFuture<B, Infallible> {
self.0.call_with_state(request, state)
}
}
impl<S, B, E> BoxedIntoRoute<S, B, E> {
@ -39,7 +56,7 @@ impl<S, B, E> BoxedIntoRoute<S, B, E> {
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: 'static,
B2: HttpBody + 'static,
E2: 'static,
{
BoxedIntoRoute(Box::new(Map {
@ -69,6 +86,8 @@ pub(crate) trait ErasedIntoRoute<S, B, E>: Send {
fn clone_box(&self) -> Box<dyn ErasedIntoRoute<S, B, E>>;
fn into_route(self: Box<Self>, state: S) -> Route<B, E>;
fn call_with_state(self: Box<Self>, request: Request<B>, state: S) -> RouteFuture<B, E>;
}
pub(crate) struct MakeErasedHandler<H, S, B> {
@ -80,7 +99,7 @@ impl<H, S, B> ErasedIntoRoute<S, B, Infallible> for MakeErasedHandler<H, S, B>
where
H: Clone + Send + 'static,
S: 'static,
B: 'static,
B: HttpBody + 'static,
{
fn clone_box(&self) -> Box<dyn ErasedIntoRoute<S, B, Infallible>> {
Box::new(self.clone())
@ -89,6 +108,14 @@ where
fn into_route(self: Box<Self>, state: S) -> Route<B> {
(self.into_route)(self.handler, state)
}
fn call_with_state(
self: Box<Self>,
request: Request<B>,
state: S,
) -> RouteFuture<B, Infallible> {
self.into_route(state).call(request)
}
}
impl<H, S, B> Clone for MakeErasedHandler<H, S, B>
@ -110,8 +137,8 @@ pub(crate) struct MakeErasedRouter<S, B> {
impl<S, B> ErasedIntoRoute<S, B, Infallible> for MakeErasedRouter<S, B>
where
S: Clone + Send + 'static,
B: 'static,
S: Clone + Send + Sync + 'static,
B: HttpBody + Send + 'static,
{
fn clone_box(&self) -> Box<dyn ErasedIntoRoute<S, B, Infallible>> {
Box::new(self.clone())
@ -120,6 +147,14 @@ where
fn into_route(self: Box<Self>, state: S) -> Route<B> {
(self.into_route)(self.router, state)
}
fn call_with_state(
mut self: Box<Self>,
request: Request<B>,
state: S,
) -> RouteFuture<B, Infallible> {
self.router.call_with_state(request, state)
}
}
impl<S, B> Clone for MakeErasedRouter<S, B>
@ -144,7 +179,7 @@ where
S: 'static,
B: 'static,
E: 'static,
B2: 'static,
B2: HttpBody + 'static,
E2: 'static,
{
fn clone_box(&self) -> Box<dyn ErasedIntoRoute<S, B2, E2>> {
@ -157,6 +192,10 @@ where
fn into_route(self: Box<Self>, state: S) -> Route<B2, E2> {
(self.layer)(self.inner.into_route(state))
}
fn call_with_state(self: Box<Self>, request: Request<B2>, state: S) -> RouteFuture<B2, E2> {
(self.layer)(self.inner.into_route(state)).call(request)
}
}
pub(crate) trait LayerFn<B, E, B2, E2>: FnOnce(Route<B, E>) -> Route<B2, E2> + Send {

View file

@ -466,7 +466,7 @@ let app = Router::new()
.route("/", get(handler))
.layer(MyLayer { state: state.clone() })
.with_state(state);
# let _: axum::routing::RouterService = app;
# let _: axum::Router = app;
```
# Passing state from middleware to handlers
@ -556,7 +556,7 @@ async fn rewrite_request_uri<B>(req: Request<B>, next: Next<B>) -> Response {
// this can be any `tower::Layer`
let middleware = axum::middleware::from_fn(rewrite_request_uri);
let app = Router::new().into_service();
let app = Router::new();
// apply the layer around the whole `Router`
// this way the middleware will run before `Router` receives the request

View file

@ -2,10 +2,6 @@ Convert this router into a [`MakeService`], that will store `C`'s
associated `ConnectInfo` in a request extension such that [`ConnectInfo`]
can extract it.
This is a convenience method for routers that don't have any state (i.e. the
state type is `()`). Use [`RouterService::into_make_service_with_connect_info`]
otherwise.
This enables extracting things like the client's remote address.
Extracting [`std::net::SocketAddr`] is supported out of the box:

View file

@ -39,13 +39,39 @@ let app = Router::new()
# Merging routers with state
When combining [`Router`]s with this function, each [`Router`] must have the
same type of state. See ["Combining stateful routers"][combining-stateful-routers]
for details.
When combining [`Router`]s with this method, each [`Router`] must have the
same type of state. If your routers have different types you can use
[`Router::with_state`] to provide the state and make the types match:
```rust
use axum::{
Router,
routing::get,
extract::State,
};
#[derive(Clone)]
struct InnerState {}
#[derive(Clone)]
struct OuterState {}
async fn inner_handler(state: State<InnerState>) {}
let inner_router = Router::new()
.route("/bar", get(inner_handler))
.with_state(InnerState {});
async fn outer_handler(state: State<OuterState>) {}
let app = Router::new()
.route("/", get(outer_handler))
.merge(inner_router)
.with_state(OuterState {});
# let _: axum::Router = app;
```
# Panics
- If two routers that each have a [fallback](Router::fallback) are merged. This
is because `Router` only allows a single fallback.
[combining-stateful-routers]: crate::extract::State#combining-stateful-routers

View file

@ -149,12 +149,40 @@ Here requests like `GET /api/not-found` will go to `api_fallback`.
# Nesting routers with state
When combining [`Router`]s with this function, each [`Router`] must have the
same type of state. See ["Combining stateful routers"][combining-stateful-routers]
for details.
When combining [`Router`]s with this method, each [`Router`] must have the
same type of state. If your routers have different types you can use
[`Router::with_state`] to provide the state and make the types match:
If you want to compose axum services with different types of state, use
[`Router::nest_service`].
```rust
use axum::{
Router,
routing::get,
extract::State,
};
#[derive(Clone)]
struct InnerState {}
#[derive(Clone)]
struct OuterState {}
async fn inner_handler(state: State<InnerState>) {}
let inner_router = Router::new()
.route("/bar", get(inner_handler))
.with_state(InnerState {});
async fn outer_handler(state: State<OuterState>) {}
let app = Router::new()
.route("/", get(outer_handler))
.nest("/foo", inner_router)
.with_state(OuterState {});
# let _: axum::Router = app;
```
Note that the inner router will still inherit the fallback from the outer
router.
# Panics
@ -165,4 +193,3 @@ for more details.
[`OriginalUri`]: crate::extract::OriginalUri
[fallbacks]: Router::fallback
[combining-stateful-routers]: crate::extract::State#combining-stateful-routers

View file

@ -69,7 +69,7 @@ use axum::{routing::get, Router};
let app = Router::new().route_service(
"/",
Router::new().route("/foo", get(|| async {})).into_service(),
Router::new().route("/foo", get(|| async {})),
);
# async {
# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();

View file

@ -43,7 +43,7 @@ use std::{
/// ) {
/// // use `state`...
/// }
/// # let _: axum::routing::RouterService = app;
/// # let _: axum::Router = app;
/// ```
///
/// ## Combining stateful routers
@ -71,19 +71,19 @@ use std::{
/// async fn posts_handler(State(state): State<AppState>) {
/// // use `state`...
/// }
/// # let _: axum::routing::RouterService = app;
/// # let _: axum::Router = app;
/// ```
///
/// However, if you are composing [`Router`]s that are defined in separate scopes,
/// you may need to annotate the [`State`] type explicitly:
///
/// ```
/// use axum::{Router, RouterService, routing::get, extract::State};
/// use axum::{Router, routing::get, extract::State};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// fn make_app() -> RouterService {
/// fn make_app() -> Router {
/// let state = AppState {};
///
/// Router::new()
@ -101,19 +101,15 @@ use std::{
/// async fn posts_handler(State(state): State<AppState>) {
/// // use `state`...
/// }
/// # let _: axum::routing::RouterService = make_app();
/// # let _: axum::Router = make_app();
/// ```
///
/// In short, a [`Router`]'s generic state type defaults to `()`
/// (no state) unless [`Router::with_state`] is called or the value
/// of the generic type is given explicitly.
///
/// It's also possible to combine multiple axum services with different state
/// types. See [`Router::nest_service`] for details.
///
/// [`Router`]: crate::Router
/// [`Router::merge`]: crate::Router::merge
/// [`Router::nest_service`]: crate::Router::nest_service
/// [`Router::nest`]: crate::Router::nest
/// [`Router::with_state`]: crate::Router::with_state
///
@ -209,7 +205,7 @@ use std::{
/// State(state): State<AppState>,
/// ) {
/// }
/// # let _: axum::routing::RouterService = app;
/// # let _: axum::Router = app;
/// ```
///
/// For convenience `FromRef` can also be derived using `#[derive(FromRef)]`.

View file

@ -358,7 +358,7 @@ mod tests {
format!("you said: {}", body)
}
let client = TestClient::from_service(handle.into_service());
let client = TestClient::new(handle.into_service());
let res = client.post("/").body("hi there!").send().await;
assert_eq!(res.status(), StatusCode::OK);
@ -381,7 +381,7 @@ mod tests {
.layer(MapRequestBodyLayer::new(body::boxed))
.with_state("foo");
let client = TestClient::from_service(svc);
let client = TestClient::new(svc);
let res = client.get("/").send().await;
assert_eq!(res.text().await, "foo");
}

View file

@ -476,7 +476,7 @@ pub use self::extension::Extension;
#[cfg(feature = "json")]
pub use self::json::Json;
#[doc(inline)]
pub use self::routing::{Router, RouterService};
pub use self::routing::Router;
#[doc(inline)]
#[cfg(feature = "headers")]

View file

@ -137,7 +137,7 @@ pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
/// .route("/", get(|| async { /* ... */ }))
/// .route_layer(middleware::from_fn_with_state(state.clone(), my_middleware))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// # let _: axum::Router = app;
/// ```
pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
FromFnLayer {
@ -381,7 +381,6 @@ mod tests {
.layer(from_fn(insert_header));
let res = app
.into_service()
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();

View file

@ -152,7 +152,7 @@ pub fn map_request<F, T>(f: F) -> MapRequestLayer<F, (), T> {
/// .route("/", get(|| async { /* ... */ }))
/// .route_layer(map_request_with_state(state.clone(), my_middleware))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// # let _: axum::Router = app;
/// ```
pub fn map_request_with_state<F, S, T>(state: S, f: F) -> MapRequestLayer<F, S, T> {
MapRequestLayer {

View file

@ -136,7 +136,7 @@ pub fn map_response<F, T>(f: F) -> MapResponseLayer<F, (), T> {
/// .route("/", get(|| async { /* ... */ }))
/// .route_layer(map_response_with_state(state.clone(), my_middleware))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// # let _: axum::Router = app;
/// ```
pub fn map_response_with_state<F, S, T>(state: S, f: F) -> MapResponseLayer<F, S, T> {
MapResponseLayer {

View file

@ -1,6 +1,6 @@
//! Route to services and handlers based on HTTP methods.
use super::{FallbackRoute, IntoMakeService};
use super::IntoMakeService;
#[cfg(feature = "tokio")]
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
use crate::{
@ -83,7 +83,7 @@ macro_rules! top_level_service_fn {
T: Service<Request<B>> + Clone + Send + 'static,
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
B: Send + 'static,
B: HttpBody + Send + 'static,
S: Clone,
{
on_service(MethodFilter::$method, svc)
@ -143,7 +143,7 @@ macro_rules! top_level_handler_fn {
pub fn $name<H, T, S, B>(handler: H) -> MethodRouter<S, B, Infallible>
where
H: Handler<T, S, B>,
B: Send + 'static,
B: HttpBody + Send + 'static,
T: 'static,
S: Clone + Send + Sync + 'static,
{
@ -327,7 +327,7 @@ where
T: Service<Request<B>> + Clone + Send + 'static,
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
B: Send + 'static,
B: HttpBody + Send + 'static,
S: Clone,
{
MethodRouter::new().on_service(filter, svc)
@ -391,7 +391,7 @@ where
T: Service<Request<B>> + Clone + Send + 'static,
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
B: Send + 'static,
B: HttpBody + Send + 'static,
S: Clone,
{
MethodRouter::new()
@ -430,7 +430,7 @@ top_level_handler_fn!(trace, TRACE);
pub fn on<H, T, S, B>(filter: MethodFilter, handler: H) -> MethodRouter<S, B, Infallible>
where
H: Handler<T, S, B>,
B: Send + 'static,
B: HttpBody + Send + 'static,
T: 'static,
S: Clone + Send + Sync + 'static,
{
@ -477,7 +477,7 @@ where
pub fn any<H, T, S, B>(handler: H) -> MethodRouter<S, B, Infallible>
where
H: Handler<T, S, B>,
B: Send + 'static,
B: HttpBody + Send + 'static,
T: 'static,
S: Clone + Send + Sync + 'static,
{
@ -571,7 +571,7 @@ impl<S, B, E> fmt::Debug for MethodRouter<S, B, E> {
impl<S, B> MethodRouter<S, B, Infallible>
where
B: Send + 'static,
B: HttpBody + Send + 'static,
S: Clone,
{
/// Chain an additional handler that will accept requests matching the given
@ -633,7 +633,7 @@ where
impl<B> MethodRouter<(), B, Infallible>
where
B: Send + 'static,
B: HttpBody + Send + 'static,
{
/// Convert the handler into a [`MakeService`].
///
@ -665,7 +665,7 @@ where
///
/// [`MakeService`]: tower::make::MakeService
pub fn into_make_service(self) -> IntoMakeService<Self> {
IntoMakeService::new(self)
IntoMakeService::new(self.with_state(()))
}
/// Convert the router into a [`MakeService`] which stores information
@ -701,13 +701,13 @@ where
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
#[cfg(feature = "tokio")]
pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
IntoMakeServiceWithConnectInfo::new(self)
IntoMakeServiceWithConnectInfo::new(self.with_state(()))
}
}
impl<S, B, E> MethodRouter<S, B, E>
where
B: Send + 'static,
B: HttpBody + Send + 'static,
S: Clone,
{
/// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all
@ -731,21 +731,19 @@ where
}
}
/// Provide the state.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn with_state(self, state: S) -> WithState<B, E> {
WithState {
get: self.get.into_route(&state),
head: self.head.into_route(&state),
delete: self.delete.into_route(&state),
options: self.options.into_route(&state),
patch: self.patch.into_route(&state),
post: self.post.into_route(&state),
put: self.put.into_route(&state),
trace: self.trace.into_route(&state),
fallback: self.fallback.into_fallback_route(&state),
/// Provide the state for the router.
pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, B, E> {
MethodRouter {
get: self.get.with_state(state.clone()),
head: self.head.with_state(state.clone()),
delete: self.delete.with_state(state.clone()),
options: self.options.with_state(state.clone()),
patch: self.patch.with_state(state.clone()),
post: self.post.with_state(state.clone()),
put: self.put.with_state(state.clone()),
trace: self.trace.with_state(state.clone()),
allow_header: self.allow_header,
fallback: self.fallback.with_state(state),
}
}
@ -918,10 +916,7 @@ where
}
#[doc = include_str!("../docs/method_routing/layer.md")]
pub fn layer<L, NewReqBody: 'static, NewError: 'static>(
self,
layer: L,
) -> MethodRouter<S, NewReqBody, NewError>
pub fn layer<L, NewReqBody, NewError>(self, layer: L) -> MethodRouter<S, NewReqBody, NewError>
where
L: Layer<Route<B, E>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
@ -930,6 +925,8 @@ where
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
E: 'static,
S: 'static,
NewReqBody: HttpBody + 'static,
NewError: 'static,
{
let layer_fn = move |route: Route<B, E>| route.layer(layer.clone());
@ -1069,226 +1066,8 @@ where
self.allow_header = AllowHeader::Skip;
self
}
}
fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
match allow_header {
AllowHeader::None => {
*allow_header = AllowHeader::Bytes(BytesMut::from(method));
}
AllowHeader::Skip => {}
AllowHeader::Bytes(allow_header) => {
if let Ok(s) = std::str::from_utf8(allow_header) {
if !s.contains(method) {
allow_header.extend_from_slice(b",");
allow_header.extend_from_slice(method.as_bytes());
}
} else {
#[cfg(debug_assertions)]
panic!("`allow_header` contained invalid uft-8. This should never happen")
}
}
}
}
impl<B, E> Service<Request<B>> for MethodRouter<(), B, E>
where
B: HttpBody + Send + 'static,
{
type Response = Response;
type Error = E;
type Future = RouteFuture<B, E>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<B>) -> Self::Future {
self.clone().with_state(()).call(req)
}
}
impl<S, B, E> Clone for MethodRouter<S, B, E> {
fn clone(&self) -> Self {
Self {
get: self.get.clone(),
head: self.head.clone(),
delete: self.delete.clone(),
options: self.options.clone(),
patch: self.patch.clone(),
post: self.post.clone(),
put: self.put.clone(),
trace: self.trace.clone(),
fallback: self.fallback.clone(),
allow_header: self.allow_header.clone(),
}
}
}
impl<S, B, E> Default for MethodRouter<S, B, E>
where
B: Send + 'static,
S: Clone,
{
fn default() -> Self {
Self::new()
}
}
enum MethodEndpoint<S, B, E> {
None,
Route(Route<B, E>),
BoxedHandler(BoxedIntoRoute<S, B, E>),
}
impl<S, B, E> MethodEndpoint<S, B, E>
where
S: Clone,
{
fn is_some(&self) -> bool {
matches!(self, Self::Route(_) | Self::BoxedHandler(_))
}
fn is_none(&self) -> bool {
matches!(self, Self::None)
}
fn map<F, B2, E2>(self, f: F) -> MethodEndpoint<S, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: 'static,
E2: 'static,
{
match self {
Self::None => MethodEndpoint::None,
Self::Route(route) => MethodEndpoint::Route(f(route)),
Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
}
}
fn into_route(self, state: &S) -> Option<Route<B, E>> {
match self {
Self::None => None,
Self::Route(route) => Some(route),
Self::BoxedHandler(handler) => Some(handler.into_route(state.clone())),
}
}
}
impl<S, B, E> Clone for MethodEndpoint<S, B, E> {
fn clone(&self) -> Self {
match self {
Self::None => Self::None,
Self::Route(inner) => Self::Route(inner.clone()),
Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
}
}
}
impl<S, B, E> fmt::Debug for MethodEndpoint<S, B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => f.debug_tuple("None").finish(),
Self::Route(inner) => inner.fmt(f),
Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
}
}
}
/// A [`MethodRouter`] which has access to some state.
///
/// Implements [`Service`].
///
/// The state can be extracted with [`State`](crate::extract::State).
///
/// Created with [`MethodRouter::with_state`]
pub struct WithState<B, E> {
get: Option<Route<B, E>>,
head: Option<Route<B, E>>,
delete: Option<Route<B, E>>,
options: Option<Route<B, E>>,
patch: Option<Route<B, E>>,
post: Option<Route<B, E>>,
put: Option<Route<B, E>>,
trace: Option<Route<B, E>>,
fallback: FallbackRoute<B, E>,
allow_header: AllowHeader,
}
impl<B, E> WithState<B, E> {
/// Convert the handler into a [`MakeService`].
///
/// See [`MethodRouter::into_make_service`] for more details.
///
/// [`MakeService`]: tower::make::MakeService
pub fn into_make_service(self) -> IntoMakeService<Self> {
IntoMakeService::new(self)
}
/// Convert the router into a [`MakeService`] which stores information
/// about the incoming connection.
///
/// See [`MethodRouter::into_make_service_with_connect_info`] for more details.
///
/// [`MakeService`]: tower::make::MakeService
#[cfg(feature = "tokio")]
pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
IntoMakeServiceWithConnectInfo::new(self)
}
}
impl<B, E> Clone for WithState<B, E> {
fn clone(&self) -> Self {
Self {
get: self.get.clone(),
head: self.head.clone(),
delete: self.delete.clone(),
options: self.options.clone(),
patch: self.patch.clone(),
post: self.post.clone(),
put: self.put.clone(),
trace: self.trace.clone(),
fallback: self.fallback.clone(),
allow_header: self.allow_header.clone(),
}
}
}
impl<B, E> fmt::Debug for WithState<B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WithState")
.field("get", &self.get)
.field("head", &self.head)
.field("delete", &self.delete)
.field("options", &self.options)
.field("patch", &self.patch)
.field("post", &self.post)
.field("put", &self.put)
.field("trace", &self.trace)
.field("fallback", &self.fallback)
.field("allow_header", &self.allow_header)
.finish()
}
}
impl<B, E> Service<Request<B>> for WithState<B, E>
where
B: HttpBody + Send,
{
type Response = Response;
type Error = E;
type Future = RouteFuture<B, E>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<B>) -> Self::Future {
pub(crate) fn call_with_state(&mut self, req: Request<B>, state: S) -> RouteFuture<B, E> {
macro_rules! call {
(
$req:expr,
@ -1297,9 +1076,17 @@ where
$svc:expr
) => {
if $method == Method::$method_variant {
if let Some(svc) = $svc {
return RouteFuture::from_future(svc.oneshot_inner($req))
.strip_body($method == Method::HEAD);
match $svc {
MethodEndpoint::None => {}
MethodEndpoint::Route(route) => {
return RouteFuture::from_future(route.oneshot_inner($req))
.strip_body($method == Method::HEAD);
}
MethodEndpoint::BoxedHandler(handler) => {
let mut route = handler.clone().into_route(state);
return RouteFuture::from_future(route.oneshot_inner($req))
.strip_body($method == Method::HEAD);
}
}
}
};
@ -1331,7 +1118,15 @@ where
call!(req, method, DELETE, delete);
call!(req, method, TRACE, trace);
let future = RouteFuture::from_future(fallback.oneshot_inner(req));
let future = match fallback {
Fallback::Default(route) | Fallback::Service(route) => {
RouteFuture::from_future(route.oneshot_inner(req))
}
Fallback::BoxedHandler(handler) => {
let mut route = handler.clone().into_route(state);
RouteFuture::from_future(route.oneshot_inner(req))
}
};
match allow_header {
AllowHeader::None => future.allow_header(Bytes::new()),
@ -1341,6 +1136,137 @@ where
}
}
fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
match allow_header {
AllowHeader::None => {
*allow_header = AllowHeader::Bytes(BytesMut::from(method));
}
AllowHeader::Skip => {}
AllowHeader::Bytes(allow_header) => {
if let Ok(s) = std::str::from_utf8(allow_header) {
if !s.contains(method) {
allow_header.extend_from_slice(b",");
allow_header.extend_from_slice(method.as_bytes());
}
} else {
#[cfg(debug_assertions)]
panic!("`allow_header` contained invalid uft-8. This should never happen")
}
}
}
}
impl<S, B, E> Clone for MethodRouter<S, B, E> {
fn clone(&self) -> Self {
Self {
get: self.get.clone(),
head: self.head.clone(),
delete: self.delete.clone(),
options: self.options.clone(),
patch: self.patch.clone(),
post: self.post.clone(),
put: self.put.clone(),
trace: self.trace.clone(),
fallback: self.fallback.clone(),
allow_header: self.allow_header.clone(),
}
}
}
impl<S, B, E> Default for MethodRouter<S, B, E>
where
B: HttpBody + Send + 'static,
S: Clone,
{
fn default() -> Self {
Self::new()
}
}
enum MethodEndpoint<S, B, E> {
None,
Route(Route<B, E>),
BoxedHandler(BoxedIntoRoute<S, B, E>),
}
impl<S, B, E> MethodEndpoint<S, B, E>
where
S: Clone,
{
fn is_some(&self) -> bool {
matches!(self, Self::Route(_) | Self::BoxedHandler(_))
}
fn is_none(&self) -> bool {
matches!(self, Self::None)
}
fn map<F, B2, E2>(self, f: F) -> MethodEndpoint<S, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: HttpBody + 'static,
E2: 'static,
{
match self {
Self::None => MethodEndpoint::None,
Self::Route(route) => MethodEndpoint::Route(f(route)),
Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
}
}
fn with_state<S2>(self, state: S) -> MethodEndpoint<S2, B, E> {
match self {
MethodEndpoint::None => MethodEndpoint::None,
MethodEndpoint::Route(route) => MethodEndpoint::Route(route),
MethodEndpoint::BoxedHandler(handler) => {
MethodEndpoint::Route(handler.into_route(state))
}
}
}
}
impl<S, B, E> Clone for MethodEndpoint<S, B, E> {
fn clone(&self) -> Self {
match self {
Self::None => Self::None,
Self::Route(inner) => Self::Route(inner.clone()),
Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
}
}
}
impl<S, B, E> fmt::Debug for MethodEndpoint<S, B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => f.debug_tuple("None").finish(),
Self::Route(inner) => inner.fmt(f),
Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
}
}
}
impl<B, E> Service<Request<B>> for MethodRouter<(), B, E>
where
B: HttpBody + Send + 'static,
{
type Response = Response;
type Error = E;
type Future = RouteFuture<B, E>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[inline]
fn call(&mut self, req: Request<B>) -> Self::Future {
self.call_with_state(req, ())
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -1,6 +1,6 @@
//! Routing between [`Service`]s and handlers.
use self::{not_found::NotFound, strip_prefix::StripPrefix};
use self::{future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix};
#[cfg(feature = "tokio")]
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
use crate::{
@ -12,8 +12,14 @@ use crate::{
use axum_core::response::{IntoResponse, Response};
use http::Request;
use matchit::MatchError;
use std::{collections::HashMap, convert::Infallible, fmt, sync::Arc};
use tower::util::{BoxCloneService, Oneshot};
use std::{
collections::HashMap,
convert::Infallible,
fmt,
sync::Arc,
task::{Context, Poll},
};
use sync_wrapper::SyncWrapper;
use tower_layer::Layer;
use tower_service::Service;
@ -27,14 +33,10 @@ mod route;
mod strip_prefix;
pub(crate) mod url_params;
mod service;
#[cfg(test)]
mod tests;
pub use self::{
into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route,
service::RouterService,
};
pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route};
pub use self::method_routing::{
any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service,
@ -168,10 +170,10 @@ where
T::Response: IntoResponse,
T::Future: Send + 'static,
{
let service = match try_downcast::<RouterService<B>, _>(service) {
let service = match try_downcast::<Router<S, B>, _>(service) {
Ok(_) => {
panic!(
"Invalid route: `Router::route_service` cannot be used with `RouterService`s. \
"Invalid route: `Router::route_service` cannot be used with `Router`s. \
Use `Router::nest` instead"
);
}
@ -212,41 +214,6 @@ where
}
/// Like [`nest`](Self::nest), but accepts an arbitrary `Service`.
///
/// While [`nest`](Self::nest) requires [`Router`]s with the same type of
/// state, you can use this method to combine [`Router`]s with different
/// types of state:
///
/// ```
/// use axum::{
/// Router,
/// routing::get,
/// extract::State,
/// };
///
/// #[derive(Clone)]
/// struct InnerState {}
///
/// #[derive(Clone)]
/// struct OuterState {}
///
/// async fn inner_handler(state: State<InnerState>) {}
///
/// let inner_router = Router::new()
/// .route("/bar", get(inner_handler))
/// .with_state(InnerState {});
///
/// async fn outer_handler(state: State<OuterState>) {}
///
/// let app = Router::new()
/// .route("/", get(outer_handler))
/// .nest_service("/foo", inner_router)
/// .with_state(OuterState {});
/// # let _: axum::routing::RouterService = app;
/// ```
///
/// Note that the inner router will still inherit the fallback from the outer
/// router.
#[track_caller]
pub fn nest_service<T>(self, path: &str, svc: T) -> Self
where
@ -353,7 +320,7 @@ where
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewReqBody: 'static,
NewReqBody: HttpBody + 'static,
{
let routes = self
.routes
@ -429,11 +396,171 @@ where
self
}
/// Convert this router into a [`RouterService`] by providing the state.
/// Provide the state for the router.
///
/// Once this method has been called you cannot add more routes. So it must be called as last.
pub fn with_state(self, state: S) -> RouterService<B> {
RouterService::new(self, state)
/// This method returns a router with a different state type. This can be used to nest or merge
/// routers with different state types. See [`Router::nest`] and [`Router::merge`] for more
/// details.
///
/// # Implementing `Service`
///
/// This can also be used to get a `Router` that implements [`Service`], since it only does so
/// when the state is `()`:
///
/// ```
/// use axum::{
/// Router,
/// body::Body,
/// http::Request,
/// };
/// use tower::{Service, ServiceExt};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// // this router doesn't implement `Service` because its state isn't `()`
/// let router: Router<AppState> = Router::new();
///
/// // by providing the state and setting the new state to `()`...
/// let router_service: Router<()> = router.with_state(AppState {});
///
/// // ...makes it implement `Service`
/// # async {
/// router_service.oneshot(Request::new(Body::empty())).await;
/// # };
/// ```
///
/// # A note about performance
///
/// If you need a `Router` that implements `Service` but you don't need any state (perhaps
/// you're making a library that uses axum internally) then it is recommended to call this
/// method before you start serving requests:
///
/// ```
/// use axum::{Router, routing::get};
///
/// let app = Router::new()
/// .route("/", get(|| async { /* ... */ }))
/// // even though we don't need any state, call `with_state(())` anyway
/// .with_state(());
/// # let _: Router = app;
/// ```
///
/// This is not required but it gives axum a chance to update some internals in the router
/// which may impact performance and reduce allocations.
///
/// Note that [`Router::into_make_service`] and [`Router::into_make_service_with_connect_info`]
/// do this automatically.
pub fn with_state<S2>(self, state: S) -> Router<S2, B> {
let routes = self
.routes
.into_iter()
.map(|(id, endpoint)| {
let endpoint: Endpoint<S2, B> = match endpoint {
Endpoint::MethodRouter(method_router) => {
Endpoint::MethodRouter(method_router.with_state(state.clone()))
}
Endpoint::Route(route) => Endpoint::Route(route),
Endpoint::NestedRouter(router) => {
Endpoint::Route(router.into_route(state.clone()))
}
};
(id, endpoint)
})
.collect();
let fallback = self.fallback.with_state(state);
Router {
routes,
node: self.node,
fallback,
}
}
pub(crate) fn call_with_state(
&mut self,
mut req: Request<B>,
state: S,
) -> RouteFuture<B, Infallible> {
#[cfg(feature = "original-uri")]
{
use crate::extract::OriginalUri;
if req.extensions().get::<OriginalUri>().is_none() {
let original_uri = OriginalUri(req.uri().clone());
req.extensions_mut().insert(original_uri);
}
}
let path = req.uri().path().to_owned();
match self.node.at(&path) {
Ok(match_) => {
match &self.fallback {
Fallback::Default(_) => {}
Fallback::Service(fallback) => {
req.extensions_mut()
.insert(SuperFallback(SyncWrapper::new(fallback.clone())));
}
Fallback::BoxedHandler(fallback) => {
req.extensions_mut().insert(SuperFallback(SyncWrapper::new(
fallback.clone().into_route(state.clone()),
)));
}
}
self.call_route(match_, req, state)
}
Err(
MatchError::NotFound
| MatchError::ExtraTrailingSlash
| MatchError::MissingTrailingSlash,
) => match &mut self.fallback {
Fallback::Default(fallback) => {
if let Some(super_fallback) = req.extensions_mut().remove::<SuperFallback<B>>()
{
let mut super_fallback = super_fallback.0.into_inner();
super_fallback.call(req)
} else {
fallback.call(req)
}
}
Fallback::Service(fallback) => fallback.call(req),
Fallback::BoxedHandler(handler) => handler.clone().into_route(state).call(req),
},
}
}
#[inline]
fn call_route(
&self,
match_: matchit::Match<&RouteId>,
mut req: Request<B>,
state: S,
) -> RouteFuture<B, Infallible> {
let id = *match_.value;
#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
req.extensions_mut(),
);
url_params::insert_url_params(req.extensions_mut(), match_.params);
let endpont = self
.routes
.get(&id)
.expect("no route for id. This is a bug in axum. Please file an issue")
.clone();
match endpont {
Endpoint::MethodRouter(mut method_router) => method_router.call_with_state(req, state),
Endpoint::Route(mut route) => route.call(req),
Endpoint::NestedRouter(router) => router.call_with_state(req, state),
}
}
}
@ -441,16 +568,6 @@ impl<B> Router<(), B>
where
B: HttpBody + Send + 'static,
{
/// Convert this router into a [`RouterService`].
///
/// This is a convenience method for routers that don't have any state (i.e. the state type is
/// `()`). Use [`Router::with_state`] otherwise.
///
/// Once this method has been called you cannot add more routes. So it must be called as last.
pub fn into_service(self) -> RouterService<B> {
RouterService::new(self, ())
}
/// Convert this router into a [`MakeService`], that is a [`Service`] whose
/// response is another service.
///
@ -473,20 +590,38 @@ where
/// # };
/// ```
///
/// This is a convenience method for routers that don't have any state (i.e. the state type is
/// `()`). Use [`RouterService::into_make_service`] otherwise.
///
/// [`MakeService`]: tower::make::MakeService
pub fn into_make_service(self) -> IntoMakeService<RouterService<B>> {
IntoMakeService::new(self.into_service())
pub fn into_make_service(self) -> IntoMakeService<Self> {
// call `Router::with_state` such that everything is turned into `Route` eagerly
// rather than doing that per request
IntoMakeService::new(self.with_state(()))
}
#[doc = include_str!("../docs/routing/into_make_service_with_connect_info.md")]
#[cfg(feature = "tokio")]
pub fn into_make_service_with_connect_info<C>(
self,
) -> IntoMakeServiceWithConnectInfo<RouterService<B>, C> {
IntoMakeServiceWithConnectInfo::new(self.into_service())
pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
// call `Router::with_state` such that everything is turned into `Route` eagerly
// rather than doing that per request
IntoMakeServiceWithConnectInfo::new(self.with_state(()))
}
}
impl<B> Service<Request<B>> for Router<(), B>
where
B: HttpBody + Send + 'static,
{
type Response = Response;
type Error = Infallible;
type Future = RouteFuture<B, Infallible>;
#[inline]
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[inline]
fn call(&mut self, req: Request<B>) -> Self::Future {
self.call_with_state(req, ())
}
}
@ -549,23 +684,13 @@ where
}
}
fn into_fallback_route(self, state: &S) -> FallbackRoute<B, E> {
match self {
Self::Default(route) => FallbackRoute::Default(route),
Self::Service(route) => FallbackRoute::Service(route),
Self::BoxedHandler(handler) => {
FallbackRoute::Service(handler.into_route(state.clone()))
}
}
}
fn map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: 'static,
B2: HttpBody + 'static,
E2: 'static,
{
match self {
@ -574,6 +699,14 @@ where
Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)),
}
}
fn with_state<S2>(self, state: S) -> Fallback<S2, B, E> {
match self {
Fallback::Default(route) => Fallback::Default(route),
Fallback::Service(route) => Fallback::Service(route),
Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)),
}
}
}
impl<S, B, E> Clone for Fallback<S, B, E> {
@ -596,61 +729,7 @@ impl<S, B, E> fmt::Debug for Fallback<S, B, E> {
}
}
/// Like `Fallback` but without the `S` param so it can be stored in `RouterService`
pub(crate) enum FallbackRoute<B, E = Infallible> {
Default(Route<B, E>),
Service(Route<B, E>),
}
impl<B, E> FallbackRoute<B, E> {
fn layer<L, NewReqBody, NewError>(self, layer: L) -> FallbackRoute<NewReqBody, NewError>
where
L: Layer<Route<B, E>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<NewError> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewReqBody: 'static,
NewError: 'static,
{
match self {
FallbackRoute::Default(route) => FallbackRoute::Default(route.layer(layer)),
FallbackRoute::Service(route) => FallbackRoute::Service(route.layer(layer)),
}
}
}
impl<B, E> fmt::Debug for FallbackRoute<B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(),
}
}
}
impl<B, E> Clone for FallbackRoute<B, E> {
fn clone(&self) -> Self {
match self {
Self::Default(inner) => Self::Default(inner.clone()),
Self::Service(inner) => Self::Service(inner.clone()),
}
}
}
impl<B, E> FallbackRoute<B, E> {
pub(crate) fn oneshot_inner(
&mut self,
req: Request<B>,
) -> Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>> {
match self {
FallbackRoute::Default(inner) => inner.oneshot_inner(req),
FallbackRoute::Service(inner) => inner.oneshot_inner(req),
}
}
}
#[allow(clippy::large_enum_variant)] // This type is only used at init time, probably fine
#[allow(clippy::large_enum_variant)]
enum Endpoint<S, B> {
MethodRouter(MethodRouter<S, B>),
Route(Route<B>),
@ -662,14 +741,6 @@ where
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
fn into_route(self, state: S) -> Route<B> {
match self {
Endpoint::MethodRouter(method_router) => Route::new(method_router.with_state(state)),
Endpoint::Route(route) => route,
Endpoint::NestedRouter(router) => router.into_route(state),
}
}
fn layer<L, NewReqBody>(self, layer: L) -> Endpoint<S, NewReqBody>
where
L: Layer<Route<B>> + Clone + Send + 'static,
@ -677,7 +748,7 @@ where
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewReqBody: 'static,
NewReqBody: HttpBody + 'static,
{
match self {
Endpoint::MethodRouter(method_router) => {
@ -721,6 +792,8 @@ enum RouterOrService<S, B, T> {
Service(T),
}
struct SuperFallback<B>(SyncWrapper<Route<B>>);
#[test]
#[allow(warnings)]
fn traits() {

View file

@ -1,225 +0,0 @@
use super::{
future::RouteFuture, url_params, FallbackRoute, IntoMakeService, Node, Route, RouteId, Router,
};
use crate::{
body::{Body, HttpBody},
response::Response,
};
use axum_core::response::IntoResponse;
use http::Request;
use matchit::MatchError;
use std::{
collections::HashMap,
convert::Infallible,
sync::Arc,
task::{Context, Poll},
};
use sync_wrapper::SyncWrapper;
use tower::Service;
use tower_layer::Layer;
/// A [`Router`] converted into a [`Service`].
#[derive(Debug)]
pub struct RouterService<B = Body> {
routes: HashMap<RouteId, Route<B>>,
node: Arc<Node>,
fallback: FallbackRoute<B>,
}
impl<B> RouterService<B>
where
B: HttpBody + Send + 'static,
{
pub(super) fn new<S>(router: Router<S, B>, state: S) -> Self
where
S: Clone + Send + Sync + 'static,
{
let fallback = router.fallback.into_fallback_route(&state);
let routes = router
.routes
.into_iter()
.map(|(route_id, endpoint)| {
let route = endpoint.into_route(state.clone());
(route_id, route)
})
.collect();
Self {
routes,
node: router.node,
fallback,
}
}
#[inline]
fn call_route(
&self,
match_: matchit::Match<&RouteId>,
mut req: Request<B>,
) -> RouteFuture<B, Infallible> {
let id = *match_.value;
#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
req.extensions_mut(),
);
url_params::insert_url_params(req.extensions_mut(), match_.params);
let mut route = self
.routes
.get(&id)
.expect("no route for id. This is a bug in axum. Please file an issue")
.clone();
route.call(req)
}
/// Apply a [`tower::Layer`] to all routes in the router.
///
/// See [`Router::layer`] for more details.
pub fn layer<L, NewReqBody>(self, layer: L) -> RouterService<NewReqBody>
where
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewReqBody: 'static,
{
let routes = self
.routes
.into_iter()
.map(|(id, route)| (id, route.layer(layer.clone())))
.collect();
let fallback = self.fallback.layer(layer);
RouterService {
routes,
node: self.node,
fallback,
}
}
/// Apply a [`tower::Layer`] to the router that will only run if the request matches
/// a route.
///
/// See [`Router::route_layer`] for more details.
pub fn route_layer<L>(self, layer: L) -> Self
where
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<B>> + Clone + Send + 'static,
<L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<B>>>::Future: Send + 'static,
{
let routes = self
.routes
.into_iter()
.map(|(id, route)| (id, route.layer(layer.clone())))
.collect();
Self {
routes,
node: self.node,
fallback: self.fallback,
}
}
/// Convert the `RouterService` into a [`MakeService`].
///
/// See [`Router::into_make_service`] for more details.
///
/// [`MakeService`]: tower::make::MakeService
pub fn into_make_service(self) -> IntoMakeService<Self> {
IntoMakeService::new(self)
}
/// Convert the `RouterService` into a [`MakeService`] which stores information
/// about the incoming connection.
///
/// See [`Router::into_make_service_with_connect_info`] for more details.
///
/// [`MakeService`]: tower::make::MakeService
#[cfg(feature = "tokio")]
pub fn into_make_service_with_connect_info<C>(
self,
) -> crate::extract::connect_info::IntoMakeServiceWithConnectInfo<Self, C> {
crate::extract::connect_info::IntoMakeServiceWithConnectInfo::new(self)
}
}
impl<B> Clone for RouterService<B> {
fn clone(&self) -> Self {
Self {
routes: self.routes.clone(),
node: Arc::clone(&self.node),
fallback: self.fallback.clone(),
}
}
}
impl<B> Service<Request<B>> for RouterService<B>
where
B: HttpBody + Send + 'static,
{
type Response = Response;
type Error = Infallible;
type Future = RouteFuture<B, Infallible>;
#[inline]
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[inline]
fn call(&mut self, mut req: Request<B>) -> Self::Future {
#[cfg(feature = "original-uri")]
{
use crate::extract::OriginalUri;
if req.extensions().get::<OriginalUri>().is_none() {
let original_uri = OriginalUri(req.uri().clone());
req.extensions_mut().insert(original_uri);
}
}
let path = req.uri().path().to_owned();
match self.node.at(&path) {
Ok(match_) => {
match &self.fallback {
FallbackRoute::Default(_) => {}
FallbackRoute::Service(fallback) => {
req.extensions_mut()
.insert(SuperFallback(SyncWrapper::new(fallback.clone())));
}
}
self.call_route(match_, req)
}
Err(
MatchError::NotFound
| MatchError::ExtraTrailingSlash
| MatchError::MissingTrailingSlash,
) => match &mut self.fallback {
FallbackRoute::Default(fallback) => {
if let Some(super_fallback) = req.extensions_mut().remove::<SuperFallback<B>>()
{
let mut super_fallback = super_fallback.0.into_inner();
super_fallback.call(req)
} else {
fallback.call(req)
}
}
FallbackRoute::Service(fallback) => fallback.call(req),
},
}
}
}
struct SuperFallback<B>(SyncWrapper<Route<B>>);

View file

@ -56,7 +56,7 @@ async fn fallback_accessing_state() {
.fallback(|State(state): State<&'static str>| async move { state })
.with_state("state");
let client = TestClient::from_service(app);
let client = TestClient::new(app);
let res = client.get("/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::OK);

View file

@ -19,7 +19,6 @@ mod for_handlers {
// don't use reqwest because it always strips bodies from HEAD responses
let res = app
.into_service()
.oneshot(
Request::builder()
.uri("/")
@ -55,7 +54,6 @@ mod for_services {
// don't use reqwest because it always strips bodies from HEAD responses
let res = app
.into_service()
.oneshot(
Request::builder()
.uri("/")

View file

@ -447,11 +447,11 @@ async fn middleware_still_run_for_unmatched_requests() {
#[tokio::test]
#[should_panic(expected = "\
Invalid route: `Router::route_service` cannot be used with `RouterService`s. \
Invalid route: `Router::route_service` cannot be used with `Router`s. \
Use `Router::nest` instead\
")]
async fn routing_to_router_panics() {
TestClient::new(Router::new().route_service("/", Router::new().into_service()));
TestClient::new(Router::new().route_service("/", Router::new()));
}
#[tokio::test]
@ -761,7 +761,7 @@ async fn extract_state() {
};
let app = Router::new().route("/", get(handler)).with_state(state);
let client = TestClient::from_service(app);
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::OK);
@ -776,7 +776,7 @@ async fn explicitly_set_state() {
)
.with_state("...");
let client = TestClient::from_service(app);
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.text().await, "foo");
}

View file

@ -1,6 +1,6 @@
#![allow(clippy::disallowed_names)]
use crate::{body::HttpBody, BoxError, Router};
use crate::{body::HttpBody, BoxError};
mod test_client;
pub(crate) use self::test_client::*;

View file

@ -1,4 +1,4 @@
use super::{BoxError, HttpBody, Router};
use super::{BoxError, HttpBody};
use bytes::Bytes;
use http::{
header::{HeaderName, HeaderValue},
@ -15,11 +15,7 @@ pub(crate) struct TestClient {
}
impl TestClient {
pub(crate) fn new(router: Router<(), Body>) -> Self {
Self::from_service(router.into_service())
}
pub(crate) fn from_service<S, ResBody>(svc: S) -> Self
pub(crate) fn new<S, ResBody>(svc: S) -> Self
where
S: Service<Request<Body>, Response = http::Response<ResBody>> + Clone + Send + 'static,
ResBody: HttpBody + Send + 'static,

View file

@ -50,7 +50,7 @@ mod tests {
#[tokio::test]
async fn test_get() {
let app = app().into_service();
let app = app();
let response = app
.oneshot(Request::get("/get-head").body(Body::empty()).unwrap())
@ -66,7 +66,7 @@ mod tests {
#[tokio::test]
async fn test_implicit_head() {
let app = app().into_service();
let app = app();
let response = app
.oneshot(Request::head("/get-head").body(Body::empty()).unwrap())

View file

@ -35,9 +35,7 @@ async fn main() {
.with(tracing_subscriber::fmt::layer())
.init();
let router_svc = Router::new()
.route("/", get(|| async { "Hello, World!" }))
.into_service();
let router_svc = Router::new().route("/", get(|| async { "Hello, World!" }));
let service = tower::service_fn(move |req: Request<Body>| {
let router_svc = router_svc.clone();

View file

@ -26,7 +26,7 @@ use std::{
use tower::{BoxError, ServiceBuilder};
use tower_http::{
auth::RequireAuthorizationLayer, compression::CompressionLayer, limit::RequestBodyLimitLayer,
trace::TraceLayer, ServiceBuilderExt,
trace::TraceLayer,
};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

View file

@ -104,7 +104,6 @@ mod tests {
async fn send_request_get_body(query: &str) -> String {
let body = app()
.into_service()
.oneshot(
Request::builder()
.uri(format!("/?{}", query))

View file

@ -55,7 +55,7 @@ async fn main() {
.init();
// build the rest service
let rest = Router::new().route("/", get(web_root)).into_service();
let rest = Router::new().route("/", get(web_root));
// build the grpc service
let grpc = GreeterServer::new(GrpcServiceImpl::default());

View file

@ -40,8 +40,7 @@ fn main() {
#[allow(clippy::let_and_return)]
async fn app(request: Request<String>) -> Response {
let mut router = Router::new().route("/api/", get(index)).into_service();
let mut router = Router::new().route("/api/", get(index));
let response = router.call(request).await.unwrap();
response
}

View file

@ -61,7 +61,7 @@ mod tests {
#[tokio::test]
async fn hello_world() {
let app = app().into_service();
let app = app();
// `Router` implements `tower::Service<Request<Body>>` so we can
// call it like any tower service, no need to run an HTTP server.
@ -78,7 +78,7 @@ mod tests {
#[tokio::test]
async fn json() {
let app = app().into_service();
let app = app();
let response = app
.oneshot(
@ -103,7 +103,7 @@ mod tests {
#[tokio::test]
async fn not_found() {
let app = app().into_service();
let app = app();
let response = app
.oneshot(
@ -154,7 +154,7 @@ mod tests {
// in multiple request
#[tokio::test]
async fn multiple_request() {
let mut app = app().into_service();
let mut app = app();
let request = Request::builder().uri("/").body(Body::empty()).unwrap();
let response = app.ready().await.unwrap().call(request).await.unwrap();