what about typed methods?

This commit is contained in:
David Pedersen 2022-02-15 09:49:26 +01:00
parent d3bcf7778d
commit cc1f989467
5 changed files with 337 additions and 283 deletions

View file

@ -13,7 +13,10 @@ pub use self::resource::Resource;
pub use axum_macros::TypedPath;
#[cfg(feature = "typed-routing")]
pub use self::typed::{FirstElementIs, TypedPath};
pub use self::typed::{
Any, Delete, FirstTwoElementsAre, Get, Head, OneOf, Options, Patch, Post, Put, Trace,
TypedMethod, TypedPath,
};
/// Extension trait that adds additional methods to [`Router`].
pub trait RouterExt<B>: sealed::Sealed {
@ -42,108 +45,80 @@ pub trait RouterExt<B>: sealed::Sealed {
where
T: HasRoutes<B>;
/// Add a typed `GET` route to the router.
/// Add a typed route to the router.
///
/// The path will be inferred from the first argument to the handler function which must
/// implement [`TypedPath`].
/// The method and path will be inferred from the first two arguments to the handler function
/// which must implement [`TypedMethod`] and [`TypedPath`] respectively.
///
/// See [`TypedPath`] for more details and examples.
/// # Example
///
/// ```rust
/// use serde::Deserialize;
/// use axum::{Router, extract::Json};
/// use axum_extra::routing::{
/// TypedPath,
/// Get,
/// Post,
/// Delete,
/// RouterExt, // for `Router::typed_*`
/// };
///
/// // A type safe route with `/users/:id` as its associated path.
/// #[derive(TypedPath, Deserialize)]
/// #[typed_path("/users/:id")]
/// struct UsersMember {
/// id: u32,
/// }
///
/// // A regular handler function that takes `Get` as its first argument and
/// // `UsersMember` as the second argument and thus creates a typed connection
/// // between this handler and `GET /users/:id`.
/// //
/// // The first argument must implement `TypedMethod` and the second must
/// // implement `TypedPath`.
/// async fn users_show(
/// _: Get,
/// UsersMember { id }: UsersMember,
/// ) {
/// // ...
/// }
///
/// let app = Router::new()
/// // Add our typed route to the router.
/// //
/// // The method and path will be inferred to `GET /users/:id` since `users_show`'s
/// // first argument is `Get` and the second is `UsersMember`.
/// .typed_route(users_show)
/// .typed_route(users_create)
/// .typed_route(users_destroy);
///
/// #[derive(TypedPath)]
/// #[typed_path("/users")]
/// struct UsersCollection;
///
/// #[derive(Deserialize)]
/// struct UsersCreatePayload { /* ... */ }
///
/// async fn users_create(
/// _: Post,
/// _: UsersCollection,
/// // Our handlers can accept other extractors.
/// Json(payload): Json<UsersCreatePayload>,
/// ) {
/// // ...
/// }
///
/// async fn users_destroy(_: Delete, _: UsersCollection) { /* ... */ }
///
/// #
/// # let app: Router<axum::body::Body> = app;
/// ```
#[cfg(feature = "typed-routing")]
fn typed_get<H, T, P>(self, handler: H) -> Self
fn typed_route<H, T, M, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
/// Add a typed `DELETE` route to the router.
///
/// The path will be inferred from the first argument to the handler function which must
/// implement [`TypedPath`].
///
/// See [`TypedPath`] for more details and examples.
#[cfg(feature = "typed-routing")]
fn typed_delete<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
/// Add a typed `HEAD` route to the router.
///
/// The path will be inferred from the first argument to the handler function which must
/// implement [`TypedPath`].
///
/// See [`TypedPath`] for more details and examples.
#[cfg(feature = "typed-routing")]
fn typed_head<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
/// Add a typed `OPTIONS` route to the router.
///
/// The path will be inferred from the first argument to the handler function which must
/// implement [`TypedPath`].
///
/// See [`TypedPath`] for more details and examples.
#[cfg(feature = "typed-routing")]
fn typed_options<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
/// Add a typed `PATCH` route to the router.
///
/// The path will be inferred from the first argument to the handler function which must
/// implement [`TypedPath`].
///
/// See [`TypedPath`] for more details and examples.
#[cfg(feature = "typed-routing")]
fn typed_patch<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
/// Add a typed `POST` route to the router.
///
/// The path will be inferred from the first argument to the handler function which must
/// implement [`TypedPath`].
///
/// See [`TypedPath`] for more details and examples.
#[cfg(feature = "typed-routing")]
fn typed_post<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
/// Add a typed `PUT` route to the router.
///
/// The path will be inferred from the first argument to the handler function which must
/// implement [`TypedPath`].
///
/// See [`TypedPath`] for more details and examples.
#[cfg(feature = "typed-routing")]
fn typed_put<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath;
/// Add a typed `TRACE` route to the router.
///
/// The path will be inferred from the first argument to the handler function which must
/// implement [`TypedPath`].
///
/// See [`TypedPath`] for more details and examples.
#[cfg(feature = "typed-routing")]
fn typed_trace<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
T: FirstTwoElementsAre<M, P> + 'static,
M: TypedMethod,
P: TypedPath;
}
@ -159,83 +134,14 @@ where
}
#[cfg(feature = "typed-routing")]
fn typed_get<H, T, P>(self, handler: H) -> Self
fn typed_route<H, T, M, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
T: FirstTwoElementsAre<M, P> + 'static,
M: TypedMethod,
P: TypedPath,
{
self.route(P::PATH, axum::routing::get(handler))
}
#[cfg(feature = "typed-routing")]
fn typed_delete<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::delete(handler))
}
#[cfg(feature = "typed-routing")]
fn typed_head<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::head(handler))
}
#[cfg(feature = "typed-routing")]
fn typed_options<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::options(handler))
}
#[cfg(feature = "typed-routing")]
fn typed_patch<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::patch(handler))
}
#[cfg(feature = "typed-routing")]
fn typed_post<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::post(handler))
}
#[cfg(feature = "typed-routing")]
fn typed_put<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::put(handler))
}
#[cfg(feature = "typed-routing")]
fn typed_trace<H, T, P>(self, handler: H) -> Self
where
H: Handler<T, B>,
T: FirstElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::trace(handler))
self.route(P::PATH, M::apply_method_router(handler))
}
}

