From cc1f98946717db6e595ad36f78166c0810246374 Mon Sep 17 00:00:00 2001 From: David Pedersen <david.pdrsn@gmail.com> Date: Tue, 15 Feb 2022 09:49:26 +0100 Subject: [PATCH] what about typed methods? --- axum-extra/src/routing/mod.rs | 248 ++++++++---------------- axum-extra/src/routing/typed.rs | 315 ++++++++++++++++++++++++------- examples/hello-world/Cargo.toml | 3 - examples/hello-world/src/main.rs | 49 +---- examples/sse/src/main.rs | 5 +- 5 files changed, 337 insertions(+), 283 deletions(-) diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 884cc484..dc757f57 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -13,7 +13,10 @@ pub use self::resource::Resource; pub use axum_macros::TypedPath; #[cfg(feature = "typed-routing")] -pub use self::typed::{FirstElementIs, TypedPath}; +pub use self::typed::{ + Any, Delete, FirstTwoElementsAre, Get, Head, OneOf, Options, Patch, Post, Put, Trace, + TypedMethod, TypedPath, +}; /// Extension trait that adds additional methods to [`Router`]. pub trait RouterExt<B>: sealed::Sealed { @@ -42,108 +45,80 @@ pub trait RouterExt<B>: sealed::Sealed { where T: HasRoutes<B>; - /// Add a typed `GET` route to the router. + /// Add a typed route to the router. /// - /// The path will be inferred from the first argument to the handler function which must - /// implement [`TypedPath`]. + /// The method and path will be inferred from the first two arguments to the handler function + /// which must implement [`TypedMethod`] and [`TypedPath`] respectively. /// - /// See [`TypedPath`] for more details and examples. + /// # Example + /// + /// ```rust + /// use serde::Deserialize; + /// use axum::{Router, extract::Json}; + /// use axum_extra::routing::{ + /// TypedPath, + /// Get, + /// Post, + /// Delete, + /// RouterExt, // for `Router::typed_*` + /// }; + /// + /// // A type safe route with `/users/:id` as its associated path. + /// #[derive(TypedPath, Deserialize)] + /// #[typed_path("/users/:id")] + /// struct UsersMember { + /// id: u32, + /// } + /// + /// // A regular handler function that takes `Get` as its first argument and + /// // `UsersMember` as the second argument and thus creates a typed connection + /// // between this handler and `GET /users/:id`. + /// // + /// // The first argument must implement `TypedMethod` and the second must + /// // implement `TypedPath`. + /// async fn users_show( + /// _: Get, + /// UsersMember { id }: UsersMember, + /// ) { + /// // ... + /// } + /// + /// let app = Router::new() + /// // Add our typed route to the router. + /// // + /// // The method and path will be inferred to `GET /users/:id` since `users_show`'s + /// // first argument is `Get` and the second is `UsersMember`. + /// .typed_route(users_show) + /// .typed_route(users_create) + /// .typed_route(users_destroy); + /// + /// #[derive(TypedPath)] + /// #[typed_path("/users")] + /// struct UsersCollection; + /// + /// #[derive(Deserialize)] + /// struct UsersCreatePayload { /* ... */ } + /// + /// async fn users_create( + /// _: Post, + /// _: UsersCollection, + /// // Our handlers can accept other extractors. + /// Json(payload): Json<UsersCreatePayload>, + /// ) { + /// // ... + /// } + /// + /// async fn users_destroy(_: Delete, _: UsersCollection) { /* ... */ } + /// + /// # + /// # let app: Router<axum::body::Body> = app; + /// ``` #[cfg(feature = "typed-routing")] - fn typed_get<H, T, P>(self, handler: H) -> Self + fn typed_route<H, T, M, P>(self, handler: H) -> Self where H: Handler<T, B>, - T: FirstElementIs<P> + 'static, - P: TypedPath; - - /// Add a typed `DELETE` route to the router. - /// - /// The path will be inferred from the first argument to the handler function which must - /// implement [`TypedPath`]. - /// - /// See [`TypedPath`] for more details and examples. - #[cfg(feature = "typed-routing")] - fn typed_delete<H, T, P>(self, handler: H) -> Self - where - H: Handler<T, B>, - T: FirstElementIs<P> + 'static, - P: TypedPath; - - /// Add a typed `HEAD` route to the router. - /// - /// The path will be inferred from the first argument to the handler function which must - /// implement [`TypedPath`]. - /// - /// See [`TypedPath`] for more details and examples. - #[cfg(feature = "typed-routing")] - fn typed_head<H, T, P>(self, handler: H) -> Self - where - H: Handler<T, B>, - T: FirstElementIs<P> + 'static, - P: TypedPath; - - /// Add a typed `OPTIONS` route to the router. - /// - /// The path will be inferred from the first argument to the handler function which must - /// implement [`TypedPath`]. - /// - /// See [`TypedPath`] for more details and examples. - #[cfg(feature = "typed-routing")] - fn typed_options<H, T, P>(self, handler: H) -> Self - where - H: Handler<T, B>, - T: FirstElementIs<P> + 'static, - P: TypedPath; - - /// Add a typed `PATCH` route to the router. - /// - /// The path will be inferred from the first argument to the handler function which must - /// implement [`TypedPath`]. - /// - /// See [`TypedPath`] for more details and examples. - #[cfg(feature = "typed-routing")] - fn typed_patch<H, T, P>(self, handler: H) -> Self - where - H: Handler<T, B>, - T: FirstElementIs<P> + 'static, - P: TypedPath; - - /// Add a typed `POST` route to the router. - /// - /// The path will be inferred from the first argument to the handler function which must - /// implement [`TypedPath`]. - /// - /// See [`TypedPath`] for more details and examples. - #[cfg(feature = "typed-routing")] - fn typed_post<H, T, P>(self, handler: H) -> Self - where - H: Handler<T, B>, - T: FirstElementIs<P> + 'static, - P: TypedPath; - - /// Add a typed `PUT` route to the router. - /// - /// The path will be inferred from the first argument to the handler function which must - /// implement [`TypedPath`]. - /// - /// See [`TypedPath`] for more details and examples. - #[cfg(feature = "typed-routing")] - fn typed_put<H, T, P>(self, handler: H) -> Self - where - H: Handler<T, B>, - T: FirstElementIs<P> + 'static, - P: TypedPath; - - /// Add a typed `TRACE` route to the router. - /// - /// The path will be inferred from the first argument to the handler function which must - /// implement [`TypedPath`]. - /// - /// See [`TypedPath`] for more details and examples. - #[cfg(feature = "typed-routing")] - fn typed_trace<H, T, P>(self, handler: H) -> Self - where - H: Handler<T, B>, - T: FirstElementIs<P> + 'static, + T: FirstTwoElementsAre<M, P> + 'static, + M: TypedMethod, P: TypedPath; } @@ -159,83 +134,14 @@ where } #[cfg(feature = "typed-routing")] - fn typed_get<H, T, P>(self, handler: H) -> Self + fn typed_route<H, T, M, P>(self, handler: H) -> Self where H: Handler<T, B>, - T: FirstElementIs<P> + 'static, + T: FirstTwoElementsAre<M, P> + 'static, + M: TypedMethod, P: TypedPath, { - self.route(P::PATH, axum::routing::get(handler)) - } - - #[cfg(feature = "typed-routing")] - fn typed_delete<H, T, P>(self, handler: H) -> Self - where - H: Handler<T, B>, - T: FirstElementIs<P> + 'static, - P: TypedPath, - { - self.route(P::PATH, axum::routing::delete(handler)) - } - - #[cfg(feature = "typed-routing")] - fn typed_head<H, T, P>(self, handler: H) -> Self - where - H: Handler<T, B>, - T: FirstElementIs<P> + 'static, - P: TypedPath, - { - self.route(P::PATH, axum::routing::head(handler)) - } - - #[cfg(feature = "typed-routing")] - fn typed_options<H, T, P>(self, handler: H) -> Self - where - H: Handler<T, B>, - T: FirstElementIs<P> + 'static, - P: TypedPath, - { - self.route(P::PATH, axum::routing::options(handler)) - } - - #[cfg(feature = "typed-routing")] - fn typed_patch<H, T, P>(self, handler: H) -> Self - where - H: Handler<T, B>, - T: FirstElementIs<P> + 'static, - P: TypedPath, - { - self.route(P::PATH, axum::routing::patch(handler)) - } - - #[cfg(feature = "typed-routing")] - fn typed_post<H, T, P>(self, handler: H) -> Self - where - H: Handler<T, B>, - T: FirstElementIs<P> + 'static, - P: TypedPath, - { - self.route(P::PATH, axum::routing::post(handler)) - } - - #[cfg(feature = "typed-routing")] - fn typed_put<H, T, P>(self, handler: H) -> Self - where - H: Handler<T, B>, - T: FirstElementIs<P> + 'static, - P: TypedPath, - { - self.route(P::PATH, axum::routing::put(handler)) - } - - #[cfg(feature = "typed-routing")] - fn typed_trace<H, T, P>(self, handler: H) -> Self - where - H: Handler<T, B>, - T: FirstElementIs<P> + 'static, - P: TypedPath, - { - self.route(P::PATH, axum::routing::trace(handler)) + self.route(P::PATH, M::apply_method_router(handler)) } } diff --git a/axum-extra/src/routing/typed.rs b/axum-extra/src/routing/typed.rs index 0147c4ae..5485af7f 100644 --- a/axum-extra/src/routing/typed.rs +++ b/axum-extra/src/routing/typed.rs @@ -1,66 +1,16 @@ use super::sealed::Sealed; +use axum::{ + async_trait, + extract::{FromRequest, RequestParts}, + handler::Handler, + routing::MethodRouter, +}; +use std::{convert::Infallible, fmt, marker::PhantomData}; -/// A type safe path. +/// A typed request path. /// /// This is used to statically connect a path to its corresponding handler using -/// [`RouterExt::typed_get`], [`RouterExt::typed_post`], etc. -/// -/// # Example -/// -/// ```rust -/// use serde::Deserialize; -/// use axum::{Router, extract::Json}; -/// use axum_extra::routing::{ -/// TypedPath, -/// RouterExt, // for `Router::typed_*` -/// }; -/// -/// // A type safe route with `/users/:id` as its associated path. -/// #[derive(TypedPath, Deserialize)] -/// #[typed_path("/users/:id")] -/// struct UsersMember { -/// id: u32, -/// } -/// -/// // A regular handler function that takes `UsersMember` as the first argument -/// // and thus creates a typed connection between this handler and the `/users/:id` path. -/// // -/// // The `TypedPath` must be the first argument to the function. -/// async fn users_show( -/// UsersMember { id }: UsersMember, -/// ) { -/// // ... -/// } -/// -/// let app = Router::new() -/// // Add our typed route to the router. -/// // -/// // The path will be inferred to `/users/:id` since `users_show`'s -/// // first argument is `UsersMember` which implements `TypedPath` -/// .typed_get(users_show) -/// .typed_post(users_create) -/// .typed_delete(users_destroy); -/// -/// #[derive(TypedPath)] -/// #[typed_path("/users")] -/// struct UsersCollection; -/// -/// #[derive(Deserialize)] -/// struct UsersCreatePayload { /* ... */ } -/// -/// async fn users_create( -/// _: UsersCollection, -/// // Our handlers can accept other extractors. -/// Json(payload): Json<UsersCreatePayload>, -/// ) { -/// // ... -/// } -/// -/// async fn users_destroy(_: UsersCollection) { /* ... */ } -/// -/// # -/// # let app: Router<axum::body::Body> = app; -/// ``` +/// [`RouterExt::typed_route`]. See that method for more details. /// /// # Using `#[derive(TypedPath)]` /// @@ -80,9 +30,9 @@ use super::sealed::Sealed; /// The macro expands to: /// /// - A `TypedPath` implementation. -/// - A [`FromRequest`] implementation compatible with [`RouterExt::typed_get`], -/// [`RouterExt::typed_post`], etc. This implementation uses [`Path`] and thus your struct must -/// also implement [`serde::Deserialize`], unless it's a unit struct. +/// - A [`FromRequest`] implementation compatible with [`RouterExt::typed_route`]. This +/// implementation uses [`Path`] and thus your struct must also implement [`serde::Deserialize`], +/// unless it's a unit struct. /// - A [`Display`] implementation that interpolates the captures. This can be used to, among other /// things, create links to known paths and have them verified statically. Note that the /// [`Display`] implementation for each field must return something that's compatible with its @@ -118,8 +68,7 @@ use super::sealed::Sealed; /// ``` /// /// [`FromRequest`]: axum::extract::FromRequest -/// [`RouterExt::typed_get`]: super::RouterExt::typed_get -/// [`RouterExt::typed_post`]: super::RouterExt::typed_post +/// [`RouterExt::typed_route`]: super::RouterExt::typed_route /// [`Path`]: axum::extract::Path /// [`Display`]: std::fmt::Display /// [`Deserialize`]: serde::Deserialize @@ -128,6 +77,232 @@ pub trait TypedPath: std::fmt::Display { const PATH: &'static str; } +/// A typed HTTP method. +/// +/// This is used to statically connect an HTTP method to its corresponding handler using +/// [`RouterExt::typed_route`]. See that method for more details. +/// +/// This trait is sealed such that it cannot be implemented outside this crate. +/// +/// [`RouterExt::typed_route`]: super::RouterExt::typed_route +pub trait TypedMethod: Sealed { + /// Wrap a handler in a [`MethodRouter`] that accepts this type's corresponding HTTP method. + fn apply_method_router<H, B, T>(handler: H) -> MethodRouter<B> + where + H: Handler<T, B>, + B: Send + 'static, + T: 'static; + + /// Check if the request matches this type's corresponding HTTP method. + fn matches_method(method: &http::Method) -> bool; +} + +macro_rules! typed_method { + ($name:ident, $method_router_constructor:ident, $method:ident) => { + #[doc = concat!("A `TypedMethod` that accepts `", stringify!($method), "` requests.")] + #[derive(Clone, Copy, Debug)] + pub struct $name; + + impl Sealed for $name {} + + impl TypedMethod for $name { + fn apply_method_router<H, B, T>(handler: H) -> MethodRouter<B> + where + H: Handler<T, B>, + B: Send + 'static, + T: 'static, + { + axum::routing::$method_router_constructor(handler) + } + + fn matches_method(method: &http::Method) -> bool { + method == http::Method::$method + } + } + + #[async_trait] + impl<B> FromRequest<B> for $name + where + B: Send, + { + type Rejection = http::StatusCode; + + async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + if Self::matches_method(req.method()) { + Ok(Self) + } else { + Err(http::StatusCode::NOT_FOUND) + } + } + } + }; +} + +typed_method!(Delete, delete, DELETE); +typed_method!(Get, get, GET); +typed_method!(Head, head, HEAD); +typed_method!(Options, options, OPTIONS); +typed_method!(Patch, patch, PATCH); +typed_method!(Post, post, POST); +typed_method!(Put, put, PUT); +typed_method!(Trace, trace, TRACE); + +/// A [`TypedMethod`] that accepts all HTTP methods. +/// +/// # Example +/// +/// ```rust +/// use axum_extra::routing::{TypedPath, Any, RouterExt}; +/// use axum::Router; +/// +/// #[derive(TypedPath)] +/// #[typed_path("/foo")] +/// struct Foo; +/// +/// // This accepts `/foo` with any HTTP method. +/// async fn foo(_: Any, _: Foo) {} +/// +/// let app = Router::new().typed_route(foo); +/// # +/// # let app: Router<axum::body::Body> = app; +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct Any; + +impl Sealed for Any {} + +impl TypedMethod for Any { + fn apply_method_router<H, B, T>(handler: H) -> MethodRouter<B> + where + H: Handler<T, B>, + B: Send + 'static, + T: 'static, + { + axum::routing::any(handler) + } + + fn matches_method(_method: &http::Method) -> bool { + true + } +} + +#[async_trait] +impl<B> FromRequest<B> for Any +where + B: Send, +{ + type Rejection = Infallible; + + async fn from_request(_: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + Ok(Self) + } +} + +/// A [`TypedMethod`] that accepts one of a number of HTTP methods. +/// +/// # Example +/// +/// ```rust +/// use axum_extra::routing::{TypedPath, OneOf, Patch, Put, RouterExt}; +/// use axum::Router; +/// +/// #[derive(TypedPath)] +/// #[typed_path("/foo")] +/// struct Foo; +/// +/// // This accepts `PATCH /foo` and `PUT /foo` +/// async fn foo(_: OneOf<(Patch, Put)>, _: Foo) {} +/// +/// let app = Router::new().typed_route(foo); +/// # +/// # let app: Router<axum::body::Body> = app; +/// ``` +pub struct OneOf<T>(PhantomData<T>); + +macro_rules! one_of { + ($($ty:ident),* $(,)?) => { + impl<$($ty,)*> TypedMethod for OneOf<($($ty,)*)> + where + $( $ty: TypedMethod, )* + { + #[allow(clippy::redundant_clone, unused_mut, unused_variables)] + fn apply_method_router<H, B, T>(handler: H) -> MethodRouter<B> + where + H: Handler<T, B>, + B: Send + 'static, + T: 'static, + { + let mut method_router = MethodRouter::new(); + $( + method_router = method_router.merge($ty::apply_method_router(handler.clone())); + )* + method_router + } + + #[allow(unused_variables)] + fn matches_method(method: &http::Method) -> bool { + $( + if $ty::matches_method(method) { + return true; + } + )* + false + } + } + + impl<$($ty,)*> Sealed for OneOf<($($ty,)*)> {} + + #[async_trait] + impl<B, $($ty,)*> FromRequest<B> for OneOf<($($ty,)*)> + where + B: Send, + $( $ty: TypedMethod + FromRequest<B>, )* + { + type Rejection = http::StatusCode; + + async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + if Self::matches_method(req.method()) { + Ok(Self(PhantomData)) + } else { + Err(http::StatusCode::NOT_FOUND) + } + } + } + }; +} + +one_of!(); +one_of!(T1,); +one_of!(T1, T2); +one_of!(T1, T2, T3); +one_of!(T1, T2, T3, T4); +one_of!(T1, T2, T3, T4, T5); +one_of!(T1, T2, T3, T4, T5, T6); +one_of!(T1, T2, T3, T4, T5, T6, T7); +one_of!(T1, T2, T3, T4, T5, T6, T7, T8); + +impl<T> fmt::Debug for OneOf<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("OneOf") + .field(&format_args!("{}", std::any::type_name::<T>())) + .finish() + } +} + +impl<T> Default for OneOf<T> { + fn default() -> Self { + Self(Default::default()) + } +} + +impl<T> Clone for OneOf<T> { + fn clone(&self) -> Self { + Self(self.0) + } +} + +impl<T> Copy for OneOf<T> {} + /// Utility trait used with [`RouterExt`] to ensure the first element of a tuple type is a /// given type. /// @@ -139,18 +314,20 @@ pub trait TypedPath: std::fmt::Display { /// It is sealed such that it cannot be implemented outside this crate. /// /// [`RouterExt`]: super::RouterExt -pub trait FirstElementIs<P>: Sealed {} +pub trait FirstTwoElementsAre<M, P>: Sealed {} macro_rules! impl_first_element_is { ( $($ty:ident),* $(,)? ) => { - impl<P, $($ty,)*> FirstElementIs<P> for (P, $($ty,)*) + impl<M, P, $($ty,)*> FirstTwoElementsAre<M, P> for (M, P, $($ty,)*) where - P: TypedPath + M: TypedMethod, + P: TypedPath, {} - impl<P, $($ty,)*> Sealed for (P, $($ty,)*) + impl<M, P, $($ty,)*> Sealed for (M, P, $($ty,)*) where - P: TypedPath + M: TypedMethod, + P: TypedPath, {} }; } diff --git a/examples/hello-world/Cargo.toml b/examples/hello-world/Cargo.toml index 54a2fd98..36b5dfc6 100644 --- a/examples/hello-world/Cargo.toml +++ b/examples/hello-world/Cargo.toml @@ -7,6 +7,3 @@ publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } - -axum-extra = { path = "../../axum-extra", features = ["typed-routing"] } -serde = { version = "1.0", features = ["derive"] } diff --git a/examples/hello-world/src/main.rs b/examples/hello-world/src/main.rs index 07105679..466caceb 100644 --- a/examples/hello-world/src/main.rs +++ b/examples/hello-world/src/main.rs @@ -4,52 +4,23 @@ //! cargo run -p example-hello-world //! ``` -// Just using this file for manual testing. Will be cleaned up before an eventual merge - -use axum::{response::IntoResponse, Router}; -use axum_extra::routing::{RouterExt, TypedPath}; -use serde::Deserialize; +use axum::{response::Html, routing::get, Router}; +use std::net::SocketAddr; #[tokio::main] async fn main() { - let app = Router::new() - .typed_get(users_index) - .typed_post(users_create) - .typed_get(users_show) - .typed_get(users_edit); + // build our application with a route + let app = Router::new().route("/", get(handler)); - axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) + // run it + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + println!("listening on {}", addr); + axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } -#[derive(TypedPath)] -#[typed_path("/users")] -struct UsersCollection; - -#[derive(Deserialize, TypedPath)] -#[typed_path("/users/:id")] -struct UsersMember { - id: u32, -} - -#[derive(Deserialize, TypedPath)] -#[typed_path("/users/:id/edit")] -struct UsersEdit(u32); - -async fn users_index(_: UsersCollection) -> impl IntoResponse { - "users#index" -} - -async fn users_create(_: UsersCollection, _payload: String) -> impl IntoResponse { - "users#create" -} - -async fn users_show(UsersMember { id }: UsersMember) -> impl IntoResponse { - format!("users#show: {}", id) -} - -async fn users_edit(UsersEdit(id): UsersEdit) -> impl IntoResponse { - format!("users#edit: {}", id) +async fn handler() -> Html<&'static str> { + Html("<h1>Hello, World!</h1>") } diff --git a/examples/sse/src/main.rs b/examples/sse/src/main.rs index 772e9d01..270fa4c9 100644 --- a/examples/sse/src/main.rs +++ b/examples/sse/src/main.rs @@ -37,7 +37,10 @@ async fn main() { let app = Router::new() .fallback(static_files_service) .route("/sse", get(sse_handler)) - .layer(TraceLayer::new_for_http()); + .layer( + TraceLayer::new_for_http() + .make_span_with(tower_http::trace::DefaultMakeSpan::new().include_headers(true)), + ); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000));