diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index 432bb293..f74ca796 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -15,5 +15,13 @@ erased-json = ["serde", "serde_json"] [dependencies] axum = { path = "../axum", version = "0.3" } +tower-service = "0.3" + +# optional dependencies serde = { version = "1.0.130", optional = true } serde_json = { version = "1.0.71", optional = true } + +[dev-dependencies] +hyper = "0.14" +tokio = { version = "1.14", features = ["full"] } +tower = { version = "0.4", features = ["util"] } diff --git a/axum-extra/src/lib.rs b/axum-extra/src/lib.rs index 945d9318..d1c04dd3 100644 --- a/axum-extra/src/lib.rs +++ b/axum-extra/src/lib.rs @@ -43,3 +43,4 @@ #![cfg_attr(test, allow(clippy::float_cmp))] pub mod response; +pub mod routing; diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs new file mode 100644 index 00000000..df83ff52 --- /dev/null +++ b/axum-extra/src/routing/mod.rs @@ -0,0 +1,35 @@ +//! Additional types for defining routes. + +use axum::Router; + +mod resource; + +pub use self::resource::Resource; + +/// Extension trait that adds additional methods to [`Router`]. +pub trait RouterExt: sealed::Sealed { + /// Add a [`Resource`] to the router. + /// + /// See [`Resource`] for more details. + fn resource(self, name: &str, f: F) -> Self + where + F: FnOnce(resource::Resource) -> resource::Resource; +} + +impl RouterExt for Router { + fn resource(self, name: &str, f: F) -> Self + where + F: FnOnce(resource::Resource) -> resource::Resource, + { + f(resource::Resource { + name: name.to_owned(), + router: self, + }) + .router + } +} + +mod sealed { + pub trait Sealed {} + impl Sealed for axum::Router {} +} diff --git a/axum-extra/src/routing/resource.rs b/axum-extra/src/routing/resource.rs new file mode 100644 index 00000000..636cc0db --- /dev/null +++ b/axum-extra/src/routing/resource.rs @@ -0,0 +1,273 @@ +use axum::{ + body::{Body, BoxBody}, + handler::Handler, + http::{Request, Response}, + routing::{delete, get, on, post, MethodFilter}, + Router, +}; +use std::convert::Infallible; +use tower_service::Service; + +/// A resource which defines a set of conventional CRUD routes. +/// +/// # Example +/// +/// ```rust +/// use axum::{Router, routing::get, extract::Path}; +/// use axum_extra::routing::RouterExt; +/// +/// let app = Router::new().resource("users", |r| { +/// // Define a route for `GET /users` +/// r.index(|| async {}) +/// // `POST /users` +/// .create(|| async {}) +/// // `GET /users/new` +/// .new(|| async {}) +/// // `GET /users/:users_id` +/// .show(|Path(user_id): Path| async {}) +/// // `GET /users/:users_id/edit` +/// .edit(|Path(user_id): Path| async {}) +/// // `PUT or PATCH /users/:users_id` +/// .update(|Path(user_id): Path| async {}) +/// // `DELETE /users/:users_id` +/// .destroy(|Path(user_id): Path| async {}) +/// // Nest another router at the "member level" +/// // This defines a route for `GET /users/:users_id/tweets` +/// .nest(Router::new().route( +/// "/tweets", +/// get(|Path(user_id): Path| async {}), +/// )) +/// // Nest another router at the "collection level" +/// // This defines a route for `GET /users/featured` +/// .nest_collection( +/// Router::new().route("/featured", get(|| async {})), +/// ) +/// }); +/// # let _: Router = app; +/// ``` +#[derive(Debug)] +pub struct Resource { + pub(crate) name: String, + pub(crate) router: Router, +} + +impl Resource { + fn index_create_path(&self) -> String { + format!("/{}", self.name) + } + + fn show_update_destroy_path(&self) -> String { + format!("/{0}/:{0}_id", self.name) + } + + fn route(mut self, path: &str, svc: T) -> Self + where + T: Service, Response = Response, Error = Infallible> + + Clone + + Send + + 'static, + T::Future: Send + 'static, + { + self.router = self.router.route(path, svc); + self + } + + /// Add a handler at `GET /resource_name`. + pub fn index(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + let path = self.index_create_path(); + self.route(&path, get(handler)) + } + + /// Add a handler at `POST /{resource_name}`. + pub fn create(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + let path = self.index_create_path(); + self.route(&path, post(handler)) + } + + /// Add a handler at `GET /{resource_name}/new`. + pub fn new(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + let path = format!("/{}/new", self.name); + self.route(&path, get(handler)) + } + + /// Add a handler at `GET /{resource_name}/:{resource_name}_id`. + pub fn show(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + let path = self.show_update_destroy_path(); + self.route(&path, get(handler)) + } + + /// Add a handler at `GET /{resource_name}/:{resource_name}_id/edit`. + pub fn edit(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + let path = format!("/{0}/:{0}_id/edit", self.name); + self.route(&path, get(handler)) + } + + /// Add a handler at `PUT or PATCH /resource_name/:{resource_name}_id`. + pub fn update(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + let path = self.show_update_destroy_path(); + self.route(&path, on(MethodFilter::PUT | MethodFilter::PATCH, handler)) + } + + /// Add a handler at `DELETE /{resource_name}/:{resource_name}_id`. + pub fn destroy(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + let path = self.show_update_destroy_path(); + self.route(&path, delete(handler)) + } + + /// Nest another route at the "member level". + /// + /// The routes will be nested at `/{resource_name}/:{resource_name}_id`. + pub fn nest(mut self, svc: T) -> Self + where + T: Service, Response = Response, Error = Infallible> + + Clone + + Send + + 'static, + T::Future: Send + 'static, + { + let path = self.show_update_destroy_path(); + self.router = self.router.nest(&path, svc); + self + } + + /// Nest another route at the "collection level". + /// + /// The routes will be nested at `/{resource_name}`. + pub fn nest_collection(mut self, svc: T) -> Self + where + T: Service, Response = Response, Error = Infallible> + + Clone + + Send + + 'static, + T::Future: Send + 'static, + { + let path = self.index_create_path(); + self.router = self.router.nest(&path, svc); + self + } +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use crate::routing::RouterExt; + use axum::{extract::Path, http::Method, Router}; + use tower::ServiceExt; + + #[tokio::test] + async fn works() { + let mut app = Router::new().resource("users", |r| { + r.index(|| async { "users#index" }) + .create(|| async { "users#create" }) + .new(|| async { "users#new" }) + .show(|Path(id): Path| async move { format!("users#show id={}", id) }) + .edit(|Path(id): Path| async move { format!("users#edit id={}", id) }) + .update(|Path(id): Path| async move { format!("users#update id={}", id) }) + .destroy(|Path(id): Path| async move { format!("users#destroy id={}", id) }) + .nest(Router::new().route( + "/tweets", + get(|Path(id): Path| async move { format!("users#tweets id={}", id) }), + )) + .nest_collection( + Router::new().route("/featured", get(|| async move { "users#featured" })), + ) + }); + + assert_eq!( + call_route(&mut app, Method::GET, "/users").await, + "users#index" + ); + + assert_eq!( + call_route(&mut app, Method::POST, "/users").await, + "users#create" + ); + + assert_eq!( + call_route(&mut app, Method::GET, "/users/new").await, + "users#new" + ); + + assert_eq!( + call_route(&mut app, Method::GET, "/users/1").await, + "users#show id=1" + ); + + assert_eq!( + call_route(&mut app, Method::GET, "/users/1/edit").await, + "users#edit id=1" + ); + + assert_eq!( + call_route(&mut app, Method::PATCH, "/users/1").await, + "users#update id=1" + ); + + assert_eq!( + call_route(&mut app, Method::PUT, "/users/1").await, + "users#update id=1" + ); + + assert_eq!( + call_route(&mut app, Method::DELETE, "/users/1").await, + "users#destroy id=1" + ); + + assert_eq!( + call_route(&mut app, Method::GET, "/users/1/tweets").await, + "users#tweets id=1" + ); + + assert_eq!( + call_route(&mut app, Method::GET, "/users/featured").await, + "users#featured" + ); + } + + async fn call_route(app: &mut Router, method: Method, uri: &str) -> String { + let res = app + .ready() + .await + .unwrap() + .call( + Request::builder() + .method(method) + .uri(uri) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + let bytes = hyper::body::to_bytes(res).await.unwrap(); + String::from_utf8(bytes.to_vec()).unwrap() + } +}