View file

@ -1,66 +1,16 @@
use super::sealed::Sealed;
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
handler::Handler,
routing::MethodRouter,
};
use std::{convert::Infallible, fmt, marker::PhantomData};
/// A type safe path.
/// A typed request path.
///
/// This is used to statically connect a path to its corresponding handler using
/// [`RouterExt::typed_get`], [`RouterExt::typed_post`], etc.
///
/// # Example
///
/// ```rust
/// use serde::Deserialize;
/// use axum::{Router, extract::Json};
/// use axum_extra::routing::{
/// TypedPath,
/// RouterExt, // for `Router::typed_*`
/// };
///
/// // A type safe route with `/users/:id` as its associated path.
/// #[derive(TypedPath, Deserialize)]
/// #[typed_path("/users/:id")]
/// struct UsersMember {
/// id: u32,
/// }
///
/// // A regular handler function that takes `UsersMember` as the first argument
/// // and thus creates a typed connection between this handler and the `/users/:id` path.
/// //
/// // The `TypedPath` must be the first argument to the function.
/// async fn users_show(
/// UsersMember { id }: UsersMember,
/// ) {
/// // ...
/// }
///
/// let app = Router::new()
/// // Add our typed route to the router.
/// //
/// // The path will be inferred to `/users/:id` since `users_show`'s
/// // first argument is `UsersMember` which implements `TypedPath`
/// .typed_get(users_show)
/// .typed_post(users_create)
/// .typed_delete(users_destroy);
///
/// #[derive(TypedPath)]
/// #[typed_path("/users")]
/// struct UsersCollection;
///
/// #[derive(Deserialize)]
/// struct UsersCreatePayload { /* ... */ }
///
/// async fn users_create(
/// _: UsersCollection,
/// // Our handlers can accept other extractors.
/// Json(payload): Json<UsersCreatePayload>,
/// ) {
/// // ...
/// }
///
/// async fn users_destroy(_: UsersCollection) { /* ... */ }
///
/// #
/// # let app: Router<axum::body::Body> = app;
/// ```
/// [`RouterExt::typed_route`]. See that method for more details.
///
/// # Using `#[derive(TypedPath)]`
///
@ -80,9 +30,9 @@ use super::sealed::Sealed;
/// The macro expands to:
///
/// - A `TypedPath` implementation.
/// - A [`FromRequest`] implementation compatible with [`RouterExt::typed_get`],
/// [`RouterExt::typed_post`], etc. This implementation uses [`Path`] and thus your struct must
/// also implement [`serde::Deserialize`], unless it's a unit struct.
/// - A [`FromRequest`] implementation compatible with [`RouterExt::typed_route`]. This
/// implementation uses [`Path`] and thus your struct must also implement [`serde::Deserialize`],
/// unless it's a unit struct.
/// - A [`Display`] implementation that interpolates the captures. This can be used to, among other
/// things, create links to known paths and have them verified statically. Note that the
/// [`Display`] implementation for each field must return something that's compatible with its
@ -118,8 +68,7 @@ use super::sealed::Sealed;
/// ```
///
/// [`FromRequest`]: axum::extract::FromRequest
/// [`RouterExt::typed_get`]: super::RouterExt::typed_get
/// [`RouterExt::typed_post`]: super::RouterExt::typed_post
/// [`RouterExt::typed_route`]: super::RouterExt::typed_route
/// [`Path`]: axum::extract::Path
/// [`Display`]: std::fmt::Display
/// [`Deserialize`]: serde::Deserialize
@ -128,6 +77,232 @@ pub trait TypedPath: std::fmt::Display {
const PATH: &'static str;
}
/// A typed HTTP method.
///
/// This is used to statically connect an HTTP method to its corresponding handler using
/// [`RouterExt::typed_route`]. See that method for more details.
///
/// This trait is sealed such that it cannot be implemented outside this crate.
///
/// [`RouterExt::typed_route`]: super::RouterExt::typed_route
pub trait TypedMethod: Sealed {
/// Wrap a handler in a [`MethodRouter`] that accepts this type's corresponding HTTP method.
fn apply_method_router<H, B, T>(handler: H) -> MethodRouter<B>
where
H: Handler<T, B>,
B: Send + 'static,
T: 'static;
/// Check if the request matches this type's corresponding HTTP method.
fn matches_method(method: &http::Method) -> bool;
}
macro_rules! typed_method {
($name:ident, $method_router_constructor:ident, $method:ident) => {
#[doc = concat!("A `TypedMethod` that accepts `", stringify!($method), "` requests.")]
#[derive(Clone, Copy, Debug)]
pub struct $name;
impl Sealed for $name {}
impl TypedMethod for $name {
fn apply_method_router<H, B, T>(handler: H) -> MethodRouter<B>
where
H: Handler<T, B>,
B: Send + 'static,
T: 'static,
{
axum::routing::$method_router_constructor(handler)
}
fn matches_method(method: &http::Method) -> bool {
method == http::Method::$method
}
}
#[async_trait]
impl<B> FromRequest<B> for $name
where
B: Send,
{
type Rejection = http::StatusCode;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
if Self::matches_method(req.method()) {
Ok(Self)
} else {
Err(http::StatusCode::NOT_FOUND)
}
}
}
};
}
typed_method!(Delete, delete, DELETE);
typed_method!(Get, get, GET);
typed_method!(Head, head, HEAD);
typed_method!(Options, options, OPTIONS);
typed_method!(Patch, patch, PATCH);
typed_method!(Post, post, POST);
typed_method!(Put, put, PUT);
typed_method!(Trace, trace, TRACE);
/// A [`TypedMethod`] that accepts all HTTP methods.
///
/// # Example
///
/// ```rust
/// use axum_extra::routing::{TypedPath, Any, RouterExt};
/// use axum::Router;
///
/// #[derive(TypedPath)]
/// #[typed_path("/foo")]
/// struct Foo;
///
/// // This accepts `/foo` with any HTTP method.
/// async fn foo(_: Any, _: Foo) {}
///
/// let app = Router::new().typed_route(foo);
/// #
/// # let app: Router<axum::body::Body> = app;
/// ```
#[derive(Debug, Clone, Copy)]
pub struct Any;
impl Sealed for Any {}
impl TypedMethod for Any {
fn apply_method_router<H, B, T>(handler: H) -> MethodRouter<B>
where
H: Handler<T, B>,
B: Send + 'static,
T: 'static,
{
axum::routing::any(handler)
}
fn matches_method(_method: &http::Method) -> bool {
true
}
}
#[async_trait]
impl<B> FromRequest<B> for Any
where
B: Send,
{
type Rejection = Infallible;
async fn from_request(_: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
Ok(Self)
}
}
/// A [`TypedMethod`] that accepts one of a number of HTTP methods.
///
/// # Example
///
/// ```rust
/// use axum_extra::routing::{TypedPath, OneOf, Patch, Put, RouterExt};
/// use axum::Router;
///
/// #[derive(TypedPath)]
/// #[typed_path("/foo")]
/// struct Foo;
///
/// // This accepts `PATCH /foo` and `PUT /foo`
/// async fn foo(_: OneOf<(Patch, Put)>, _: Foo) {}
///
/// let app = Router::new().typed_route(foo);
/// #
/// # let app: Router<axum::body::Body> = app;
/// ```
pub struct OneOf<T>(PhantomData<T>);
macro_rules! one_of {
($($ty:ident),* $(,)?) => {
impl<$($ty,)*> TypedMethod for OneOf<($($ty,)*)>
where
$( $ty: TypedMethod, )*
{
#[allow(clippy::redundant_clone, unused_mut, unused_variables)]
fn apply_method_router<H, B, T>(handler: H) -> MethodRouter<B>
where
H: Handler<T, B>,
B: Send + 'static,
T: 'static,
{
let mut method_router = MethodRouter::new();
$(
method_router = method_router.merge($ty::apply_method_router(handler.clone()));
)*
method_router
}
#[allow(unused_variables)]
fn matches_method(method: &http::Method) -> bool {
$(
if $ty::matches_method(method) {
return true;
}
)*
false
}
}
impl<$($ty,)*> Sealed for OneOf<($($ty,)*)> {}
#[async_trait]
impl<B, $($ty,)*> FromRequest<B> for OneOf<($($ty,)*)>
where
B: Send,
$( $ty: TypedMethod + FromRequest<B>, )*
{
type Rejection = http::StatusCode;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
if Self::matches_method(req.method()) {
Ok(Self(PhantomData))
} else {
Err(http::StatusCode::NOT_FOUND)
}
}
}
};
}
one_of!();
one_of!(T1,);
one_of!(T1, T2);
one_of!(T1, T2, T3);
one_of!(T1, T2, T3, T4);
one_of!(T1, T2, T3, T4, T5);
one_of!(T1, T2, T3, T4, T5, T6);
one_of!(T1, T2, T3, T4, T5, T6, T7);
one_of!(T1, T2, T3, T4, T5, T6, T7, T8);
impl<T> fmt::Debug for OneOf<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("OneOf")
.field(&format_args!("{}", std::any::type_name::<T>()))
.finish()
}
}
impl<T> Default for OneOf<T> {
fn default() -> Self {
Self(Default::default())
}
}
impl<T> Clone for OneOf<T> {
fn clone(&self) -> Self {
Self(self.0)
}
}
impl<T> Copy for OneOf<T> {}
/// Utility trait used with [`RouterExt`] to ensure the first element of a tuple type is a
/// given type.
///
@ -139,18 +314,20 @@ pub trait TypedPath: std::fmt::Display {
/// It is sealed such that it cannot be implemented outside this crate.
///
/// [`RouterExt`]: super::RouterExt
pub trait FirstElementIs<P>: Sealed {}
pub trait FirstTwoElementsAre<M, P>: Sealed {}
macro_rules! impl_first_element_is {
( $($ty:ident),* $(,)? ) => {
impl<P, $($ty,)*> FirstElementIs<P> for (P, $($ty,)*)
impl<M, P, $($ty,)*> FirstTwoElementsAre<M, P> for (M, P, $($ty,)*)
where
P: TypedPath
M: TypedMethod,
P: TypedPath,
{}
impl<P, $($ty,)*> Sealed for (P, $($ty,)*)
impl<M, P, $($ty,)*> Sealed for (M, P, $($ty,)*)
where
P: TypedPath
M: TypedMethod,
P: TypedPath,
{}
};
}

View file

@ -7,6 +7,3 @@ publish = false
[dependencies]
axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
axum-extra = { path = "../../axum-extra", features = ["typed-routing"] }
serde = { version = "1.0", features = ["derive"] }

View file

@ -4,52 +4,23 @@
//! cargo run -p example-hello-world
//! ```
// Just using this file for manual testing. Will be cleaned up before an eventual merge
use axum::{response::IntoResponse, Router};
use axum_extra::routing::{RouterExt, TypedPath};
use serde::Deserialize;
use axum::{response::Html, routing::get, Router};
use std::net::SocketAddr;
#[tokio::main]
async fn main() {
let app = Router::new()
.typed_get(users_index)
.typed_post(users_create)
.typed_get(users_show)
.typed_get(users_edit);
// build our application with a route
let app = Router::new().route("/", get(handler));
axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
// run it
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
println!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
#[derive(TypedPath)]
#[typed_path("/users")]
struct UsersCollection;
#[derive(Deserialize, TypedPath)]
#[typed_path("/users/:id")]
struct UsersMember {
id: u32,
}
#[derive(Deserialize, TypedPath)]
#[typed_path("/users/:id/edit")]
struct UsersEdit(u32);
async fn users_index(_: UsersCollection) -> impl IntoResponse {
"users#index"
}
async fn users_create(_: UsersCollection, _payload: String) -> impl IntoResponse {
"users#create"
}
async fn users_show(UsersMember { id }: UsersMember) -> impl IntoResponse {
format!("users#show: {}", id)
}
async fn users_edit(UsersEdit(id): UsersEdit) -> impl IntoResponse {
format!("users#edit: {}", id)
async fn handler() -> Html<&'static str> {
Html("<h1>Hello, World!</h1>")
}

View file

@ -37,7 +37,10 @@ async fn main() {
let app = Router::new()
.fallback(static_files_service)
.route("/sse", get(sse_handler))
.layer(TraceLayer::new_for_http());
.layer(
TraceLayer::new_for_http()
.make_span_with(tower_http::trace::DefaultMakeSpan::new().include_headers(true)),
);
// run it
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));