From 6bd726f4a92958eea0173ca2c951f7e76c8cb87c Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 6 May 2022 16:53:12 +0200 Subject: [PATCH] Refactor `#[derive(FromRequest)]` to make illegal state unrepresentable (#1004) --- axum-macros/src/from_request.rs | 20 ++++------ axum-macros/src/from_request/attr.rs | 58 +++++++++++++++------------- 2 files changed, 40 insertions(+), 38 deletions(-) diff --git a/axum-macros/src/from_request.rs b/axum-macros/src/from_request.rs index 158719d3..901b2688 100644 --- a/axum-macros/src/from_request.rs +++ b/axum-macros/src/from_request.rs @@ -30,18 +30,14 @@ pub(crate) fn expand(item: syn::ItemStruct) -> syn::Result { return Err(syn::Error::new_spanned(where_clause, GENERICS_ERROR)); } - let FromRequestContainerAttr { - via, - rejection_derive, - } = parse_container_attrs(&attrs)?; - - if let Some((_, path)) = via { - impl_by_extracting_all_at_once(ident, fields, path) - } else { - let rejection_derive_opt_outs = rejection_derive - .map(|(_, opt_outs)| opt_outs) - .unwrap_or_default(); - impl_by_extracting_each_field(ident, fields, vis, rejection_derive_opt_outs) + match parse_container_attrs(&attrs)? { + FromRequestContainerAttr::Via(path) => impl_by_extracting_all_at_once(ident, fields, path), + FromRequestContainerAttr::RejectionDerive(opt_outs) => { + impl_by_extracting_each_field(ident, fields, vis, opt_outs) + } + FromRequestContainerAttr::None => { + impl_by_extracting_each_field(ident, fields, vis, RejectionDeriveOptOuts::default()) + } } } diff --git a/axum-macros/src/from_request/attr.rs b/axum-macros/src/from_request/attr.rs index 35b63d79..25c989a0 100644 --- a/axum-macros/src/from_request/attr.rs +++ b/axum-macros/src/from_request/attr.rs @@ -10,10 +10,10 @@ pub(crate) struct FromRequestFieldAttr { pub(crate) via: Option<(kw::via, syn::Path)>, } -#[derive(Default)] -pub(crate) struct FromRequestContainerAttr { - pub(crate) via: Option<(kw::via, syn::Path)>, - pub(crate) rejection_derive: Option<(kw::rejection_derive, RejectionDeriveOptOuts)>, +pub(crate) enum FromRequestContainerAttr { + Via(syn::Path), + RejectionDerive(RejectionDeriveOptOuts), + None, } pub(crate) mod kw { @@ -47,47 +47,53 @@ pub(crate) fn parse_field_attrs(attrs: &[syn::Attribute]) -> syn::Result syn::Result { - let attrs = parse_attrs(attrs)?; + let attrs = parse_attrs::(attrs)?; - let mut out = FromRequestContainerAttr::default(); + let mut out_via = None; + let mut out_rejection_derive = None; - for from_request_attr in attrs { + // we track the index of the attribute to know which comes last + // used to give more accurate error messages + for (idx, from_request_attr) in attrs.into_iter().enumerate() { match from_request_attr { ContainerAttr::Via { via, path } => { - if out.rejection_derive.is_some() { - return Err(syn::Error::new_spanned( - via, - "cannot use both `rejection_derive` and `via`", - )); - } - - if out.via.is_some() { + if out_via.is_some() { return Err(double_attr_error("via", via)); } else { - out.via = Some((via, path)); + out_via = Some((idx, via, path)); } } ContainerAttr::RejectionDerive { rejection_derive, opt_outs, } => { - if out.via.is_some() { - return Err(syn::Error::new_spanned( - rejection_derive, - "cannot use both `via` and `rejection_derive`", - )); - } - - if out.rejection_derive.is_some() { + if out_rejection_derive.is_some() { return Err(double_attr_error("rejection_derive", rejection_derive)); } else { - out.rejection_derive = Some((rejection_derive, opt_outs)); + out_rejection_derive = Some((idx, rejection_derive, opt_outs)); } } } } - Ok(out) + match (out_via, out_rejection_derive) { + (Some((via_idx, via, _)), Some((rejection_derive_idx, rejection_derive, _))) => { + if via_idx > rejection_derive_idx { + Err(syn::Error::new_spanned( + via, + "cannot use both `rejection_derive` and `via`", + )) + } else { + Err(syn::Error::new_spanned( + rejection_derive, + "cannot use both `via` and `rejection_derive`", + )) + } + } + (Some((_, _, path)), None) => Ok(FromRequestContainerAttr::Via(path)), + (None, Some((_, _, opt_outs))) => Ok(FromRequestContainerAttr::RejectionDerive(opt_outs)), + (None, None) => Ok(FromRequestContainerAttr::None), + } } pub(crate) fn parse_attrs(attrs: &[syn::Attribute]) -> syn::Result>