From 852e548e19c7d44de469a37688d5e233d5b475bd Mon Sep 17 00:00:00 2001
From: David Pedersen <david.pdrsn@gmail.com>
Date: Sun, 8 May 2022 20:04:56 +0200
Subject: [PATCH] Support `#[derive(FromRequest)]` on enums (#1009)

* Support `#[from_request(via(...))]` on enums

* Check `#[from_request]` on variants

* check for non enum/struct and clean up

* changelog

* changelog

* remove needless feature

* changelog ref
---
 axum-macros/CHANGELOG.md                      |   2 +
 axum-macros/src/from_request.rs               | 146 +++++++++++++++---
 axum-macros/src/from_request/attr.rs          |   6 +-
 .../enum_from_request_ident_in_variant.rs     |  12 ++
 .../enum_from_request_ident_in_variant.stderr |   5 +
 .../fail/enum_from_request_on_variant.rs      |  10 ++
 .../fail/enum_from_request_on_variant.stderr  |   5 +
 .../tests/from_request/fail/enum_no_via.rs    |   6 +
 .../from_request/fail/enum_no_via.stderr      |   7 +
 .../fail/enum_rejection_derive.rs             |   7 +
 .../fail/enum_rejection_derive.stderr         |   5 +
 .../from_request/fail/not_enum_or_struct.rs   |   6 +
 .../fail/not_enum_or_struct.stderr            |  11 ++
 .../tests/from_request/pass/enum_via.rs       |  12 ++
 14 files changed, 215 insertions(+), 25 deletions(-)
 create mode 100644 axum-macros/tests/from_request/fail/enum_from_request_ident_in_variant.rs
 create mode 100644 axum-macros/tests/from_request/fail/enum_from_request_ident_in_variant.stderr
 create mode 100644 axum-macros/tests/from_request/fail/enum_from_request_on_variant.rs
 create mode 100644 axum-macros/tests/from_request/fail/enum_from_request_on_variant.stderr
 create mode 100644 axum-macros/tests/from_request/fail/enum_no_via.rs
 create mode 100644 axum-macros/tests/from_request/fail/enum_no_via.stderr
 create mode 100644 axum-macros/tests/from_request/fail/enum_rejection_derive.rs
 create mode 100644 axum-macros/tests/from_request/fail/enum_rejection_derive.stderr
 create mode 100644 axum-macros/tests/from_request/fail/not_enum_or_struct.rs
 create mode 100644 axum-macros/tests/from_request/fail/not_enum_or_struct.stderr
 create mode 100644 axum-macros/tests/from_request/pass/enum_via.rs

diff --git a/axum-macros/CHANGELOG.md b/axum-macros/CHANGELOG.md
index 6a1377d0..a983d92b 100644
--- a/axum-macros/CHANGELOG.md
+++ b/axum-macros/CHANGELOG.md
@@ -9,9 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 - **fixed:** `Option` and `Result` are now supported in typed path route handler parameters ([#1001])
 - **fixed:** Support wildcards in typed paths ([#1003])
+- **added:** Support `#[derive(FromRequest)]` on enums using `#[from_request(via(OtherExtractor))]` ([#1009])
 
 [#1001]: https://github.com/tokio-rs/axum/pull/1001
 [#1003]: https://github.com/tokio-rs/axum/pull/1003
+[#1009]: https://github.com/tokio-rs/axum/pull/1009
 
 # 0.2.0 (31. March, 2022)
 
diff --git a/axum-macros/src/from_request.rs b/axum-macros/src/from_request.rs
index 901b2688..76634ee9 100644
--- a/axum-macros/src/from_request.rs
+++ b/axum-macros/src/from_request.rs
@@ -3,24 +3,77 @@ use self::attr::{
     RejectionDeriveOptOuts,
 };
 use heck::ToUpperCamelCase;
-use proc_macro2::TokenStream;
+use proc_macro2::{Span, TokenStream};
 use quote::{format_ident, quote, quote_spanned};
 use syn::{punctuated::Punctuated, spanned::Spanned, Token};
 
 mod attr;
 
-const GENERICS_ERROR: &str = "`#[derive(FromRequest)] doesn't support generics";
+pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> {
+    match item {
+        syn::Item::Struct(item) => {
+            let syn::ItemStruct {
+                attrs,
+                ident,
+                generics,
+                fields,
+                semi_token: _,
+                vis,
+                struct_token: _,
+            } = item;
 
-pub(crate) fn expand(item: syn::ItemStruct) -> syn::Result<TokenStream> {
-    let syn::ItemStruct {
-        attrs,
-        ident,
-        generics,
-        fields,
-        semi_token: _,
-        vis,
-        struct_token: _,
-    } = item;
+            error_on_generics(generics)?;
+
+            match parse_container_attrs(&attrs)? {
+                FromRequestContainerAttr::Via(path) => {
+                    impl_struct_by_extracting_all_at_once(ident, fields, path)
+                }
+                FromRequestContainerAttr::RejectionDerive(_, opt_outs) => {
+                    impl_struct_by_extracting_each_field(ident, fields, vis, opt_outs)
+                }
+                FromRequestContainerAttr::None => impl_struct_by_extracting_each_field(
+                    ident,
+                    fields,
+                    vis,
+                    RejectionDeriveOptOuts::default(),
+                ),
+            }
+        }
+        syn::Item::Enum(item) => {
+            let syn::ItemEnum {
+                attrs,
+                vis: _,
+                enum_token: _,
+                ident,
+                generics,
+                brace_token: _,
+                variants,
+            } = item;
+
+            error_on_generics(generics)?;
+
+            match parse_container_attrs(&attrs)? {
+                FromRequestContainerAttr::Via(path) => {
+                    impl_enum_by_extracting_all_at_once(ident, variants, path)
+                }
+                FromRequestContainerAttr::RejectionDerive(rejection_derive, _) => {
+                    Err(syn::Error::new_spanned(
+                        rejection_derive,
+                        "cannot use `rejection_derive` on enums",
+                    ))
+                }
+                FromRequestContainerAttr::None => Err(syn::Error::new(
+                    Span::call_site(),
+                    "missing `#[from_request(via(...))]`",
+                )),
+            }
+        }
+        _ => Err(syn::Error::new_spanned(item, "expected `struct` or `enum`")),
+    }
+}
+
+fn error_on_generics(generics: syn::Generics) -> syn::Result<()> {
+    const GENERICS_ERROR: &str = "`#[derive(FromRequest)] doesn't support generics";
 
     if !generics.params.is_empty() {
         return Err(syn::Error::new_spanned(generics, GENERICS_ERROR));
@@ -30,18 +83,10 @@ pub(crate) fn expand(item: syn::ItemStruct) -> syn::Result<TokenStream> {
         return Err(syn::Error::new_spanned(where_clause, GENERICS_ERROR));
     }
 
-    match parse_container_attrs(&attrs)? {
-        FromRequestContainerAttr::Via(path) => impl_by_extracting_all_at_once(ident, fields, path),
-        FromRequestContainerAttr::RejectionDerive(opt_outs) => {
-            impl_by_extracting_each_field(ident, fields, vis, opt_outs)
-        }
-        FromRequestContainerAttr::None => {
-            impl_by_extracting_each_field(ident, fields, vis, RejectionDeriveOptOuts::default())
-        }
-    }
+    Ok(())
 }
 
-fn impl_by_extracting_each_field(
+fn impl_struct_by_extracting_each_field(
     ident: syn::Ident,
     fields: syn::Fields,
     vis: syn::Visibility,
@@ -413,7 +458,7 @@ fn rejection_variant_name(field: &syn::Field) -> syn::Result<syn::Ident> {
     }
 }
 
-fn impl_by_extracting_all_at_once(
+fn impl_struct_by_extracting_all_at_once(
     ident: syn::Ident,
     fields: syn::Fields,
     path: syn::Path,
@@ -459,6 +504,61 @@ fn impl_by_extracting_all_at_once(
     })
 }
 
+fn impl_enum_by_extracting_all_at_once(
+    ident: syn::Ident,
+    variants: Punctuated<syn::Variant, Token![,]>,
+    path: syn::Path,
+) -> syn::Result<TokenStream> {
+    for variant in variants {
+        let FromRequestFieldAttr { via } = parse_field_attrs(&variant.attrs)?;
+        if let Some((via, _)) = via {
+            return Err(syn::Error::new_spanned(
+                via,
+                "`#[from_request(via(...))]` cannot be used on variants",
+            ));
+        }
+
+        let fields = match variant.fields {
+            syn::Fields::Named(fields) => fields.named.into_iter(),
+            syn::Fields::Unnamed(fields) => fields.unnamed.into_iter(),
+            syn::Fields::Unit => Punctuated::<_, Token![,]>::new().into_iter(),
+        };
+
+        for field in fields {
+            let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
+            if let Some((via, _)) = via {
+                return Err(syn::Error::new_spanned(
+                    via,
+                    "`#[from_request(via(...))]` cannot be used inside variants",
+                ));
+            }
+        }
+    }
+
+    let path_span = path.span();
+
+    Ok(quote_spanned! {path_span=>
+        #[::axum::async_trait]
+        #[automatically_derived]
+        impl<B> ::axum::extract::FromRequest<B> for #ident
+        where
+            B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
+            B::Data: ::std::marker::Send,
+            B::Error: ::std::convert::Into<::axum::BoxError>,
+        {
+            type Rejection = <#path<Self> as ::axum::extract::FromRequest<B>>::Rejection;
+
+            async fn from_request(
+                req: &mut ::axum::extract::RequestParts<B>,
+            ) -> ::std::result::Result<Self, Self::Rejection> {
+                ::axum::extract::FromRequest::<B>::from_request(req)
+                    .await
+                    .map(|#path(inner)| inner)
+            }
+        }
+    })
+}
+
 #[test]
 fn ui() {
     #[rustversion::stable]
diff --git a/axum-macros/src/from_request/attr.rs b/axum-macros/src/from_request/attr.rs
index 25c989a0..657c7bd6 100644
--- a/axum-macros/src/from_request/attr.rs
+++ b/axum-macros/src/from_request/attr.rs
@@ -12,7 +12,7 @@ pub(crate) struct FromRequestFieldAttr {
 
 pub(crate) enum FromRequestContainerAttr {
     Via(syn::Path),
-    RejectionDerive(RejectionDeriveOptOuts),
+    RejectionDerive(kw::rejection_derive, RejectionDeriveOptOuts),
     None,
 }
 
@@ -91,7 +91,9 @@ pub(crate) fn parse_container_attrs(
             }
         }
         (Some((_, _, path)), None) => Ok(FromRequestContainerAttr::Via(path)),
-        (None, Some((_, _, opt_outs))) => Ok(FromRequestContainerAttr::RejectionDerive(opt_outs)),
+        (None, Some((_, rejection_derive, opt_outs))) => Ok(
+            FromRequestContainerAttr::RejectionDerive(rejection_derive, opt_outs),
+        ),
         (None, None) => Ok(FromRequestContainerAttr::None),
     }
 }
diff --git a/axum-macros/tests/from_request/fail/enum_from_request_ident_in_variant.rs b/axum-macros/tests/from_request/fail/enum_from_request_ident_in_variant.rs
new file mode 100644
index 00000000..336850e5
--- /dev/null
+++ b/axum-macros/tests/from_request/fail/enum_from_request_ident_in_variant.rs
@@ -0,0 +1,12 @@
+use axum_macros::FromRequest;
+
+#[derive(FromRequest, Clone)]
+#[from_request(via(axum::Extension))]
+enum Extractor {
+    Foo {
+        #[from_request(via(axum::Extension))]
+        foo: (),
+    }
+}
+
+fn main() {}
diff --git a/axum-macros/tests/from_request/fail/enum_from_request_ident_in_variant.stderr b/axum-macros/tests/from_request/fail/enum_from_request_ident_in_variant.stderr
new file mode 100644
index 00000000..f998004a
--- /dev/null
+++ b/axum-macros/tests/from_request/fail/enum_from_request_ident_in_variant.stderr
@@ -0,0 +1,5 @@
+error: `#[from_request(via(...))]` cannot be used inside variants
+ --> tests/from_request/fail/enum_from_request_ident_in_variant.rs:7:24
+  |
+7 |         #[from_request(via(axum::Extension))]
+  |                        ^^^
diff --git a/axum-macros/tests/from_request/fail/enum_from_request_on_variant.rs b/axum-macros/tests/from_request/fail/enum_from_request_on_variant.rs
new file mode 100644
index 00000000..6b825077
--- /dev/null
+++ b/axum-macros/tests/from_request/fail/enum_from_request_on_variant.rs
@@ -0,0 +1,10 @@
+use axum_macros::FromRequest;
+
+#[derive(FromRequest, Clone)]
+#[from_request(via(axum::Extension))]
+enum Extractor {
+    #[from_request(via(axum::Extension))]
+    Foo,
+}
+
+fn main() {}
diff --git a/axum-macros/tests/from_request/fail/enum_from_request_on_variant.stderr b/axum-macros/tests/from_request/fail/enum_from_request_on_variant.stderr
new file mode 100644
index 00000000..9818da51
--- /dev/null
+++ b/axum-macros/tests/from_request/fail/enum_from_request_on_variant.stderr
@@ -0,0 +1,5 @@
+error: `#[from_request(via(...))]` cannot be used on variants
+ --> tests/from_request/fail/enum_from_request_on_variant.rs:6:20
+  |
+6 |     #[from_request(via(axum::Extension))]
+  |                    ^^^
diff --git a/axum-macros/tests/from_request/fail/enum_no_via.rs b/axum-macros/tests/from_request/fail/enum_no_via.rs
new file mode 100644
index 00000000..f5add219
--- /dev/null
+++ b/axum-macros/tests/from_request/fail/enum_no_via.rs
@@ -0,0 +1,6 @@
+use axum_macros::FromRequest;
+
+#[derive(FromRequest, Clone)]
+enum Extractor {}
+
+fn main() {}
diff --git a/axum-macros/tests/from_request/fail/enum_no_via.stderr b/axum-macros/tests/from_request/fail/enum_no_via.stderr
new file mode 100644
index 00000000..e1f86d53
--- /dev/null
+++ b/axum-macros/tests/from_request/fail/enum_no_via.stderr
@@ -0,0 +1,7 @@
+error: missing `#[from_request(via(...))]`
+ --> tests/from_request/fail/enum_no_via.rs:3:10
+  |
+3 | #[derive(FromRequest, Clone)]
+  |          ^^^^^^^^^^^
+  |
+  = note: this error originates in the derive macro `FromRequest` (in Nightly builds, run with -Z macro-backtrace for more info)
diff --git a/axum-macros/tests/from_request/fail/enum_rejection_derive.rs b/axum-macros/tests/from_request/fail/enum_rejection_derive.rs
new file mode 100644
index 00000000..f343d544
--- /dev/null
+++ b/axum-macros/tests/from_request/fail/enum_rejection_derive.rs
@@ -0,0 +1,7 @@
+use axum_macros::FromRequest;
+
+#[derive(FromRequest, Clone)]
+#[from_request(rejection_derive(!Error))]
+enum Extractor {}
+
+fn main() {}
diff --git a/axum-macros/tests/from_request/fail/enum_rejection_derive.stderr b/axum-macros/tests/from_request/fail/enum_rejection_derive.stderr
new file mode 100644
index 00000000..1e721d76
--- /dev/null
+++ b/axum-macros/tests/from_request/fail/enum_rejection_derive.stderr
@@ -0,0 +1,5 @@
+error: cannot use `rejection_derive` on enums
+ --> tests/from_request/fail/enum_rejection_derive.rs:4:16
+  |
+4 | #[from_request(rejection_derive(!Error))]
+  |                ^^^^^^^^^^^^^^^^
diff --git a/axum-macros/tests/from_request/fail/not_enum_or_struct.rs b/axum-macros/tests/from_request/fail/not_enum_or_struct.rs
new file mode 100644
index 00000000..989ca26d
--- /dev/null
+++ b/axum-macros/tests/from_request/fail/not_enum_or_struct.rs
@@ -0,0 +1,6 @@
+use axum_macros::FromRequest;
+
+#[derive(FromRequest)]
+union Extractor {}
+
+fn main() {}
diff --git a/axum-macros/tests/from_request/fail/not_enum_or_struct.stderr b/axum-macros/tests/from_request/fail/not_enum_or_struct.stderr
new file mode 100644
index 00000000..6a755643
--- /dev/null
+++ b/axum-macros/tests/from_request/fail/not_enum_or_struct.stderr
@@ -0,0 +1,11 @@
+error: expected `struct` or `enum`
+ --> tests/from_request/fail/not_enum_or_struct.rs:4:1
+  |
+4 | union Extractor {}
+  | ^^^^^^^^^^^^^^^^^^
+
+error: unions cannot have zero fields
+ --> tests/from_request/fail/not_enum_or_struct.rs:4:1
+  |
+4 | union Extractor {}
+  | ^^^^^^^^^^^^^^^^^^
diff --git a/axum-macros/tests/from_request/pass/enum_via.rs b/axum-macros/tests/from_request/pass/enum_via.rs
new file mode 100644
index 00000000..d6ba90e2
--- /dev/null
+++ b/axum-macros/tests/from_request/pass/enum_via.rs
@@ -0,0 +1,12 @@
+use axum::{body::Body, routing::get, Extension, Router};
+use axum_macros::FromRequest;
+
+#[derive(FromRequest, Clone)]
+#[from_request(via(Extension))]
+enum Extractor {}
+
+async fn foo(_: Extractor) {}
+
+fn main() {
+    Router::<Body>::new().route("/", get(foo));
+}