Type safe state inheritance (#1532)

* Make state type safe

* fix examples

* remove unnecessary `#[track_caller]`s

* Router::into_service -> Router::with_state

* fixup docs

* macro docs

* add missing docs

* fix examples

* format

* changelog

* Update trybuild tests

* Make sure fallbacks are still inherited for opaque services (#1540)

* Document nesting routers with different state

* fix leftover conflicts
This commit is contained in:
David Pedersen 2022-11-18 12:02:58 +01:00 committed by GitHub
parent ba8e9c1b21
commit 64960bb19c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
62 changed files with 675 additions and 736 deletions

View file

@ -34,13 +34,14 @@ use tower_layer::Layer;
/// http::Request,
/// };
///
/// let router = Router::new()
/// let app = Router::new()
/// .route(
/// "/",
/// // even with `DefaultBodyLimit` the request body is still just `Body`
/// post(|request: Request<Body>| async {}),
/// )
/// .layer(DefaultBodyLimit::max(1024));
/// # let _: Router<(), _> = app;
/// ```
///
/// ```
@ -48,7 +49,7 @@ use tower_layer::Layer;
/// use tower_http::limit::RequestBodyLimitLayer;
/// use http_body::Limited;
///
/// let router = Router::new()
/// let app = Router::new()
/// .route(
/// "/",
/// // `RequestBodyLimitLayer` changes the request body type to `Limited<Body>`
@ -56,6 +57,7 @@ use tower_layer::Layer;
/// post(|request: Request<Limited<Body>>| async {}),
/// )
/// .layer(RequestBodyLimitLayer::new(1024));
/// # let _: Router<(), _> = app;
/// ```
///
/// In general using `DefaultBodyLimit` is recommended but if you need to use third party
@ -102,7 +104,7 @@ impl DefaultBodyLimit {
/// use tower_http::limit::RequestBodyLimitLayer;
/// use http_body::Limited;
///
/// let app: Router<_, Limited<Body>> = Router::new()
/// let app: Router<(), Limited<Body>> = Router::new()
/// .route("/", get(|body: Bytes| async {}))
/// // Disable the default limit
/// .layer(DefaultBodyLimit::disable())
@ -138,7 +140,7 @@ impl DefaultBodyLimit {
/// use tower_http::limit::RequestBodyLimitLayer;
/// use http_body::Limited;
///
/// let app: Router<_, Limited<Body>> = Router::new()
/// let app: Router<(), Limited<Body>> = Router::new()
/// .route("/", get(|body: Bytes| async {}))
/// // Replace the default of 2MB with 1024 bytes.
/// .layer(DefaultBodyLimit::max(1024));

View file

@ -255,11 +255,11 @@ mod tests {
custom_key: CustomKey(Key::generate()),
};
let app = Router::<_, Body>::with_state(state)
let app = Router::<_, Body>::new()
.route("/set", get(set_cookie))
.route("/get", get(get_cookie))
.route("/remove", get(remove_cookie))
.into_service();
.with_state(state);
let res = app
.clone()
@ -352,9 +352,9 @@ mod tests {
custom_key: CustomKey(Key::generate()),
};
let app = Router::<_, Body>::with_state(state)
let app = Router::<_, Body>::new()
.route("/get", get(get_cookie))
.into_service();
.with_state(state);
let res = app
.clone()

View file

@ -64,10 +64,11 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// key: Key::generate(),
/// };
///
/// let app = Router::with_state(state)
/// let app = Router::new()
/// .route("/set", post(set_secret))
/// .route("/get", get(get_secret));
/// # let app: Router<_> = app;
/// .route("/get", get(get_secret))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// ```
pub struct PrivateCookieJar<K = Key> {
jar: cookie::CookieJar,

View file

@ -82,10 +82,11 @@ use std::{convert::Infallible, fmt, marker::PhantomData};
/// key: Key::generate(),
/// };
///
/// let app = Router::with_state(state)
/// let app = Router::new()
/// .route("/sessions", post(create_session))
/// .route("/me", get(me));
/// # let app: Router<_> = app;
/// .route("/me", get(me))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// ```
pub struct SignedCookieJar<K = Key> {
jar: cookie::CookieJar,

View file

@ -38,30 +38,18 @@ pub struct Resource<S = (), B = Body> {
pub(crate) router: Router<S, B>,
}
impl<B> Resource<(), B>
where
B: axum::body::HttpBody + Send + 'static,
{
/// Create a `Resource` with the given name.
///
/// All routes will be nested at `/{resource_name}`.
pub fn named(resource_name: &str) -> Self {
Self::named_with((), resource_name)
}
}
impl<S, B> Resource<S, B>
where
B: axum::body::HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
/// Create a `Resource` with the given name and state.
/// Create a `Resource` with the given name.
///
/// All routes will be nested at `/{resource_name}`.
pub fn named_with(state: S, resource_name: &str) -> Self {
pub fn named(resource_name: &str) -> Self {
Self {
name: resource_name.to_owned(),
router: Router::with_state(state),
router: Router::new(),
}
}

View file

@ -50,10 +50,10 @@ use tower_service::Service;
/// - `GET /some/other/path` will serve `index.html` since there isn't another
/// route for it
/// - `GET /api/foo` will serve the `api_foo` handler function
pub struct SpaRouter<B = Body, T = (), F = fn(io::Error) -> Ready<StatusCode>> {
pub struct SpaRouter<S = (), B = Body, T = (), F = fn(io::Error) -> Ready<StatusCode>> {
paths: Arc<Paths>,
handle_error: F,
_marker: PhantomData<fn() -> (B, T)>,
_marker: PhantomData<fn() -> (S, B, T)>,
}
#[derive(Debug)]
@ -63,7 +63,7 @@ struct Paths {
index_file: PathBuf,
}
impl<B> SpaRouter<B, (), fn(io::Error) -> Ready<StatusCode>> {
impl<S, B> SpaRouter<S, B, (), fn(io::Error) -> Ready<StatusCode>> {
/// Create a new `SpaRouter`.
///
/// Assets will be served at `GET /{serve_assets_at}` from the directory at `assets_dir`.
@ -86,7 +86,7 @@ impl<B> SpaRouter<B, (), fn(io::Error) -> Ready<StatusCode>> {
}
}
impl<B, T, F> SpaRouter<B, T, F> {
impl<S, B, T, F> SpaRouter<S, B, T, F> {
/// Set the path to the index file.
///
/// `path` must be relative to `assets_dir` passed to [`SpaRouter::new`].
@ -138,7 +138,7 @@ impl<B, T, F> SpaRouter<B, T, F> {
/// let app = Router::new().merge(spa);
/// # let _: Router = app;
/// ```
pub fn handle_error<T2, F2>(self, f: F2) -> SpaRouter<B, T2, F2> {
pub fn handle_error<T2, F2>(self, f: F2) -> SpaRouter<S, B, T2, F2> {
SpaRouter {
paths: self.paths,
handle_error: f,
@ -147,7 +147,7 @@ impl<B, T, F> SpaRouter<B, T, F> {
}
}
impl<B, F, T> From<SpaRouter<B, T, F>> for Router<(), B>
impl<S, B, F, T> From<SpaRouter<S, B, T, F>> for Router<S, B>
where
F: Clone + Send + Sync + 'static,
HandleError<Route<B, io::Error>, F, T>: Service<Request<B>, Error = Infallible>,
@ -155,8 +155,9 @@ where
<HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Future: Send,
B: HttpBody + Send + 'static,
T: 'static,
S: Clone + Send + Sync + 'static,
{
fn from(spa: SpaRouter<B, T, F>) -> Self {
fn from(spa: SpaRouter<S, B, T, F>) -> Router<S, B> {
let assets_service = get_service(ServeDir::new(&spa.paths.assets_dir))
.handle_error(spa.handle_error.clone());
@ -195,7 +196,7 @@ where
fn clone(&self) -> Self {
Self {
paths: self.paths.clone(),
handle_error: self.handle_error.clone(),
handle_error: self.handle_error,
_marker: self._marker,
}
}
@ -264,13 +265,14 @@ mod tests {
let spa = SpaRouter::new("/assets", "test_files").handle_error(handle_error);
Router::<_, Body>::new().merge(spa);
Router::<(), Body>::new().merge(spa);
}
#[allow(dead_code)]
fn works_with_router_with_state() {
let _: Router<String> = Router::with_state(String::new())
let _: axum::RouterService = Router::new()
.merge(SpaRouter::new("/assets", "test_files"))
.route("/", get(|_: axum::extract::State<String>| async {}));
.route("/", get(|_: axum::extract::State<String>| async {}))
.with_state(String::new());
}
}

View file

@ -25,7 +25,7 @@ syn = { version = "1.0", features = [
] }
[dev-dependencies]
axum = { path = "../axum", version = "0.6.0-rc.2", features = ["headers"] }
axum = { path = "../axum", version = "0.6.0-rc.2", features = ["headers", "macros"] }
axum-extra = { path = "../axum-extra", version = "0.4.0-rc.1", features = ["typed-routing", "cookie-private"] }
rustversion = "1.0"
serde = { version = "1.0", features = ["derive"] }

View file

@ -181,7 +181,6 @@ use from_request::Trait::{FromRequest, FromRequestParts};
/// rejection type with `#[from_request(rejection(YourType))]`:
///
/// ```
/// use axum_macros::FromRequest;
/// use axum::{
/// extract::{
/// rejection::{ExtensionRejection, StringRejection},
@ -463,8 +462,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
/// As the error message says, handler function needs to be async.
///
/// ```
/// use axum::{routing::get, Router};
/// use axum_macros::debug_handler;
/// use axum::{routing::get, Router, debug_handler};
///
/// #[tokio::main]
/// async fn main() {
@ -493,8 +491,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
/// To work around that the request body type can be customized like so:
///
/// ```
/// use axum::{body::BoxBody, http::Request};
/// # use axum_macros::debug_handler;
/// use axum::{body::BoxBody, http::Request, debug_handler};
///
/// #[debug_handler(body = BoxBody)]
/// async fn handler(request: Request<BoxBody>) {}
@ -506,8 +503,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
/// [`axum::extract::State`] argument:
///
/// ```
/// use axum::extract::State;
/// # use axum_macros::debug_handler;
/// use axum::{debug_handler, extract::State};
///
/// #[debug_handler]
/// async fn handler(
@ -523,8 +519,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
/// customize the state type you can set it with `#[debug_handler(state = ...)]`:
///
/// ```
/// use axum::extract::{State, FromRef};
/// # use axum_macros::debug_handler;
/// use axum::{debug_handler, extract::{State, FromRef}};
///
/// #[debug_handler(state = AppState)]
/// async fn handler(
@ -579,8 +574,11 @@ pub fn derive_typed_path(input: TokenStream) -> TokenStream {
/// # Example
///
/// ```
/// use axum_macros::FromRef;
/// use axum::{Router, routing::get, extract::State};
/// use axum::{
/// Router,
/// routing::get,
/// extract::{State, FromRef},
/// };
///
/// #
/// # type AuthToken = String;
@ -605,8 +603,10 @@ pub fn derive_typed_path(input: TokenStream) -> TokenStream {
/// database_pool,
/// };
///
/// let app = Router::with_state(state).route("/", get(handler).post(other_handler));
/// # let _: Router<AppState> = app;
/// let app = Router::new()
/// .route("/", get(handler).post(other_handler))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// ```
///
/// [`FromRef`]: https://docs.rs/axum/latest/axum/extract/trait.FromRef.html

View file

@ -1,5 +1,4 @@
use axum_macros::FromRef;
use axum::{Router, routing::get, extract::State};
use axum::{Router, routing::get, extract::{State, FromRef}};
// This will implement `FromRef` for each field in the struct.
#[derive(Clone, FromRef)]
@ -15,5 +14,7 @@ fn main() {
auth_token: Default::default(),
};
let _: Router<AppState> = Router::with_state(state).route("/", get(handler));
let _: axum::routing::RouterService = Router::new()
.route("/", get(handler))
.with_state(state);
}

View file

@ -1,5 +1,4 @@
use axum::{extract::FromRequestParts, response::Response};
use axum_macros::FromRequestParts;
#[derive(FromRequestParts)]
struct Extractor {

View file

@ -1,7 +1,7 @@
error[E0277]: the trait bound `String: FromRequestParts<S>` is not satisfied
--> tests/from_request/fail/parts_extracting_body.rs:6:11
--> tests/from_request/fail/parts_extracting_body.rs:5:11
|
6 | body: String,
5 | body: String,
| ^^^^^^ the trait `FromRequestParts<S>` is not implemented for `String`
|
= help: the following other types implement trait `FromRequestParts<S>`:

View file

@ -3,7 +3,6 @@ use axum::{
extract::{FromRequest, Json},
response::Response,
};
use axum_macros::FromRequest;
use serde::Deserialize;
#[derive(Deserialize, FromRequest)]

View file

@ -2,7 +2,6 @@ use axum::{
extract::{FromRequestParts, Extension},
response::Response,
};
use axum_macros::FromRequestParts;
#[derive(Clone, FromRequestParts)]
#[from_request(via(Extension))]

View file

@ -4,7 +4,6 @@ use axum::{
response::Response,
headers::{self, UserAgent},
};
use axum_macros::FromRequest;
#[derive(FromRequest)]
struct Extractor {

View file

@ -3,7 +3,6 @@ use axum::{
headers::{self, UserAgent},
response::Response,
};
use axum_macros::FromRequestParts;
#[derive(FromRequestParts)]
struct Extractor {

View file

@ -7,7 +7,6 @@ use axum::{
},
headers::{self, UserAgent},
};
use axum_macros::FromRequest;
#[derive(FromRequest)]
struct Extractor {

View file

@ -6,7 +6,6 @@ use axum::{
},
headers::{self, UserAgent},
};
use axum_macros::FromRequestParts;
#[derive(FromRequestParts)]
struct Extractor {

View file

@ -6,7 +6,6 @@ use axum::{
routing::get,
Extension, Router,
};
use axum_macros::FromRequest;
fn main() {
let _: Router = Router::new().route("/", get(handler).post(handler_result));

View file

@ -6,7 +6,6 @@ use axum::{
routing::get,
Extension, Router,
};
use axum_macros::FromRequestParts;
fn main() {
let _: Router = Router::new().route("/", get(handler).post(handler_result));

View file

@ -6,9 +6,10 @@ use axum::{
use axum_macros::FromRequest;
fn main() {
let _: Router<AppState> = Router::with_state(AppState::default())
let _: axum::routing::RouterService = Router::new()
.route("/a", get(|_: AppState| async {}))
.route("/b", get(|_: InnerState| async {}));
.route("/b", get(|_: InnerState| async {}))
.with_state(AppState::default());
}
#[derive(Clone, FromRequest)]

View file

@ -6,10 +6,11 @@ use axum::{
use axum_macros::FromRequestParts;
fn main() {
let _: Router<AppState> = Router::with_state(AppState::default())
let _: axum::routing::RouterService = Router::new()
.route("/a", get(|_: AppState| async {}))
.route("/b", get(|_: InnerState| async {}))
.route("/c", get(|_: AppState, _: InnerState| async {}));
.route("/c", get(|_: AppState, _: InnerState| async {}))
.with_state(AppState::default());
}
#[derive(Clone, FromRequestParts)]

View file

@ -6,8 +6,9 @@ use axum::{
};
fn main() {
let _: Router<AppState> = Router::with_state(AppState::default())
.route("/b", get(|_: Extractor| async {}));
let _: axum::routing::RouterService = Router::new()
.route("/b", get(|_: Extractor| async {}))
.with_state(AppState::default());
}
#[derive(FromRequest)]

View file

@ -7,8 +7,9 @@ use axum::{
use std::collections::HashMap;
fn main() {
let _: Router<AppState> = Router::with_state(AppState::default())
.route("/b", get(|_: Extractor| async {}));
let _: axum::routing::RouterService = Router::new()
.route("/b", get(|_: Extractor| async {}))
.with_state(AppState::default());
}
#[derive(FromRequestParts)]

View file

@ -6,8 +6,9 @@ use axum::{
use axum_macros::FromRequest;
fn main() {
let _: Router<AppState> = Router::with_state(AppState::default())
.route("/", get(|_: Extractor| async {}));
let _: axum::routing::RouterService = Router::new()
.route("/", get(|_: Extractor| async {}))
.with_state(AppState::default());
}
#[derive(FromRequest)]

View file

@ -6,8 +6,9 @@ use axum::{
use axum_macros::FromRequest;
fn main() {
let _: Router<AppState> = Router::with_state(AppState::default())
.route("/", get(|_: Extractor| async {}));
let _: axum::routing::RouterService = Router::new()
.route("/", get(|_: Extractor| async {}))
.with_state(AppState::default());
}
#[derive(FromRequest)]

View file

@ -6,9 +6,10 @@ use axum::{
use axum_macros::FromRequest;
fn main() {
let _: Router<AppState> = Router::with_state(AppState::default())
let _: axum::routing::RouterService = Router::new()
.route("/b", get(|_: (), _: AppState| async {}))
.route("/c", get(|_: (), _: InnerState| async {}));
.route("/c", get(|_: (), _: InnerState| async {}))
.with_state(AppState::default());
}
#[derive(Clone, Default, FromRequest)]

View file

@ -6,8 +6,9 @@ use axum::{
use axum_macros::FromRequest;
fn main() {
let _: Router<AppState> = Router::with_state(AppState::default())
.route("/b", get(|_: AppState| async {}));
let _: axum::routing::RouterService = Router::new()
.route("/b", get(|_: AppState| async {}))
.with_state(AppState::default());
}
// if we're extract "via" `State<AppState>` and not specifying state

View file

@ -6,10 +6,11 @@ use axum::{
use axum_macros::FromRequestParts;
fn main() {
let _: Router<AppState> = Router::with_state(AppState::default())
let _: axum::routing::RouterService = Router::new()
.route("/a", get(|_: AppState, _: InnerState, _: String| async {}))
.route("/b", get(|_: AppState, _: String| async {}))
.route("/c", get(|_: InnerState, _: String| async {}));
.route("/c", get(|_: InnerState, _: String| async {}))
.with_state(AppState::default());
}
#[derive(Clone, Default, FromRequestParts)]

View file

@ -8,8 +8,9 @@ use axum::{
use axum_macros::FromRequest;
fn main() {
let _: Router<AppState> =
Router::with_state(AppState::default()).route("/a", get(|_: Extractor| async {}));
let _: axum::routing::RouterService = Router::new()
.route("/a", get(|_: Extractor| async {}))
.with_state(AppState::default());
}
#[derive(Clone, Default, FromRequest)]

View file

@ -7,6 +7,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- **breaking:** `Router::with_state` is no longer a constructor. It is instead
used to convert the router into a `RouterService` ([#1532])
This nested router on 0.6.0-rc.4
```rust
Router::with_state(state).route(...);
```
Becomes this in 0.6.0-rc.5
```rust
Router::new().route(...).with_state(state);
```
- **breaking:**: `Router::nest` and `Router::merge` now only supports nesting
routers that use the same state type as the router they're being merged into.
Use `FromRef` for substates ([#1532])
- **added:** Add `accept_unmasked_frames` setting in WebSocketUpgrade ([#1529])
- **fixed:** Nested routers will now inherit fallbacks from outer routers ([#1521])
- **added:** Add `accept_unmasked_frames` setting in WebSocketUpgrade ([#1529])
- **added:** Add `WebSocketUpgrade::on_failed_upgrade` to customize what to do
@ -15,6 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1539]: https://github.com/tokio-rs/axum/pull/1539
[#1521]: https://github.com/tokio-rs/axum/pull/1521
[#1529]: https://github.com/tokio-rs/axum/pull/1529
[#1532]: https://github.com/tokio-rs/axum/pull/1532
# 0.6.0-rc.4 (9. November, 2022)

View file

@ -1,7 +1,7 @@
use axum::{
extract::State,
routing::{get, post},
Extension, Json, Router, Server,
Extension, Json, Router, RouterService, Server,
};
use hyper::server::conn::AddrIncoming;
use serde::{Deserialize, Serialize};
@ -17,9 +17,13 @@ fn main() {
ensure_rewrk_is_installed();
}
benchmark("minimal").run(Router::new);
benchmark("minimal").run(|| Router::new().into_service());
benchmark("basic").run(|| Router::new().route("/", get(|| async { "Hello, World!" })));
benchmark("basic").run(|| {
Router::new()
.route("/", get(|| async { "Hello, World!" }))
.into_service()
});
benchmark("routing").path("/foo/bar/baz").run(|| {
let mut app = Router::new();
@ -30,26 +34,32 @@ fn main() {
}
}
}
app.route("/foo/bar/baz", get(|| async {}))
app.route("/foo/bar/baz", get(|| async {})).into_service()
});
benchmark("receive-json")
.method("post")
.headers(&[("content-type", "application/json")])
.body(r#"{"n": 123, "s": "hi there", "b": false}"#)
.run(|| Router::new().route("/", post(|_: Json<Payload>| async {})));
.run(|| {
Router::new()
.route("/", post(|_: Json<Payload>| async {}))
.into_service()
});
benchmark("send-json").run(|| {
Router::new().route(
"/",
get(|| async {
Json(Payload {
n: 123,
s: "hi there".to_owned(),
b: false,
})
}),
)
Router::new()
.route(
"/",
get(|| async {
Json(Payload {
n: 123,
s: "hi there".to_owned(),
b: false,
})
}),
)
.into_service()
});
let state = AppState {
@ -65,10 +75,14 @@ fn main() {
Router::new()
.route("/", get(|_: Extension<AppState>| async {}))
.layer(Extension(state.clone()))
.into_service()
});
benchmark("state")
.run(|| Router::with_state(state.clone()).route("/", get(|_: State<AppState>| async {})));
benchmark("state").run(|| {
Router::new()
.route("/", get(|_: State<AppState>| async {}))
.with_state(state.clone())
});
}
#[derive(Clone)]
@ -117,10 +131,9 @@ impl BenchmarkBuilder {
config_method!(headers, &'static [(&'static str, &'static str)]);
config_method!(body, &'static str);
fn run<F, S>(self, f: F)
fn run<F>(self, f: F)
where
F: FnOnce() -> Router<S>,
S: Clone + Send + Sync + 'static,
F: FnOnce() -> RouterService,
{
// support only running some benchmarks with
// ```

173
axum/src/boxed.rs Normal file
View file

@ -0,0 +1,173 @@
use std::{convert::Infallible, fmt};
use crate::{body::HttpBody, handler::Handler, routing::Route, Router};
pub(crate) struct BoxedIntoRoute<S, B, E>(Box<dyn ErasedIntoRoute<S, B, E>>);
impl<S, B> BoxedIntoRoute<S, B, Infallible>
where
S: Clone + Send + Sync + 'static,
B: Send + 'static,
{
pub(crate) fn from_handler<H, T>(handler: H) -> Self
where
H: Handler<T, S, B>,
T: 'static,
{
Self(Box::new(MakeErasedHandler {
handler,
into_route: |handler, state| Route::new(Handler::with_state(handler, state)),
}))
}
pub(crate) fn from_router(router: Router<S, B>) -> Self
where
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
Self(Box::new(MakeErasedRouter {
router,
into_route: |router, state| Route::new(router.with_state(state)),
}))
}
}
impl<S, B, E> BoxedIntoRoute<S, B, E> {
pub(crate) fn map<F, B2, E2>(self, f: F) -> BoxedIntoRoute<S, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: 'static,
E2: 'static,
{
BoxedIntoRoute(Box::new(Map {
inner: self.0,
layer: Box::new(f),
}))
}
pub(crate) fn into_route(self, state: S) -> Route<B, E> {
self.0.into_route(state)
}
}
impl<S, B, E> Clone for BoxedIntoRoute<S, B, E> {
fn clone(&self) -> Self {
Self(self.0.clone_box())
}
}
impl<S, B, E> fmt::Debug for BoxedIntoRoute<S, B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("BoxedIntoRoute").finish()
}
}
pub(crate) trait ErasedIntoRoute<S, B, E>: Send {
fn clone_box(&self) -> Box<dyn ErasedIntoRoute<S, B, E>>;
fn into_route(self: Box<Self>, state: S) -> Route<B, E>;
}
pub(crate) struct MakeErasedHandler<H, S, B> {
pub(crate) handler: H,
pub(crate) into_route: fn(H, S) -> Route<B>,
}
impl<H, S, B> ErasedIntoRoute<S, B, Infallible> for MakeErasedHandler<H, S, B>
where
H: Clone + Send + 'static,
S: 'static,
B: 'static,
{
fn clone_box(&self) -> Box<dyn ErasedIntoRoute<S, B, Infallible>> {
Box::new(self.clone())
}
fn into_route(self: Box<Self>, state: S) -> Route<B> {
(self.into_route)(self.handler, state)
}
}
impl<H, S, B> Clone for MakeErasedHandler<H, S, B>
where
H: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
into_route: self.into_route,
}
}
}
pub(crate) struct MakeErasedRouter<S, B> {
pub(crate) router: Router<S, B>,
pub(crate) into_route: fn(Router<S, B>, S) -> Route<B>,
}
impl<S, B> ErasedIntoRoute<S, B, Infallible> for MakeErasedRouter<S, B>
where
S: Clone + Send + 'static,
B: 'static,
{
fn clone_box(&self) -> Box<dyn ErasedIntoRoute<S, B, Infallible>> {
Box::new(self.clone())
}
fn into_route(self: Box<Self>, state: S) -> Route<B> {
(self.into_route)(self.router, state)
}
}
impl<S, B> Clone for MakeErasedRouter<S, B>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
router: self.router.clone(),
into_route: self.into_route,
}
}
}
pub(crate) struct Map<S, B, E, B2, E2> {
pub(crate) inner: Box<dyn ErasedIntoRoute<S, B, E>>,
pub(crate) layer: Box<dyn LayerFn<B, E, B2, E2>>,
}
impl<S, B, E, B2, E2> ErasedIntoRoute<S, B2, E2> for Map<S, B, E, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
B2: 'static,
E2: 'static,
{
fn clone_box(&self) -> Box<dyn ErasedIntoRoute<S, B2, E2>> {
Box::new(Self {
inner: self.inner.clone_box(),
layer: self.layer.clone_box(),
})
}
fn into_route(self: Box<Self>, state: S) -> Route<B2, E2> {
(self.layer)(self.inner.into_route(state))
}
}
pub(crate) trait LayerFn<B, E, B2, E2>: FnOnce(Route<B, E>) -> Route<B2, E2> + Send {
fn clone_box(&self) -> Box<dyn LayerFn<B, E, B2, E2>>;
}
impl<F, B, E, B2, E2> LayerFn<B, E, B2, E2> for F
where
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
{
fn clone_box(&self) -> Box<dyn LayerFn<B, E, B2, E2>> {
Box::new(self.clone())
}
}

View file

@ -462,10 +462,11 @@ async fn handler(_: State<AppState>) {}
let state = AppState {};
let app = Router::with_state(state.clone())
let app = Router::new()
.route("/", get(handler))
.layer(MyLayer { state });
# let _: Router<_> = app;
.layer(MyLayer { state: state.clone() })
.with_state(state);
# let _: axum::routing::RouterService = app;
```
# Passing state from middleware to handlers

View file

@ -2,6 +2,10 @@ Convert this router into a [`MakeService`], that will store `C`'s
associated `ConnectInfo` in a request extension such that [`ConnectInfo`]
can extract it.
This is a convenience method for routers that don't have any state (i.e. the
state type is `()`). Use [`RouterService::into_make_service_with_connect_info`]
otherwise.
This enables extracting things like the client's remote address.
Extracting [`std::net::SocketAddr`] is supported out of the box:

View file

@ -147,6 +147,43 @@ let app = Router::new()
Here requests like `GET /api/not-found` will go to `api_fallback`.
# Nesting a router with a different state type
By default `nest` requires a `Router` with the same state type as the outer
`Router`. If you need to nest a `Router` with a different state type you can
use [`Router::with_state`] and [`Router::nest_service`]:
```rust
use axum::{
Router,
routing::get,
extract::State,
};
#[derive(Clone)]
struct InnerState {}
#[derive(Clone)]
struct OuterState {}
async fn inner_handler(state: State<InnerState>) {}
let inner_router = Router::new()
.route("/bar", get(inner_handler))
.with_state(InnerState {});
async fn outer_handler(state: State<OuterState>) {}
let app = Router::new()
.route("/", get(outer_handler))
.nest_service("/foo", inner_router)
.with_state(OuterState {});
# let _: axum::routing::RouterService = app;
```
Note that the inner router will still inherit the fallback from the outer
router.
# Panics
- If the route overlaps with another route. See [`Router::route`]

View file

@ -31,7 +31,10 @@ use std::{
/// let state = AppState {};
///
/// // create a `Router` that holds our state
/// let app = Router::with_state(state).route("/", get(handler));
/// let app = Router::new()
/// .route("/", get(handler))
/// // provide the state so the router can access it
/// .with_state(state);
///
/// async fn handler(
/// // access the state via the `State` extractor
@ -40,7 +43,7 @@ use std::{
/// ) {
/// // use `state`...
/// }
/// # let _: Router<AppState> = app;
/// # let _: axum::routing::RouterService = app;
/// ```
///
/// # With `MethodRouter`
@ -119,9 +122,10 @@ use std::{
/// api_state: ApiState {},
/// };
///
/// let app = Router::with_state(state)
/// let app = Router::new()
/// .route("/", get(handler))
/// .route("/api/users", get(api_users));
/// .route("/api/users", get(api_users))
/// .with_state(state);
///
/// async fn api_users(
/// // access the api specific state
@ -134,9 +138,11 @@ use std::{
/// State(state): State<AppState>,
/// ) {
/// }
/// # let _: Router<AppState> = app;
/// # let _: axum::routing::RouterService = app;
/// ```
///
/// For convenience `FromRef` can also be derived using `#[derive(FromRef)]`.
///
/// # For library authors
///
/// If you're writing a library that has an extractor that needs state, this is the recommended way

View file

@ -1,123 +0,0 @@
use std::convert::Infallible;
use super::Handler;
use crate::routing::Route;
pub(crate) struct BoxedHandler<S, B, E = Infallible>(Box<dyn ErasedHandler<S, B, E>>);
impl<S, B> BoxedHandler<S, B>
where
S: Clone + Send + Sync + 'static,
B: Send + 'static,
{
pub(crate) fn new<H, T>(handler: H) -> Self
where
H: Handler<T, S, B>,
T: 'static,
{
Self(Box::new(MakeErasedHandler {
handler,
into_route: |handler, state| Route::new(Handler::with_state(handler, state)),
}))
}
}
impl<S, B, E> BoxedHandler<S, B, E> {
pub(crate) fn map<F, B2, E2>(self, f: F) -> BoxedHandler<S, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: 'static,
E2: 'static,
{
BoxedHandler(Box::new(Map {
handler: self.0,
layer: Box::new(f),
}))
}
pub(crate) fn into_route(self, state: S) -> Route<B, E> {
self.0.into_route(state)
}
}
impl<S, B, E> Clone for BoxedHandler<S, B, E> {
fn clone(&self) -> Self {
Self(self.0.clone_box())
}
}
trait ErasedHandler<S, B, E = Infallible>: Send {
fn clone_box(&self) -> Box<dyn ErasedHandler<S, B, E>>;
fn into_route(self: Box<Self>, state: S) -> Route<B, E>;
}
struct MakeErasedHandler<H, S, B> {
handler: H,
into_route: fn(H, S) -> Route<B>,
}
impl<H, S, B> ErasedHandler<S, B> for MakeErasedHandler<H, S, B>
where
H: Clone + Send + 'static,
S: 'static,
B: 'static,
{
fn clone_box(&self) -> Box<dyn ErasedHandler<S, B>> {
Box::new(self.clone())
}
fn into_route(self: Box<Self>, state: S) -> Route<B> {
(self.into_route)(self.handler, state)
}
}
impl<H: Clone, S, B> Clone for MakeErasedHandler<H, S, B> {
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
into_route: self.into_route,
}
}
}
struct Map<S, B, E, B2, E2> {
handler: Box<dyn ErasedHandler<S, B, E>>,
layer: Box<dyn LayerFn<B, E, B2, E2>>,
}
impl<S, B, E, B2, E2> ErasedHandler<S, B2, E2> for Map<S, B, E, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
B2: 'static,
E2: 'static,
{
fn clone_box(&self) -> Box<dyn ErasedHandler<S, B2, E2>> {
Box::new(Self {
handler: self.handler.clone_box(),
layer: self.layer.clone_box(),
})
}
fn into_route(self: Box<Self>, state: S) -> Route<B2, E2> {
(self.layer)(self.handler.into_route(state))
}
}
trait LayerFn<B, E, B2, E2>: FnOnce(Route<B, E>) -> Route<B2, E2> + Send {
fn clone_box(&self) -> Box<dyn LayerFn<B, E, B2, E2>>;
}
impl<F, B, E, B2, E2> LayerFn<B, E, B2, E2> for F
where
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
{
fn clone_box(&self) -> Box<dyn LayerFn<B, E, B2, E2>> {
Box::new(self.clone())
}
}

View file

@ -49,11 +49,9 @@ use tower::ServiceExt;
use tower_layer::Layer;
use tower_service::Service;
mod boxed;
pub mod future;
mod service;
pub(crate) use self::boxed::BoxedHandler;
pub use self::service::HandlerService;
/// Trait for async functions that can be used to handle requests.

View file

@ -188,8 +188,9 @@
//!
//! let shared_state = Arc::new(AppState { /* ... */ });
//!
//! let app = Router::with_state(shared_state)
//! .route("/", get(handler));
//! let app = Router::new()
//! .route("/", get(handler))
//! .with_state(shared_state);
//!
//! async fn handler(
//! State(state): State<Arc<AppState>>,
@ -434,6 +435,7 @@
#[macro_use]
pub(crate) mod macros;
mod boxed;
mod extension;
#[cfg(feature = "form")]
mod form;

View file

@ -133,10 +133,11 @@ pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
///
/// let state = AppState { /* ... */ };
///
/// let app = Router::with_state(state.clone())
/// let app = Router::new()
/// .route("/", get(|| async { /* ... */ }))
/// .route_layer(middleware::from_fn_with_state(state, my_middleware));
/// # let app: Router<_> = app;
/// .route_layer(middleware::from_fn_with_state(state.clone(), my_middleware))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// ```
pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
FromFnLayer {

View file

@ -148,10 +148,11 @@ pub fn map_request<F, T>(f: F) -> MapRequestLayer<F, (), T> {
///
/// let state = AppState { /* ... */ };
///
/// let app = Router::with_state(state.clone())
/// let app = Router::new()
/// .route("/", get(|| async { /* ... */ }))
/// .route_layer(map_request_with_state(state, my_middleware));
/// # let app: Router<_> = app;
/// .route_layer(map_request_with_state(state.clone(), my_middleware))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// ```
pub fn map_request_with_state<F, S, T>(state: S, f: F) -> MapRequestLayer<F, S, T> {
MapRequestLayer {

View file

@ -132,10 +132,11 @@ pub fn map_response<F, T>(f: F) -> MapResponseLayer<F, (), T> {
///
/// let state = AppState { /* ... */ };
///
/// let app = Router::with_state(state.clone())
/// let app = Router::new()
/// .route("/", get(|| async { /* ... */ }))
/// .route_layer(map_response_with_state(state, my_middleware));
/// # let app: Router<_> = app;
/// .route_layer(map_response_with_state(state.clone(), my_middleware))
/// .with_state(state);
/// # let _: axum::routing::RouterService = app;
/// ```
pub fn map_response_with_state<F, S, T>(state: S, f: F) -> MapResponseLayer<F, S, T> {
MapResponseLayer {

View file

@ -98,7 +98,7 @@ mod tests {
}
}
Router::<_, Body>::new()
Router::<(), Body>::new()
.route("/", get(impl_trait_ok))
.route("/", get(impl_trait_err))
.route("/", get(impl_trait_both))
@ -208,7 +208,7 @@ mod tests {
)
}
Router::<_, Body>::new()
Router::<(), Body>::new()
.route("/", get(status))
.route("/", get(status_headermap))
.route("/", get(status_header_array))

View file

@ -5,12 +5,12 @@ use super::{FallbackRoute, IntoMakeService};
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
use crate::{
body::{Body, Bytes, HttpBody},
boxed::BoxedIntoRoute,
error_handling::{HandleError, HandleErrorLayer},
handler::{BoxedHandler, Handler},
handler::Handler,
http::{Method, Request, StatusCode},
response::Response,
routing::{future::RouteFuture, Fallback, MethodFilter, Route},
util::try_downcast,
};
use axum_core::response::IntoResponse;
use bytes::BytesMut;
@ -606,7 +606,7 @@ where
{
self.on_endpoint(
filter,
MethodEndpoint::BoxedHandler(BoxedHandler::new(handler)),
MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)),
)
}
@ -626,7 +626,7 @@ where
T: 'static,
S: Send + Sync + 'static,
{
self.fallback = Fallback::BoxedHandler(BoxedHandler::new(handler));
self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
self
}
}
@ -749,46 +749,6 @@ where
}
}
pub(crate) fn map_state<S2>(self, state: &S) -> MethodRouter<S2, B, E>
where
E: 'static,
S: 'static,
S2: 'static,
{
MethodRouter {
get: self.get.map_state(state),
head: self.head.map_state(state),
delete: self.delete.map_state(state),
options: self.options.map_state(state),
patch: self.patch.map_state(state),
post: self.post.map_state(state),
put: self.put.map_state(state),
trace: self.trace.map_state(state),
fallback: self.fallback.map_state(state),
allow_header: self.allow_header,
}
}
pub(crate) fn downcast_state<S2>(self) -> Option<MethodRouter<S2, B, E>>
where
E: 'static,
S: 'static,
S2: 'static,
{
Some(MethodRouter {
get: self.get.downcast_state()?,
head: self.head.downcast_state()?,
delete: self.delete.downcast_state()?,
options: self.options.downcast_state()?,
patch: self.patch.downcast_state()?,
post: self.post.downcast_state()?,
put: self.put.downcast_state()?,
trace: self.trace.downcast_state()?,
fallback: self.fallback.downcast_state()?,
allow_header: self.allow_header,
})
}
/// Chain an additional service that will accept requests matching the given
/// `MethodFilter`.
///
@ -964,17 +924,14 @@ where
) -> MethodRouter<S, NewReqBody, NewError>
where
L: Layer<Route<B, E>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>, Error = NewError> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<NewError> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
E: 'static,
S: 'static,
{
let layer_fn = move |svc| {
let svc = layer.layer(svc);
let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
Route::new(svc)
};
let layer_fn = move |route: Route<B, E>| route.layer(layer.clone());
MethodRouter {
get: self.get.map(layer_fn.clone()),
@ -1182,7 +1139,7 @@ where
enum MethodEndpoint<S, B, E> {
None,
Route(Route<B, E>),
BoxedHandler(BoxedHandler<S, B, E>),
BoxedHandler(BoxedIntoRoute<S, B, E>),
}
impl<S, B, E> MethodEndpoint<S, B, E>
@ -1213,32 +1170,6 @@ where
}
}
fn map_state<S2>(self, state: &S) -> MethodEndpoint<S2, B, E> {
match self {
Self::None => MethodEndpoint::None,
Self::Route(route) => MethodEndpoint::Route(route),
Self::BoxedHandler(handler) => MethodEndpoint::Route(handler.into_route(state.clone())),
}
}
fn downcast_state<S2>(self) -> Option<MethodEndpoint<S2, B, E>>
where
S: 'static,
B: 'static,
E: 'static,
S2: 'static,
{
match self {
Self::None => Some(MethodEndpoint::None),
Self::Route(route) => Some(MethodEndpoint::Route(route)),
Self::BoxedHandler(handler) => {
try_downcast::<BoxedHandler<S2, B, E>, BoxedHandler<S, B, E>>(handler)
.map(MethodEndpoint::BoxedHandler)
.ok()
}
}
}
fn into_route(self, state: &S) -> Option<Route<B, E>> {
match self {
Self::None => None,

View file

@ -1,27 +1,19 @@
//! Routing between [`Service`]s and handlers.
use self::not_found::NotFound;
use self::{not_found::NotFound, strip_prefix::StripPrefix};
#[cfg(feature = "tokio")]
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
use crate::{
body::{Body, HttpBody},
handler::{BoxedHandler, Handler},
boxed::BoxedIntoRoute,
handler::Handler,
util::try_downcast,
};
use axum_core::response::{IntoResponse, Response};
use http::Request;
use matchit::MatchError;
use std::{
any::{type_name, TypeId},
collections::HashMap,
convert::Infallible,
fmt,
sync::Arc,
};
use tower::{
util::{BoxCloneService, MapResponseLayer, Oneshot},
ServiceBuilder,
};
use std::{collections::HashMap, convert::Infallible, fmt, sync::Arc};
use tower::util::{BoxCloneService, Oneshot};
use tower_layer::Layer;
use tower_service::Service;
@ -68,19 +60,14 @@ impl RouteId {
/// The router type for composing handlers and services.
pub struct Router<S = (), B = Body> {
state: Option<S>,
routes: HashMap<RouteId, Endpoint<S, B>>,
node: Arc<Node>,
fallback: Fallback<S, B>,
}
impl<S, B> Clone for Router<S, B>
where
S: Clone,
{
impl<S, B> Clone for Router<S, B> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
routes: self.routes.clone(),
node: Arc::clone(&self.node),
fallback: self.fallback.clone(),
@ -91,10 +78,10 @@ where
impl<S, B> Default for Router<S, B>
where
B: HttpBody + Send + 'static,
S: Default + Clone + Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
fn default() -> Self {
Self::with_state(S::default())
Self::new()
}
}
@ -104,7 +91,6 @@ where
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Router")
.field("state", &self.state)
.field("routes", &self.routes)
.field("node", &self.node)
.field("fallback", &self.fallback)
@ -115,71 +101,19 @@ where
pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param";
pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";
impl<B> Router<(), B>
where
B: HttpBody + Send + 'static,
{
/// Create a new `Router`.
///
/// Unless you add additional routes this will respond with `404 Not Found` to
/// all requests.
pub fn new() -> Self {
Self::with_state(())
}
}
impl<B> Router<(), B> where B: HttpBody + Send + 'static {}
impl<S, B> Router<S, B>
where
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
/// Create a new `Router` with the given state.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
/// Create a new `Router`.
///
/// Unless you add additional routes this will respond with `404 Not Found` to
/// all requests.
pub fn with_state(state: S) -> Self {
pub fn new() -> Self {
Self {
state: Some(state),
routes: Default::default(),
node: Default::default(),
fallback: Fallback::Default(Route::new(NotFound)),
}
}
/// Create a new `Router` that inherits its state from another `Router` that it is merged into
/// or nested under.
///
/// # Example
///
/// ```
/// use axum::{Router, routing::get, extract::State};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// // A router that will be nested under the `app` router.
/// //
/// // By using `inherit_state` we'll reuse the state from the `app` router.
/// let nested_router = Router::inherit_state()
/// .route("/bar", get(|state: State<AppState>| async {}));
///
/// // A router that will be merged into the `app` router.
/// let merged_router = Router::inherit_state()
/// .route("/baz", get(|state: State<AppState>| async {}));
///
/// let app = Router::with_state(AppState {})
/// .route("/", get(|state: State<AppState>| async {}))
/// .nest("/foo", nested_router)
/// .merge(merged_router);
///
/// // `app` now has routes for `/`, `/foo/bar`, and `/baz` that all use the same state.
/// # let _: Router<AppState> = app;
/// ```
pub fn inherit_state() -> Self {
Self {
state: None,
routes: Default::default(),
node: Default::default(),
fallback: Fallback::Default(Route::new(NotFound)),
@ -228,18 +162,12 @@ where
}
#[doc = include_str!("../docs/routing/route_service.md")]
pub fn route_service<T>(mut self, path: &str, service: T) -> Self
pub fn route_service<T>(self, path: &str, service: T) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
if path.is_empty() {
panic!("Paths must start with a `/`. Use \"/\" for root routes");
} else if !path.starts_with('/') {
panic!("Paths must start with a `/`");
}
let service = match try_downcast::<RouterService<B>, _>(service) {
Ok(_) => {
panic!(
@ -250,11 +178,20 @@ where
Err(svc) => svc,
};
self.route_endpoint(path, Endpoint::Route(Route::new(service)))
}
#[track_caller]
fn route_endpoint(mut self, path: &str, endpoint: Endpoint<S, B>) -> Self {
if path.is_empty() {
panic!("Paths must start with a `/`. Use \"/\" for root routes");
} else if !path.starts_with('/') {
panic!("Paths must start with a `/`");
}
let id = RouteId::next();
let endpoint = Endpoint::Route(Route::new(service));
self.set_node(path, id);
self.routes.insert(id, endpoint);
self
}
@ -270,29 +207,27 @@ where
#[doc = include_str!("../docs/routing/nest.md")]
#[track_caller]
pub fn nest<S2>(self, path: &str, mut router: Router<S2, B>) -> Self
where
S2: Clone + Send + Sync + 'static,
{
if router.state.is_none() {
let s = self.state.clone();
router.state = match try_downcast::<Option<S2>, Option<S>>(s) {
Ok(state) => state,
Err(_) => panic!(
"can't nest a `Router` that wants to inherit state of type `{}` \
into a `Router` with a state type of `{}`",
type_name::<S2>(),
type_name::<S>(),
),
};
}
self.nest_service(path, router.into_service())
pub fn nest(self, path: &str, router: Router<S, B>) -> Self {
self.nest_endpoint(path, RouterOrService::<_, _, NotFound>::Router(router))
}
/// Like [`nest`](Self::nest), but accepts an arbitrary `Service`.
#[track_caller]
pub fn nest_service<T>(mut self, mut path: &str, svc: T) -> Self
pub fn nest_service<T>(self, path: &str, svc: T) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
self.nest_endpoint(path, RouterOrService::Service(svc))
}
#[track_caller]
fn nest_endpoint<T>(
mut self,
mut path: &str,
router_or_service: RouterOrService<S, B, T>,
) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
@ -315,16 +250,27 @@ where
format!("{path}/*{NEST_TAIL_PARAM}")
};
let svc = strip_prefix::StripPrefix::new(svc, prefix);
self = self.route_service(&path, svc.clone());
let endpoint = match router_or_service {
RouterOrService::Router(router) => {
let prefix = prefix.to_owned();
let boxed = BoxedIntoRoute::from_router(router)
.map(move |route| Route::new(StripPrefix::new(route, &prefix)));
Endpoint::NestedRouter(boxed)
}
RouterOrService::Service(svc) => {
Endpoint::Route(Route::new(StripPrefix::new(svc, prefix)))
}
};
self = self.route_endpoint(&path, endpoint.clone());
// `/*rest` is not matched by `/` so we need to also register a router at the
// prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself
// wouldn't match, which it should
self = self.route_service(prefix, svc.clone());
self = self.route_endpoint(prefix, endpoint.clone());
if !prefix.ends_with('/') {
// same goes for `/foo/`, that should also match
self = self.route_service(&format!("{prefix}/"), svc);
self = self.route_endpoint(&format!("{prefix}/"), endpoint);
}
self
@ -332,66 +278,27 @@ where
#[doc = include_str!("../docs/routing/merge.md")]
#[track_caller]
pub fn merge<S2, R>(mut self, other: R) -> Self
pub fn merge<R>(mut self, other: R) -> Self
where
R: Into<Router<S2, B>>,
S2: Clone + Send + Sync + 'static,
R: Into<Router<S, B>>,
{
let Router {
state,
routes,
node,
fallback,
} = other.into();
let cast_method_router_closure_slot;
let (fallback, cast_method_router) = match state {
// other has its state set
Some(state) => {
let fallback = fallback.map_state(&state);
cast_method_router_closure_slot = move |r: MethodRouter<_, _>| r.map_state(&state);
let cast_method_router = &cast_method_router_closure_slot
as &dyn Fn(MethodRouter<_, _>) -> MethodRouter<_, _>;
(fallback, cast_method_router)
}
// other wants to inherit its state
None => {
if TypeId::of::<S>() != TypeId::of::<S2>() {
panic!(
"can't merge a `Router` that wants to inherit state of type `{}` \
into a `Router` with a state type of `{}`",
type_name::<S2>(),
type_name::<S>(),
);
}
// With the branch above not taken, we know we can cast S2 to S
let fallback = fallback.downcast_state::<S>().unwrap();
fn cast_method_router<S, S2, B>(r: MethodRouter<S2, B>) -> MethodRouter<S, B>
where
B: Send + 'static,
S: 'static,
S2: Clone + 'static,
{
r.downcast_state().unwrap()
}
(fallback, &cast_method_router as _)
}
};
for (id, route) in routes {
let path = node
.route_id_to_path
.get(&id)
.expect("no path for route id. This is a bug in axum. Please file an issue");
self = match route {
Endpoint::MethodRouter(method_router) => {
self.route(path, cast_method_router(method_router))
}
Endpoint::MethodRouter(method_router) => self.route(path, method_router),
Endpoint::Route(route) => self.route_service(path, route),
Endpoint::NestedRouter(router) => {
self.route_endpoint(path, Endpoint::NestedRouter(router))
}
};
}
@ -412,30 +319,18 @@ where
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
{
let layer = ServiceBuilder::new()
.map_err(Into::into)
.layer(MapResponseLayer::new(IntoResponse::into_response))
.layer(layer)
.into_inner();
let routes = self
.routes
.into_iter()
.map(|(id, route)| {
let route = match route {
Endpoint::MethodRouter(method_router) => {
Endpoint::MethodRouter(method_router.layer(layer.clone()))
}
Endpoint::Route(route) => Endpoint::Route(Route::new(layer.layer(route))),
};
.map(|(id, endpoint)| {
let route = endpoint.layer(layer.clone());
(id, route)
})
.collect();
let fallback = self.fallback.map(move |svc| Route::new(layer.layer(svc)));
let fallback = self.fallback.map(|route| route.layer(layer));
Router {
state: self.state,
routes,
node: self.node,
fallback,
@ -459,28 +354,16 @@ where
);
}
let layer = ServiceBuilder::new()
.map_err(Into::into)
.layer(MapResponseLayer::new(IntoResponse::into_response))
.layer(layer)
.into_inner();
let routes = self
.routes
.into_iter()
.map(|(id, route)| {
let route = match route {
Endpoint::MethodRouter(method_router) => {
Endpoint::MethodRouter(method_router.layer(layer.clone()))
}
Endpoint::Route(route) => Endpoint::Route(Route::new(layer.layer(route))),
};
.map(|(id, endpoint)| {
let route = endpoint.layer(layer.clone());
(id, route)
})
.collect();
Router {
state: self.state,
routes,
node: self.node,
fallback: self.fallback,
@ -493,7 +376,7 @@ where
H: Handler<T, S, B>,
T: 'static,
{
self.fallback = Fallback::BoxedHandler(BoxedHandler::new(handler));
self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
self
}
@ -510,14 +393,26 @@ where
self
}
/// Convert this router into a [`RouterService`] by providing the state.
///
/// Once this method has been called you cannot add more routes. So it must be called as last.
pub fn with_state(self, state: S) -> RouterService<B> {
RouterService::new(self, state)
}
}
impl<B> Router<(), B>
where
B: HttpBody + Send + 'static,
{
/// Convert this router into a [`RouterService`].
///
/// # Panics
/// This is a convenience method for routers that don't have any state (i.e. the state type is
/// `()`). Use [`Router::with_state`] otherwise.
///
/// Panics if the router was constructed with [`Router::inherit_state`].
#[track_caller]
/// Once this method has been called you cannot add more routes. So it must be called as last.
pub fn into_service(self) -> RouterService<B> {
RouterService::new(self)
RouterService::new(self, ())
}
/// Convert this router into a [`MakeService`], that is a [`Service`] whose
@ -542,8 +437,10 @@ where
/// # };
/// ```
///
/// This is a convenience method for routers that don't have any state (i.e. the state type is
/// `()`). Use [`RouterService::into_make_service`] otherwise.
///
/// [`MakeService`]: tower::make::MakeService
#[track_caller]
pub fn into_make_service(self) -> IntoMakeService<RouterService<B>> {
IntoMakeService::new(self.into_service())
}
@ -601,39 +498,13 @@ impl fmt::Debug for Node {
enum Fallback<S, B, E = Infallible> {
Default(Route<B, E>),
Service(Route<B, E>),
BoxedHandler(BoxedHandler<S, B, E>),
BoxedHandler(BoxedIntoRoute<S, B, E>),
}
impl<S, B, E> Fallback<S, B, E>
where
S: Clone,
{
fn map_state<S2>(self, state: &S) -> Fallback<S2, B, E> {
match self {
Self::Default(route) => Fallback::Default(route),
Self::Service(route) => Fallback::Service(route),
Self::BoxedHandler(handler) => Fallback::Service(handler.into_route(state.clone())),
}
}
fn downcast_state<S2>(self) -> Option<Fallback<S2, B, E>>
where
S: 'static,
B: 'static,
E: 'static,
S2: 'static,
{
match self {
Self::Default(route) => Some(Fallback::Default(route)),
Self::Service(route) => Some(Fallback::Service(route)),
Self::BoxedHandler(handler) => {
try_downcast::<BoxedHandler<S2, B, E>, BoxedHandler<S, B, E>>(handler)
.map(Fallback::BoxedHandler)
.ok()
}
}
}
fn merge(self, other: Self) -> Option<Self> {
match (self, other) {
(Self::Default(_), pick @ Self::Default(_)) => Some(pick),
@ -729,6 +600,41 @@ impl<B, E> FallbackRoute<B, E> {
enum Endpoint<S, B> {
MethodRouter(MethodRouter<S, B>),
Route(Route<B>),
NestedRouter(BoxedIntoRoute<S, B, Infallible>),
}
impl<S, B> Endpoint<S, B>
where
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
fn into_route(self, state: S) -> Route<B> {
match self {
Endpoint::MethodRouter(method_router) => Route::new(method_router.with_state(state)),
Endpoint::Route(route) => route,
Endpoint::NestedRouter(router) => router.into_route(state),
}
}
fn layer<L, NewReqBody>(self, layer: L) -> Endpoint<S, NewReqBody>
where
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewReqBody: 'static,
{
match self {
Endpoint::MethodRouter(method_router) => {
Endpoint::MethodRouter(method_router.layer(layer))
}
Endpoint::Route(route) => Endpoint::Route(route.layer(layer)),
Endpoint::NestedRouter(router) => {
Endpoint::NestedRouter(router.map(|route| route.layer(layer)))
}
}
}
}
impl<S, B> Clone for Endpoint<S, B> {
@ -736,19 +642,31 @@ impl<S, B> Clone for Endpoint<S, B> {
match self {
Self::MethodRouter(inner) => Self::MethodRouter(inner.clone()),
Self::Route(inner) => Self::Route(inner.clone()),
Self::NestedRouter(router) => Self::NestedRouter(router.clone()),
}
}
}
impl<S, B> fmt::Debug for Endpoint<S, B> {
impl<S, B> fmt::Debug for Endpoint<S, B>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MethodRouter(inner) => inner.fmt(f),
Self::Route(inner) => inner.fmt(f),
Self::MethodRouter(method_router) => {
f.debug_tuple("MethodRouter").field(method_router).finish()
}
Self::Route(route) => f.debug_tuple("Route").field(route).finish(),
Self::NestedRouter(router) => f.debug_tuple("NestedRouter").field(router).finish(),
}
}
}
enum RouterOrService<S, B, T> {
Router(Router<S, B>),
Service(T),
}
#[test]
#[allow(warnings)]
fn traits() {

View file

@ -17,9 +17,10 @@ use std::{
task::{Context, Poll},
};
use tower::{
util::{BoxCloneService, Oneshot},
ServiceExt,
util::{BoxCloneService, MapResponseLayer, Oneshot},
ServiceBuilder, ServiceExt,
};
use tower_layer::Layer;
use tower_service::Service;
/// How routes are stored inside a [`Router`](super::Router).
@ -46,6 +47,25 @@ impl<B, E> Route<B, E> {
) -> Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>> {
self.0.clone().oneshot(req)
}
pub(crate) fn layer<L, NewReqBody, NewError>(self, layer: L) -> Route<NewReqBody, NewError>
where
L: Layer<Route<B, E>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<NewError> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewReqBody: 'static,
NewError: 'static,
{
let layer = ServiceBuilder::new()
.map_err(Into::into)
.layer(MapResponseLayer::new(IntoResponse::into_response))
.layer(layer)
.into_inner();
Route::new(layer.layer(self))
}
}
impl<B, E> Clone for Route<B, E> {

View file

@ -1,5 +1,5 @@
use super::{
future::RouteFuture, url_params, Endpoint, FallbackRoute, Node, Route, RouteId, Router,
future::RouteFuture, url_params, FallbackRoute, IntoMakeService, Node, Route, RouteId, Router,
};
use crate::{
body::{Body, HttpBody},
@ -28,26 +28,17 @@ impl<B> RouterService<B>
where
B: HttpBody + Send + 'static,
{
#[track_caller]
pub(super) fn new<S>(router: Router<S, B>) -> Self
pub(super) fn new<S>(router: Router<S, B>, state: S) -> Self
where
S: Clone + Send + Sync + 'static,
{
let state = router
.state
.expect("Can't turn a `Router` that wants to inherit state into a service");
let fallback = router.fallback.into_fallback_route(&state);
let routes = router
.routes
.into_iter()
.map(|(route_id, endpoint)| {
let route = match endpoint {
Endpoint::MethodRouter(method_router) => {
Route::new(method_router.with_state(state.clone()))
}
Endpoint::Route(route) => route,
};
let route = endpoint.into_route(state.clone());
(route_id, route)
})
.collect();
@ -55,7 +46,7 @@ where
Self {
routes,
node: router.node,
fallback: router.fallback.into_fallback_route(&state),
fallback,
}
}
@ -84,6 +75,28 @@ where
route.call(req)
}
/// Convert the router into a [`MakeService`] and no state.
///
/// See [`Router::into_make_service`] for more details.
///
/// [`MakeService`]: tower::make::MakeService
pub fn into_make_service(self) -> IntoMakeService<RouterService<B>> {
IntoMakeService::new(self)
}
/// Convert the router into a [`MakeService`] which stores information
/// about the incoming connection and has no state.
///
/// See [`Router::into_make_service_with_connect_info`] for more details.
///
/// [`MakeService`]: tower::make::MakeService
#[cfg(feature = "tokio")]
pub fn into_make_service_with_connect_info<C>(
self,
) -> crate::extract::connect_info::IntoMakeServiceWithConnectInfo<RouterService<B>, C> {
crate::extract::connect_info::IntoMakeServiceWithConnectInfo::new(self)
}
}
impl<B> Clone for RouterService<B> {

View file

@ -1,3 +1,5 @@
use tower::ServiceExt;
use super::*;
use crate::middleware::{map_request, map_response};
@ -50,10 +52,11 @@ async fn or() {
#[tokio::test]
async fn fallback_accessing_state() {
let app = Router::with_state("state")
.fallback(|State(state): State<&'static str>| async move { state });
let app = Router::new()
.fallback(|State(state): State<&'static str>| async move { state })
.with_state("state");
let client = TestClient::new(app);
let client = TestClient::from_service(app);
let res = client.get("/does-not-exist").send().await;
assert_eq!(res.status(), StatusCode::OK);
@ -158,3 +161,45 @@ async fn also_inherits_default_layered_fallback() {
assert_eq!(res.headers()["x-from-fallback"], "1");
assert_eq!(res.text().await, "outer");
}
#[tokio::test]
async fn fallback_inherited_into_nested_router_service() {
let inner = Router::new()
.route(
"/bar",
get(|State(state): State<&'static str>| async move { state }),
)
.with_state("inner");
// with a different state
let app = Router::<()>::new()
.nest_service("/foo", inner)
.fallback(outer_fallback);
let client = TestClient::new(app);
let res = client.get("/foo/not-found").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "outer");
}
#[tokio::test]
async fn fallback_inherited_into_nested_opaque_service() {
let inner = Router::new()
.route(
"/bar",
get(|State(state): State<&'static str>| async move { state }),
)
.with_state("inner")
// even if the service is made more opaque it should still inherit the fallback
.boxed_clone();
// with a different state
let app = Router::<()>::new()
.nest_service("/foo", inner)
.fallback(outer_fallback);
let client = TestClient::new(app);
let res = client.get("/foo/not-found").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "outer");
}

View file

@ -397,87 +397,3 @@ async fn middleware_that_return_early() {
);
assert_eq!(client.get("/public").send().await.status(), StatusCode::OK);
}
#[tokio::test]
async fn merge_with_different_state_type() {
let inner = Router::with_state("inner".to_owned()).route(
"/foo",
get(|State(state): State<String>| async move { state }),
);
let app = Router::with_state("outer").merge(inner).route(
"/bar",
get(|State(state): State<&'static str>| async move { state }),
);
let client = TestClient::new(app);
let res = client.get("/foo").send().await;
assert_eq!(res.text().await, "inner");
let res = client.get("/bar").send().await;
assert_eq!(res.text().await, "outer");
}
#[tokio::test]
async fn merging_routes_different_method_different_states() {
let get = Router::with_state("get state").route(
"/",
get(|State(state): State<&'static str>| async move { state }),
);
let post = Router::with_state("post state").route(
"/",
post(|State(state): State<&'static str>| async move { state }),
);
let app = Router::new().merge(get).merge(post);
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.text().await, "get state");
let res = client.post("/").send().await;
assert_eq!(res.text().await, "post state");
}
#[tokio::test]
async fn merging_routes_different_paths_different_states() {
let foo = Router::with_state("foo state").route(
"/foo",
get(|State(state): State<&'static str>| async move { state }),
);
let bar = Router::with_state("bar state").route(
"/bar",
get(|State(state): State<&'static str>| async move { state }),
);
let app = Router::new().merge(foo).merge(bar);
let client = TestClient::new(app);
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "foo state");
let res = client.get("/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "bar state");
}
#[tokio::test]
async fn inherit_state_via_merge() {
let foo = Router::inherit_state().route(
"/foo",
get(|State(state): State<&'static str>| async move { state }),
);
let app = Router::with_state("state").merge(foo);
let client = TestClient::new(app);
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "state");
}

View file

@ -760,8 +760,8 @@ async fn extract_state() {
inner: InnerState { value: 2 },
};
let app = Router::with_state(state).route("/", get(handler));
let client = TestClient::new(app);
let app = Router::new().route("/", get(handler)).with_state(state);
let client = TestClient::from_service(app);
let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::OK);
@ -769,12 +769,14 @@ async fn extract_state() {
#[tokio::test]
async fn explicitly_set_state() {
let app = Router::with_state("...").route_service(
"/",
get(|State(state): State<&'static str>| async move { state }).with_state("foo"),
);
let app = Router::new()
.route_service(
"/",
get(|State(state): State<&'static str>| async move { state }).with_state("foo"),
)
.with_state("...");
let client = TestClient::new(app);
let client = TestClient::from_service(app);
let res = client.get("/").send().await;
assert_eq!(res.text().await, "foo");
}

View file

@ -265,7 +265,7 @@ async fn multiple_top_level_nests() {
#[tokio::test]
#[should_panic(expected = "Invalid route: nested routes cannot contain wildcards (*)")]
async fn nest_cannot_contain_wildcards() {
Router::<_, Body>::new().nest("/one/*rest", Router::new());
Router::<(), Body>::new().nest("/one/*rest", Router::new());
}
#[tokio::test]
@ -424,48 +424,3 @@ nested_route_test!(nest_9, nest = "/a", route = "/a/", expected = "/a/a/");
nested_route_test!(nest_11, nest = "/a/", route = "/", expected = "/a/");
nested_route_test!(nest_12, nest = "/a/", route = "/a", expected = "/a/a");
nested_route_test!(nest_13, nest = "/a/", route = "/a/", expected = "/a/a/");
#[tokio::test]
async fn nesting_with_different_state() {
let inner = Router::with_state("inner".to_owned()).route(
"/foo",
get(|State(state): State<String>| async move { state }),
);
let outer = Router::with_state("outer")
.route(
"/foo",
get(|State(state): State<&'static str>| async move { state }),
)
.nest("/nested", inner)
.route(
"/bar",
get(|State(state): State<&'static str>| async move { state }),
);
let client = TestClient::new(outer);
let res = client.get("/foo").send().await;
assert_eq!(res.text().await, "outer");
let res = client.get("/nested/foo").send().await;
assert_eq!(res.text().await, "inner");
let res = client.get("/bar").send().await;
assert_eq!(res.text().await, "outer");
}
#[tokio::test]
async fn inherit_state_via_nest() {
let foo = Router::inherit_state().route(
"/foo",
get(|State(state): State<&'static str>| async move { state }),
);
let app = Router::with_state("state").nest("/test", foo);
let client = TestClient::new(app);
let res = client.get("/test/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "state");
}

View file

@ -15,10 +15,7 @@ pub(crate) struct TestClient {
}
impl TestClient {
pub(crate) fn new<S>(router: Router<S, Body>) -> Self
where
S: Clone + Send + Sync + 'static,
{
pub(crate) fn new(router: Router<(), Body>) -> Self {
Self::from_service(router.into_service())
}

View file

@ -34,7 +34,9 @@ async fn main() {
.data(StarWars::new())
.finish();
let app = Router::with_state(schema).route("/", get(graphql_playground).post(graphql_handler));
let app = Router::new()
.route("/", get(graphql_playground).post(graphql_handler))
.with_state(schema);
println!("Playground: http://localhost:3000");

View file

@ -44,9 +44,10 @@ async fn main() {
let app_state = Arc::new(AppState { user_set, tx });
let app = Router::with_state(app_state)
let app = Router::new()
.route("/", get(index))
.route("/websocket", get(websocket_handler));
.route("/websocket", get(websocket_handler))
.with_state(app_state);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);

View file

@ -36,9 +36,10 @@ async fn main() {
let user_repo = Arc::new(ExampleUserRepo) as DynUserRepo;
// Build our application with some routes
let app = Router::with_state(user_repo)
let app = Router::new()
.route("/users/:id", get(users_show))
.route("/users", post(users_create));
.route("/users", post(users_create))
.with_state(user_repo);
// Run our application
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));

View file

@ -43,7 +43,7 @@ async fn main() {
let shared_state = SharedState::default();
// Build our application by composing routes
let app = Router::with_state(Arc::clone(&shared_state))
let app = Router::new()
.route(
"/:key",
// Add compression to `kv_get`
@ -60,7 +60,7 @@ async fn main() {
)
.route("/keys", get(list_keys))
// Nest our admin routes under `/admin`
.nest("/admin", admin_routes(shared_state))
.nest("/admin", admin_routes())
// Add middleware to all routes
.layer(
ServiceBuilder::new()
@ -69,9 +69,9 @@ async fn main() {
.load_shed()
.concurrency_limit(1024)
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
.into_inner(),
);
.layer(TraceLayer::new_for_http()),
)
.with_state(Arc::clone(&shared_state));
// Run our app with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
@ -115,7 +115,7 @@ async fn list_keys(State(state): State<SharedState>) -> String {
.join("\n")
}
fn admin_routes(state: SharedState) -> Router<SharedState> {
fn admin_routes() -> Router<SharedState> {
async fn delete_all_keys(State(state): State<SharedState>) {
state.write().unwrap().db.clear();
}
@ -124,7 +124,7 @@ fn admin_routes(state: SharedState) -> Router<SharedState> {
state.write().unwrap().db.remove(&key);
}
Router::with_state(state)
Router::new()
.route("/keys", delete(delete_all_keys))
.route("/key/:key", delete(remove_key))
// Require bearer auth for all admin routes

View file

@ -47,12 +47,13 @@ async fn main() {
oauth_client,
};
let app = Router::with_state(app_state)
let app = Router::new()
.route("/", get(index))
.route("/auth/discord", get(discord_auth))
.route("/auth/authorized", get(login_authorized))
.route("/protected", get(protected))
.route("/logout", get(logout));
.route("/logout", get(logout))
.with_state(app_state);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);

View file

@ -24,7 +24,7 @@ async fn main() {
let client = Client::new();
let app = Router::with_state(client).route("/", get(handler));
let app = Router::new().route("/", get(handler)).with_state(client);
let addr = SocketAddr::from(([127, 0, 0, 1], 4000));
println!("reverse proxy listening on {}", addr);

View file

@ -39,7 +39,7 @@ async fn main() {
// `MemoryStore` just used as an example. Don't use this in production.
let store = MemoryStore::new();
let app = Router::with_state(store).route("/", get(handler));
let app = Router::new().route("/", get(handler)).with_state(store);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);

View file

@ -46,10 +46,12 @@ async fn main() {
.expect("can connect to database");
// build our application with some routes
let app = Router::with_state(pool).route(
"/",
get(using_connection_pool_extractor).post(using_connection_extractor),
);
let app = Router::new()
.route(
"/",
get(using_connection_pool_extractor).post(using_connection_extractor),
)
.with_state(pool);
// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));

View file

@ -46,7 +46,7 @@ async fn main() {
let db = Db::default();
// Compose the routes
let app = Router::with_state(db)
let app = Router::new()
.route("/todos", get(todos_index).post(todos_create))
.route("/todos/:id", patch(todos_update).delete(todos_delete))
// Add middleware to all routes
@ -65,7 +65,8 @@ async fn main() {
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
.into_inner(),
);
)
.with_state(db);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);

View file

@ -33,10 +33,12 @@ async fn main() {
let pool = Pool::builder().build(manager).await.unwrap();
// build our application with some routes
let app = Router::with_state(pool).route(
"/",
get(using_connection_pool_extractor).post(using_connection_extractor),
);
let app = Router::new()
.route(
"/",
get(using_connection_pool_extractor).post(using_connection_extractor),
)
.with_state(pool);
// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));