Refactor proc-macro attribute parsing (#1369)

* Refactor proc-macro attribute parsing

* Remove `#[allow(warnings)]` which was accidentally committed

* Change span for "cannot use `rejection` without `via`" error for enums

* fix test
This commit is contained in:
David Pedersen 2022-09-12 20:10:58 +02:00 committed by GitHub
parent 54d8439e35
commit 8da69a98fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 156 additions and 177 deletions

View file

@ -0,0 +1,61 @@
use quote::ToTokens;
use syn::parse::{Parse, ParseStream};
pub(crate) fn parse_parenthesized_attribute<K, T>(
input: ParseStream,
out: &mut Option<(K, T)>,
) -> syn::Result<()>
where
K: Parse + ToTokens,
T: Parse,
{
let kw = input.parse()?;
let content;
syn::parenthesized!(content in input);
let inner = content.parse()?;
if out.is_some() {
let kw_name = std::any::type_name::<K>().split("::").last().unwrap();
let msg = format!("`{}` specified more than once", kw_name);
return Err(syn::Error::new_spanned(kw, msg));
}
*out = Some((kw, inner));
Ok(())
}
pub(crate) trait Combine: Sized {
fn combine(self, other: Self) -> syn::Result<Self>;
}
pub(crate) fn parse_attrs<T>(ident: &str, attrs: &[syn::Attribute]) -> syn::Result<T>
where
T: Combine + Default + Parse,
{
attrs
.iter()
.filter(|attr| attr.path.is_ident(ident))
.map(|attr| attr.parse_args::<T>())
.try_fold(T::default(), |out, next| out.combine(next?))
}
pub(crate) fn combine_attribute<K, T>(a: &mut Option<(K, T)>, b: Option<(K, T)>) -> syn::Result<()>
where
K: ToTokens,
{
if let Some((kw, inner)) = b {
if a.is_some() {
let kw_name = std::any::type_name::<K>().split("::").last().unwrap();
let msg = format!("`{}` specified more than once", kw_name);
return Err(syn::Error::new_spanned(kw, msg));
}
*a = Some((kw, inner));
}
Ok(())
}
pub(crate) fn second<T, K>(tuple: (T, K)) -> K {
tuple.1
}

View file

@ -1,5 +1,7 @@
use self::attr::{
parse_container_attrs, parse_field_attrs, FromRequestContainerAttr, FromRequestFieldAttr,
use self::attr::FromRequestContainerAttrs;
use crate::{
attr_parsing::{parse_attrs, second},
from_request::attr::FromRequestFieldAttrs,
};
use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned};
@ -38,26 +40,20 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
let generic_ident = parse_single_generic_type_on_struct(generics, &fields, tr)?;
match parse_container_attrs(&attrs)? {
FromRequestContainerAttr::Via { path, rejection } => {
impl_struct_by_extracting_all_at_once(
ident,
fields,
path,
rejection,
generic_ident,
tr,
)
}
FromRequestContainerAttr::Rejection(rejection) => {
error_on_generic_ident(generic_ident, tr)?;
let FromRequestContainerAttrs { via, rejection } = parse_attrs("from_request", &attrs)?;
impl_struct_by_extracting_each_field(ident, fields, Some(rejection), tr)
}
FromRequestContainerAttr::None => {
match (via.map(second), rejection.map(second)) {
(Some(via), rejection) => impl_struct_by_extracting_all_at_once(
ident,
fields,
via,
rejection,
generic_ident,
tr,
),
(None, rejection) => {
error_on_generic_ident(generic_ident, tr)?;
impl_struct_by_extracting_each_field(ident, fields, None, tr)
impl_struct_by_extracting_each_field(ident, fields, rejection, tr)
}
}
}
@ -82,15 +78,21 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
return Err(syn::Error::new_spanned(where_clause, generics_error));
}
match parse_container_attrs(&attrs)? {
FromRequestContainerAttr::Via { path, rejection } => {
impl_enum_by_extracting_all_at_once(ident, variants, path, rejection, tr)
}
FromRequestContainerAttr::Rejection(rejection) => Err(syn::Error::new_spanned(
rejection,
let FromRequestContainerAttrs { via, rejection } = parse_attrs("from_request", &attrs)?;
match (via.map(second), rejection) {
(Some(via), rejection) => impl_enum_by_extracting_all_at_once(
ident,
variants,
via,
rejection.map(second),
tr,
),
(None, Some((rejection_kw, _))) => Err(syn::Error::new_spanned(
rejection_kw,
"cannot use `rejection` without `via`",
)),
FromRequestContainerAttr::None => Err(syn::Error::new(
(None, _) => Err(syn::Error::new(
Span::call_site(),
"missing `#[from_request(via(...))]`",
)),
@ -316,7 +318,7 @@ fn extract_fields(
let mut res: Vec<_> = fields_iter
.enumerate()
.map(|(index, field)| {
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
let member = member(field, index);
let ty_span = field.ty.span();
@ -434,7 +436,7 @@ fn extract_fields(
// Handle the last element, if deriving FromRequest
if let Some(field) = last {
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
let member = member(field, fields.len() - 1);
let ty_span = field.ty.span();
@ -557,7 +559,8 @@ fn impl_struct_by_extracting_all_at_once(
};
for field in fields {
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
if let Some((via, _)) = via {
return Err(syn::Error::new_spanned(
via,
@ -695,7 +698,8 @@ fn impl_enum_by_extracting_all_at_once(
tr: Trait,
) -> syn::Result<TokenStream> {
for variant in variants {
let FromRequestFieldAttr { via } = parse_field_attrs(&variant.attrs)?;
let FromRequestFieldAttrs { via } = parse_attrs("from_request", &variant.attrs)?;
if let Some((via, _)) = via {
return Err(syn::Error::new_spanned(
via,
@ -710,7 +714,7 @@ fn impl_enum_by_extracting_all_at_once(
};
for field in fields {
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs)?;
if let Some((via, _)) = via {
return Err(syn::Error::new_spanned(
via,

View file

@ -1,166 +1,79 @@
use quote::ToTokens;
use crate::attr_parsing::{combine_attribute, parse_parenthesized_attribute, Combine};
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
Token,
};
#[derive(Default)]
pub(crate) struct FromRequestFieldAttr {
pub(crate) via: Option<(kw::via, syn::Path)>,
}
pub(crate) enum FromRequestContainerAttr {
Via {
path: syn::Path,
rejection: Option<syn::Path>,
},
Rejection(syn::Path),
None,
}
pub(crate) mod kw {
syn::custom_keyword!(via);
syn::custom_keyword!(rejection);
syn::custom_keyword!(Display);
syn::custom_keyword!(Debug);
syn::custom_keyword!(Error);
}
pub(crate) fn parse_field_attrs(attrs: &[syn::Attribute]) -> syn::Result<FromRequestFieldAttr> {
let attrs = parse_attrs(attrs)?;
let mut out = FromRequestFieldAttr::default();
for from_request_attr in attrs {
match from_request_attr {
FieldAttr::Via { via, path } => {
if out.via.is_some() {
return Err(double_attr_error("via", via));
} else {
out.via = Some((via, path));
}
}
}
}
Ok(out)
#[derive(Default)]
pub(super) struct FromRequestContainerAttrs {
pub(super) via: Option<(kw::via, syn::Path)>,
pub(super) rejection: Option<(kw::rejection, syn::Path)>,
}
pub(crate) fn parse_container_attrs(
attrs: &[syn::Attribute],
) -> syn::Result<FromRequestContainerAttr> {
let attrs = parse_attrs::<ContainerAttr>(attrs)?;
let mut out_via = None;
let mut out_rejection = None;
// we track the index of the attribute to know which comes last
// used to give more accurate error messages
for (idx, from_request_attr) in attrs.into_iter().enumerate() {
match from_request_attr {
ContainerAttr::Via { via, path } => {
if out_via.is_some() {
return Err(double_attr_error("via", via));
} else {
out_via = Some((idx, via, path));
}
}
ContainerAttr::Rejection { rejection, path } => {
if out_rejection.is_some() {
return Err(double_attr_error("rejection", rejection));
} else {
out_rejection = Some((idx, rejection, path));
}
}
}
}
match (out_via, out_rejection) {
(Some((_, _, path)), None) => Ok(FromRequestContainerAttr::Via {
path,
rejection: None,
}),
(Some((_, _, path)), Some((_, _, rejection))) => Ok(FromRequestContainerAttr::Via {
path,
rejection: Some(rejection),
}),
(None, Some((_, _, rejection))) => Ok(FromRequestContainerAttr::Rejection(rejection)),
(None, None) => Ok(FromRequestContainerAttr::None),
}
}
pub(crate) fn parse_attrs<T>(attrs: &[syn::Attribute]) -> syn::Result<Punctuated<T, Token![,]>>
where
T: Parse,
{
let attrs = attrs
.iter()
.filter(|attr| attr.path.is_ident("from_request"))
.map(|attr| attr.parse_args_with(Punctuated::<T, Token![,]>::parse_terminated))
.collect::<syn::Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect::<Punctuated<T, Token![,]>>();
Ok(attrs)
}
fn double_attr_error<T>(ident: &str, spanned: T) -> syn::Error
where
T: ToTokens,
{
syn::Error::new_spanned(spanned, format!("`{}` specified more than once", ident))
}
enum ContainerAttr {
Via {
via: kw::via,
path: syn::Path,
},
Rejection {
rejection: kw::rejection,
path: syn::Path,
},
}
impl Parse for ContainerAttr {
impl Parse for FromRequestContainerAttrs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lh = input.lookahead1();
if lh.peek(kw::via) {
let via = input.parse::<kw::via>()?;
let content;
syn::parenthesized!(content in input);
content.parse().map(|path| Self::Via { via, path })
} else if lh.peek(kw::rejection) {
let rejection = input.parse::<kw::rejection>()?;
let content;
syn::parenthesized!(content in input);
content
.parse()
.map(|path| Self::Rejection { rejection, path })
} else {
Err(lh.error())
let mut via = None;
let mut rejection = None;
while !input.is_empty() {
let lh = input.lookahead1();
if lh.peek(kw::via) {
parse_parenthesized_attribute(input, &mut via)?;
} else if lh.peek(kw::rejection) {
parse_parenthesized_attribute(input, &mut rejection)?;
} else {
return Err(lh.error());
}
let _ = input.parse::<Token![,]>();
}
Ok(Self { via, rejection })
}
}
enum FieldAttr {
Via { via: kw::via, path: syn::Path },
impl Combine for FromRequestContainerAttrs {
fn combine(mut self, other: Self) -> syn::Result<Self> {
let Self { via, rejection } = other;
combine_attribute(&mut self.via, via)?;
combine_attribute(&mut self.rejection, rejection)?;
Ok(self)
}
}
impl Parse for FieldAttr {
#[derive(Default)]
pub(super) struct FromRequestFieldAttrs {
pub(super) via: Option<(kw::via, syn::Path)>,
}
impl Parse for FromRequestFieldAttrs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lh = input.lookahead1();
if lh.peek(kw::via) {
let via = input.parse::<kw::via>()?;
let content;
syn::parenthesized!(content in input);
content.parse().map(|path| Self::Via { via, path })
} else {
Err(lh.error())
let mut via = None;
while !input.is_empty() {
let lh = input.lookahead1();
if lh.peek(kw::via) {
parse_parenthesized_attribute(input, &mut via)?;
} else {
return Err(lh.error());
}
let _ = input.parse::<Token![,]>();
}
Ok(Self { via })
}
}
impl Combine for FromRequestFieldAttrs {
fn combine(mut self, other: Self) -> syn::Result<Self> {
let Self { via } = other;
combine_attribute(&mut self.via, via)?;
Ok(self)
}
}

View file

@ -47,6 +47,7 @@ use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::parse::Parse;
mod attr_parsing;
mod debug_handler;
mod from_request;
mod typed_path;

View file

@ -1,8 +1,8 @@
error: cannot use `rejection` without `via`
--> tests/from_request/fail/override_rejection_on_enum_without_via.rs:18:26
--> tests/from_request/fail/override_rejection_on_enum_without_via.rs:18:16
|
18 | #[from_request(rejection(MyRejection))]
| ^^^^^^^^^^^
| ^^^^^^^^^
error[E0277]: the trait bound `fn(MyExtractor) -> impl Future<Output = ()> {handler}: Handler<_, _, _>` is not satisfied
--> tests/from_request/fail/override_rejection_on_enum_without_via.rs:10:50