checkpoint

This commit is contained in:
David Pedersen 2022-07-03 21:07:29 +02:00
parent 7e30205205
commit 02f9ebaee9
13 changed files with 172 additions and 161 deletions

View file

@ -2,8 +2,8 @@
members = [
"axum",
"axum-core",
# "axum-extra",
# "axum-macros",
"axum-extra",
"axum-macros",
# internal crate used to bump the minimum versions we
# get for some dependencies which otherwise wouldn't build

View file

@ -88,14 +88,15 @@ pub struct Cached<T>(pub T);
struct CachedEntry<T>(T);
#[async_trait]
impl<B, T> FromRequest<B> for Cached<T>
impl<S, B, T> FromRequest<S, B> for Cached<T>
where
S: Send,
B: Send,
T: FromRequest<B> + Clone + Send + Sync + 'static,
T: FromRequest<S, B> + Clone + Send + Sync + 'static,
{
type Rejection = T::Rejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
match Extension::<CachedEntry<T>>::from_request(req).await {
Ok(Extension(CachedEntry(value))) => Ok(Self(value)),
Err(_) => {
@ -139,19 +140,20 @@ mod tests {
struct Extractor(Instant);
#[async_trait]
impl<B> FromRequest<B> for Extractor
impl<S, B> FromRequest<S, B> for Extractor
where
S: Send,
B: Send,
{
type Rejection = Infallible;
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
COUNTER.fetch_add(1, Ordering::SeqCst);
Ok(Self(Instant::now()))
}
}
let mut req = RequestParts::new(Request::new(()));
let mut req = RequestParts::new((), Request::new(()));
let first = Cached::<Extractor>::from_request(&mut req).await.unwrap().0;
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);

View file

@ -88,13 +88,14 @@ pub struct CookieJar {
}
#[async_trait]
impl<B> FromRequest<B> for CookieJar
impl<S, B> FromRequest<S, B> for CookieJar
where
S: Send,
B: Send,
{
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let mut jar = cookie_lib::CookieJar::new();
for cookie in cookies_from_request(req) {
jar.add_original(cookie);
@ -103,8 +104,8 @@ where
}
}
fn cookies_from_request<B>(
req: &mut RequestParts<B>,
fn cookies_from_request<S, B>(
req: &mut RequestParts<S, B>,
) -> impl Iterator<Item = Cookie<'static>> + '_ {
req.headers()
.get_all(COOKIE)
@ -226,13 +227,12 @@ mod tests {
jar.remove(Cookie::named("key"))
}
let app = Router::<_, Body, _>::new()
let app = Router::without_state()
.route("/set", get(set_cookie))
.route("/get", get(get_cookie))
.route("/remove", get(remove_cookie))
.layer(Extension(Key::generate()))
.layer(Extension(CustomKey(Key::generate())))
.state(());
.layer(Extension(CustomKey(Key::generate())));
let res = app
.clone()
@ -295,10 +295,9 @@ mod tests {
format!("{:?}", jar.get("key"))
}
let app = Router::<_, Body, _>::new()
let app = Router::without_state()
.route("/get", get(get_cookie))
.layer(Extension(Key::generate()))
.state(());
.layer(Extension(Key::generate()));
let res = app
.clone()

View file

@ -74,14 +74,15 @@ impl<K> fmt::Debug for PrivateCookieJar<K> {
}
#[async_trait]
impl<B, K> FromRequest<B> for PrivateCookieJar<K>
impl<S, B, K> FromRequest<S, B> for PrivateCookieJar<K>
where
B: Send,
S: Send,
K: Into<Key> + Clone + Send + Sync + 'static,
{
type Rejection = <axum::Extension<K> as FromRequest<B>>::Rejection;
type Rejection = <axum::Extension<K> as FromRequest<S, B>>::Rejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let key = Extension::<K>::from_request(req).await?.0.into();
let mut jar = cookie_lib::CookieJar::new();

View file

@ -92,14 +92,15 @@ impl<K> fmt::Debug for SignedCookieJar<K> {
}
#[async_trait]
impl<B, K> FromRequest<B> for SignedCookieJar<K>
impl<S, B, K> FromRequest<S, B> for SignedCookieJar<K>
where
B: Send,
S: Send,
K: Into<Key> + Clone + Send + Sync + 'static,
{
type Rejection = <axum::Extension<K> as FromRequest<B>>::Rejection;
type Rejection = <axum::Extension<K> as FromRequest<S, B>>::Rejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let key = Extension::<K>::from_request(req).await?.0.into();
let mut jar = cookie_lib::CookieJar::new();

View file

@ -54,16 +54,17 @@ impl<T> Deref for Form<T> {
}
#[async_trait]
impl<T, B> FromRequest<B> for Form<T>
impl<T, S, B> FromRequest<S, B> for Form<T>
where
T: DeserializeOwned,
B: HttpBody + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
{
type Rejection = FormRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
if req.method() == Method::GET {
let query = req.uri().query().unwrap_or_default();
let value = serde_html_form::from_str(query)
@ -84,7 +85,7 @@ where
}
// this is duplicated in `axum/src/extract/mod.rs`
fn has_content_type<B>(req: &RequestParts<B>, expected_content_type: &mime::Mime) -> bool {
fn has_content_type<S, B>(req: &RequestParts<S, B>, expected_content_type: &mime::Mime) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
content_type
} else {
@ -116,12 +117,10 @@ mod tests {
values: Vec<String>,
}
let app = Router::new()
.route(
"/",
post(|Form(data): Form<Data>| async move { data.values.join(",") }),
)
.state(());
let app = Router::without_state().route(
"/",
post(|Form(data): Form<Data>| async move { data.values.join(",") }),
);
let client = TestClient::new(app);

View file

@ -58,14 +58,15 @@ use std::ops::Deref;
pub struct Query<T>(pub T);
#[async_trait]
impl<T, B> FromRequest<B> for Query<T>
impl<T, S, B> FromRequest<S, B> for Query<T>
where
T: DeserializeOwned,
B: Send,
S: Send,
{
type Rejection = QueryRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let query = req.uri().query().unwrap_or_default();
let value = serde_html_form::from_str(query)
.map_err(FailedToDeserializeQueryString::__private_new::<T, _>)?;
@ -97,12 +98,10 @@ mod tests {
values: Vec<String>,
}
let app = Router::new()
.route(
"/",
post(|Query(data): Query<Data>| async move { data.values.join(",") }),
)
.state(());
let app = Router::without_state().route(
"/",
post(|Query(data): Query<Data>| async move { data.values.join(",") }),
);
let client = TestClient::new(app);

View file

@ -98,16 +98,17 @@ impl<S> JsonLines<S, AsResponse> {
}
#[async_trait]
impl<B, T> FromRequest<B> for JsonLines<T, AsExtractor>
impl<S, B, T> FromRequest<S, B> for JsonLines<T, AsExtractor>
where
B: HttpBody + Send + 'static,
B::Data: Into<Bytes>,
B::Error: Into<BoxError>,
T: DeserializeOwned,
S: Send,
{
type Rejection = BodyAlreadyExtracted;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
// `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead`
// so we can call `AsyncRead::lines` and then convert it back to a `Stream`
@ -217,24 +218,22 @@ mod tests {
#[tokio::test]
async fn extractor() {
let app = Router::new()
.route(
"/",
post(|mut stream: JsonLines<User>| async move {
assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 1 });
assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 2 });
assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 3 });
let app = Router::without_state().route(
"/",
post(|mut stream: JsonLines<User>| async move {
assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 1 });
assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 2 });
assert_eq!(stream.next().await.unwrap().unwrap(), User { id: 3 });
// sources are downcastable to `serde_json::Error`
let err = stream.next().await.unwrap().unwrap_err();
let _: &serde_json::Error = err
.source()
.unwrap()
.downcast_ref::<serde_json::Error>()
.unwrap();
}),
)
.state(());
// sources are downcastable to `serde_json::Error`
let err = stream.next().await.unwrap().unwrap_err();
let _: &serde_json::Error = err
.source()
.unwrap()
.downcast_ref::<serde_json::Error>()
.unwrap();
}),
);
let client = TestClient::new(app);
@ -257,19 +256,17 @@ mod tests {
#[tokio::test]
async fn response() {
let app = Router::new()
.route(
"/",
get(|| async {
let values = futures_util::stream::iter(vec![
Ok::<_, Infallible>(User { id: 1 }),
Ok::<_, Infallible>(User { id: 2 }),
Ok::<_, Infallible>(User { id: 3 }),
]);
JsonLines::new(values)
}),
)
.state(());
let app = Router::without_state().route(
"/",
get(|| async {
let values = futures_util::stream::iter(vec![
Ok::<_, Infallible>(User { id: 1 }),
Ok::<_, Infallible>(User { id: 2 }),
Ok::<_, Infallible>(User { id: 3 }),
]);
JsonLines::new(values)
}),
);
let client = TestClient::new(app);

View file

@ -4,6 +4,7 @@ use axum::{
handler::Handler,
http::Request,
response::{Redirect, Response},
routing::{MethodRouter, MissingState},
Router,
};
use std::{convert::Infallible, future::ready};
@ -29,7 +30,7 @@ pub use self::typed::{FirstElementIs, TypedPath};
pub use self::spa::SpaRouter;
/// Extension trait that adds additional methods to [`Router`].
pub trait RouterExt<S, B, R>: sealed::Sealed {
pub trait RouterExt<S, R, B>: sealed::Sealed {
/// Add a typed `GET` route to the router.
///
/// The path will be inferred from the first argument to the handler function which must
@ -39,7 +40,7 @@ pub trait RouterExt<S, B, R>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_get<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -52,7 +53,7 @@ pub trait RouterExt<S, B, R>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_delete<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -65,7 +66,7 @@ pub trait RouterExt<S, B, R>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_head<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -78,7 +79,7 @@ pub trait RouterExt<S, B, R>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_options<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -91,7 +92,7 @@ pub trait RouterExt<S, B, R>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_patch<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -104,7 +105,7 @@ pub trait RouterExt<S, B, R>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_post<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -117,7 +118,7 @@ pub trait RouterExt<S, B, R>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_put<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
@ -130,10 +131,18 @@ pub trait RouterExt<S, B, R>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_trace<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
fn route_with_tsr(
self,
path: &str,
// TODO(david): constrain this so it only accepts methods
// routers containing handlers
method_router: MethodRouter<S, MissingState, B, Infallible>,
) -> Self;
/// Add another route to the router with an additional "trailing slash redirect" route.
///
/// If you add a route _without_ a trailing slash, such as `/foo`, this method will also add a
@ -159,23 +168,23 @@ pub trait RouterExt<S, B, R>: sealed::Sealed {
/// .route_with_tsr("/bar/", get(|| async {}));
/// # let _: Router = app;
/// ```
fn route_with_tsr<T>(self, path: &str, service: T) -> Self
fn route_service_with_tsr<T>(self, path: &str, service: T) -> Self
where
T: Service<Request<B>, Response = Response, Error = Infallible> + Clone + Send + 'static,
T::Future: Send + 'static,
Self: Sized;
}
impl<S, B, R> RouterExt<S, B, R> for Router<S, B, R>
impl<S, B, R> RouterExt<S, R, B> for Router<S, R, B>
where
B: axum::body::HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
R: 'static,
S: 'static,
{
#[cfg(feature = "typed-routing")]
fn typed_get<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
@ -185,7 +194,7 @@ where
#[cfg(feature = "typed-routing")]
fn typed_delete<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
@ -195,7 +204,7 @@ where
#[cfg(feature = "typed-routing")]
fn typed_head<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
@ -205,7 +214,7 @@ where
#[cfg(feature = "typed-routing")]
fn typed_options<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
@ -215,7 +224,7 @@ where
#[cfg(feature = "typed-routing")]
fn typed_patch<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
@ -225,7 +234,7 @@ where
#[cfg(feature = "typed-routing")]
fn typed_post<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
@ -235,7 +244,7 @@ where
#[cfg(feature = "typed-routing")]
fn typed_put<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
@ -245,32 +254,42 @@ where
#[cfg(feature = "typed-routing")]
fn typed_trace<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::trace(handler))
}
fn route_with_tsr<T>(mut self, path: &str, service: T) -> Self
fn route_with_tsr(
mut self,
path: &str,
// TODO(david): constrain this so it only accepts methods
// routers containing handlers
method_router: MethodRouter<S, MissingState, B, Infallible>,
) -> Self {
todo!()
}
fn route_service_with_tsr<T>(mut self, path: &str, service: T) -> Self
where
T: Service<Request<B>, Response = Response, Error = Infallible> + Clone + Send + 'static,
T::Future: Send + 'static,
Self: Sized,
{
self = self.route(path, service);
self = self.route_service(path, service);
let redirect = Redirect::permanent(path);
if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
self.route(
self.route_service(
path_without_trailing_slash,
(move || ready(redirect.clone())).into_service(),
(move || ready(redirect.clone())).into_service(()),
)
} else {
self.route(
self.route_service(
&format!("{}/", path),
(move || ready(redirect.clone())).into_service(),
(move || ready(redirect.clone())).into_service(()),
)
}
}
@ -289,10 +308,9 @@ mod tests {
#[tokio::test]
async fn test_tsr() {
let app = Router::new()
let app = Router::without_state()
.route_with_tsr("/foo", get(|| async {}))
.route_with_tsr("/bar/", get(|| async {}))
.state(());
.route_with_tsr("/bar/", get(|| async {}));
let client = TestClient::new(app);

View file

@ -3,7 +3,7 @@ use axum::{
handler::Handler,
http::Request,
response::Response,
routing::{delete, get, on, post, MethodFilter, MissingState, WithState},
routing::{delete, get, on, post, MethodFilter, MethodRouter, MissingState, WithState},
Router,
};
use std::{convert::Infallible, fmt};
@ -47,12 +47,12 @@ use tower_service::Service;
/// let app = Router::new().merge(users);
/// # let _: Router<axum::body::Body> = app;
/// ```
pub struct Resource<S, B = Body, R = MissingState> {
pub struct Resource<S, R, B> {
pub(crate) name: String,
pub(crate) router: Router<S, B, R>,
pub(crate) router: Router<S, R, B>,
}
impl<S, B, R> fmt::Debug for Resource<S, B, R>
impl<S, B, R> fmt::Debug for Resource<S, R, B>
where
S: fmt::Debug,
{
@ -64,7 +64,7 @@ where
}
}
impl<S, B> Resource<S, B, MissingState>
impl<S, B> Resource<S, MissingState, B>
where
B: axum::body::HttpBody + Send + 'static,
{
@ -77,29 +77,18 @@ where
router: Default::default(),
}
}
/// TODO(david): docs
pub fn state(self, state: S) -> Resource<S, B, WithState>
where
S: Clone,
{
Resource {
name: self.name,
router: self.router.state(state),
}
}
}
impl<S, B, R> Resource<S, B, R>
impl<S, R, B> Resource<S, R, B>
where
B: axum::body::HttpBody + Send + 'static,
S: 'static,
S: Clone + Send + Sync + 'static,
R: 'static,
{
/// Add a handler at `GET /{resource_name}`.
pub fn index<H, T>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: 'static,
{
let path = self.index_create_path();
@ -109,7 +98,7 @@ where
/// Add a handler at `POST /{resource_name}`.
pub fn create<H, T>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: 'static,
{
let path = self.index_create_path();
@ -119,7 +108,7 @@ where
/// Add a handler at `GET /{resource_name}/new`.
pub fn new<H, T>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: 'static,
{
let path = format!("/{}/new", self.name);
@ -129,7 +118,7 @@ where
/// Add a handler at `GET /{resource_name}/:{resource_name}_id`.
pub fn show<H, T>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: 'static,
{
let path = self.show_update_destroy_path();
@ -139,7 +128,7 @@ where
/// Add a handler at `GET /{resource_name}/:{resource_name}_id/edit`.
pub fn edit<H, T>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: 'static,
{
let path = format!("/{0}/:{0}_id/edit", self.name);
@ -149,7 +138,7 @@ where
/// Add a handler at `PUT or PATCH /resource_name/:{resource_name}_id`.
pub fn update<H, T>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: 'static,
{
let path = self.show_update_destroy_path();
@ -159,7 +148,7 @@ where
/// Add a handler at `DELETE /{resource_name}/:{resource_name}_id`.
pub fn destroy<H, T>(self, handler: H) -> Self
where
H: Handler<T, B>,
H: Handler<S, T, B>,
T: 'static,
{
let path = self.show_update_destroy_path();
@ -169,7 +158,7 @@ where
/// Nest another router at the "member level".
///
/// The routes will be nested at `/{resource_name}/:{resource_name}_id`.
pub fn nest(mut self, router: Router<S, B, MissingState>) -> Self {
pub fn nest(mut self, router: Router<S, MissingState, B>) -> Self {
let path = self.show_update_destroy_path();
self.router = self.router.nest(&path, router);
self
@ -178,7 +167,7 @@ where
/// Nest another router at the "collection level".
///
/// The routes will be nested at `/{resource_name}`.
pub fn nest_collection(mut self, router: Router<S, B, MissingState>) -> Self {
pub fn nest_collection(mut self, router: Router<S, MissingState, B>) -> Self {
let path = self.index_create_path();
self.router = self.router.nest(&path, router);
self
@ -192,18 +181,18 @@ where
format!("/{0}/:{0}_id", self.name)
}
fn route<T>(mut self, path: &str, svc: T) -> Self
where
T: Service<Request<B>, Response = Response, Error = Infallible> + Clone + Send + 'static,
T::Future: Send + 'static,
{
self.router = self.router.route(path, svc);
fn route(
mut self,
path: &str,
method_router: MethodRouter<S, MissingState, B, Infallible>,
) -> Self {
self.router = self.router.route(path, method_router);
self
}
}
impl<S, B> From<Resource<S, B, MissingState>> for Router<S, B, MissingState> {
fn from(resource: Resource<S, B, MissingState>) -> Self {
impl<S, B> From<Resource<S, MissingState, B>> for Router<S, MissingState, B> {
fn from(resource: Resource<S, MissingState, B>) -> Self {
resource.router
}
}
@ -233,7 +222,7 @@ mod tests {
Router::new().route("/featured", get(|| async move { "users#featured" })),
);
let mut app = Router::new().merge(users).state(());
let mut app = Router::without_state().merge(users);
assert_eq!(
call_route(&mut app, Method::GET, "/users").await,
@ -286,11 +275,7 @@ mod tests {
);
}
async fn call_route(
app: &mut Router<(), Body, axum::routing::WithState>,
method: Method,
uri: &str,
) -> String {
async fn call_route(app: &mut Router<(), WithState>, method: Method, uri: &str) -> String {
let res = app
.ready()
.await

View file

@ -147,7 +147,7 @@ impl<B, T, F> SpaRouter<B, T, F> {
}
}
impl<B, F, T, S> From<SpaRouter<B, T, F>> for Router<S, B, MissingState>
impl<B, F, T, S> From<SpaRouter<B, T, F>> for Router<S, MissingState, B>
where
F: Clone + Send + 'static,
HandleError<Route<B, io::Error>, F, T>:
@ -155,17 +155,22 @@ where
<HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Future: Send,
B: HttpBody + Send + 'static,
T: 'static,
S: 'static,
S: Clone + Send + Sync + 'static,
{
fn from(spa: SpaRouter<B, T, F>) -> Self {
let assets_service = get_service(ServeDir::new(&spa.paths.assets_dir))
.handle_error(spa.handle_error.clone());
.handle_error(spa.handle_error.clone())
// TODO(david): having to do this is annoying
.state(());
let fallback_service = get_service(ServeFile::new(&spa.paths.index_file))
.handle_error(spa.handle_error)
// TODO(david): having to do this is annoying
.state(());
Router::new()
.nest_service(&spa.paths.assets_path, assets_service)
.fallback(
get_service(ServeFile::new(&spa.paths.index_file)).handle_error(spa.handle_error),
)
.fallback_service(fallback_service)
}
}
@ -213,10 +218,9 @@ mod tests {
#[tokio::test]
async fn basic() {
let app = Router::new()
let app = Router::without_state()
.route("/foo", get(|| async { "GET /foo" }))
.merge(SpaRouter::new("/assets", "test_files"))
.state(());
.merge(SpaRouter::new("/assets", "test_files"));
let client = TestClient::new(app);
let res = client.get("/").send().await;
@ -241,9 +245,8 @@ mod tests {
#[tokio::test]
async fn setting_index_file() {
let app = Router::new()
.merge(SpaRouter::new("/assets", "test_files").index_file("index_2.html"))
.state(());
let app = Router::without_state()
.merge(SpaRouter::new("/assets", "test_files").index_file("index_2.html"));
let client = TestClient::new(app);
let res = client.get("/").send().await;
@ -267,6 +270,6 @@ mod tests {
let spa = SpaRouter::new("/assets", "test_files").handle_error(handle_error);
Router::<(), Body, _>::new().merge(spa);
Router::<()>::new().merge(spa);
}
}

View file

@ -1469,4 +1469,11 @@ mod tests {
async fn created() -> (StatusCode, &'static str) {
(StatusCode::CREATED, "created")
}
#[test]
fn service_constructors_have_state() {
// shouldn't have to do
// get_service(...).state(...)
todo!()
}
}

View file

@ -12,7 +12,7 @@ use std::net::SocketAddr;
#[tokio::main]
async fn main() {
let app = Router::new()
let app = Router::without_state()
.merge(root())
.merge(get_foo())
.merge(post_foo());
@ -25,7 +25,7 @@ async fn main() {
.unwrap();
}
fn root() -> Router {
fn root() -> Router<()> {
async fn handler() -> &'static str {
"Hello, World!"
}
@ -33,7 +33,7 @@ fn root() -> Router {
route("/", get(handler))
}
fn get_foo() -> Router {
fn get_foo() -> Router<()> {
async fn handler() -> &'static str {
"Hi from `GET /foo`"
}
@ -41,7 +41,7 @@ fn get_foo() -> Router {
route("/foo", get(handler))
}
fn post_foo() -> Router {
fn post_foo() -> Router<()> {
async fn handler() -> &'static str {
"Hi from `POST /foo`"
}
@ -49,6 +49,6 @@ fn post_foo() -> Router {
route("/foo", post(handler))
}
fn route(path: &str, method_router: MethodRouter) -> Router {
fn route(path: &str, method_router: MethodRouter<()>) -> Router<()> {
Router::new().route(path, method_router)
}