Support opt-out of extra derived traits for rejections for #[derive(FromRequest)] (#729)

* Handle structs without fields

* Support opt-out of derived rejection traits

* Handle duplicate opt outs

* Improve error if opting out of `Display` or `Debug` but not `Error`

* document `rejection_derive`

* Handle using both `via` and `rejection_derive`

* don't derive debug for `RejectionDeriveOptOuts`

* Update axum-macros/src/from_request.rs

Co-authored-by: Jonas Platte <jplatte@users.noreply.github.com>

Co-authored-by: Jonas Platte <jplatte@users.noreply.github.com>
This commit is contained in:
David Pedersen 2022-01-28 10:54:38 +01:00 committed by GitHub
parent f6fc5ed80c
commit 911c4a788e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 504 additions and 91 deletions

View file

@ -1,12 +1,13 @@
use self::attr::{
parse_container_attrs, parse_field_attrs, FromRequestContainerAttr, FromRequestFieldAttr,
RejectionDeriveOptOuts,
};
use heck::ToUpperCamelCase;
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned};
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
spanned::Spanned,
Token,
};
use syn::{punctuated::Punctuated, spanned::Spanned, Token};
mod attr;
const GENERICS_ERROR: &str = "`#[derive(FromRequest)] doesn't support generics";
@ -29,12 +30,18 @@ pub(crate) fn expand(item: syn::ItemStruct) -> syn::Result<TokenStream> {
return Err(syn::Error::new_spanned(where_clause, GENERICS_ERROR));
}
let FromRequestAttrs { via } = parse_attrs(&attrs)?;
let FromRequestContainerAttr {
via,
rejection_derive,
} = parse_container_attrs(&attrs)?;
if let Some((_, path)) = via {
impl_by_extracting_all_at_once(ident, fields, path)
} else {
impl_by_extracting_each_field(ident, fields, vis)
let rejection_derive_opt_outs = rejection_derive
.map(|(_, opt_outs)| opt_outs)
.unwrap_or_default();
impl_by_extracting_each_field(ident, fields, vis, rejection_derive_opt_outs)
}
}
@ -42,15 +49,17 @@ fn impl_by_extracting_each_field(
ident: syn::Ident,
fields: syn::Fields,
vis: syn::Visibility,
rejection_derive_opt_outs: RejectionDeriveOptOuts,
) -> syn::Result<TokenStream> {
let extract_fields = extract_fields(&fields)?;
let (rejection_ident, rejection) = if let syn::Fields::Unit = &fields {
(syn::parse_quote!(::std::convert::Infallible), quote! {})
let (rejection_ident, rejection) = if has_no_fields(&fields) {
(syn::parse_quote!(::std::convert::Infallible), None)
} else {
let rejection_ident = rejection_ident(&ident);
let rejection = extract_each_field_rejection(&ident, &fields, &vis)?;
(rejection_ident, rejection)
let rejection =
extract_each_field_rejection(&ident, &fields, &vis, rejection_derive_opt_outs)?;
(rejection_ident, Some(rejection))
};
Ok(quote! {
@ -77,6 +86,14 @@ fn impl_by_extracting_each_field(
})
}
fn has_no_fields(fields: &syn::Fields) -> bool {
match fields {
syn::Fields::Named(fields) => fields.named.is_empty(),
syn::Fields::Unnamed(fields) => fields.unnamed.is_empty(),
syn::Fields::Unit => true,
}
}
fn rejection_ident(ident: &syn::Ident) -> syn::Type {
let ident = format_ident!("{}Rejection", ident);
syn::parse_quote!(#ident)
@ -87,7 +104,7 @@ fn extract_fields(fields: &syn::Fields) -> syn::Result<Vec<TokenStream>> {
.iter()
.enumerate()
.map(|(index, field)| {
let FromRequestAttrs { via } = parse_attrs(&field.attrs)?;
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
let member = if let Some(ident) = &field.ident {
quote! { #ident }
@ -211,13 +228,14 @@ fn extract_each_field_rejection(
ident: &syn::Ident,
fields: &syn::Fields,
vis: &syn::Visibility,
rejection_derive_opt_outs: RejectionDeriveOptOuts,
) -> syn::Result<TokenStream> {
let rejection_ident = rejection_ident(ident);
let variants = fields
.iter()
.map(|field| {
let FromRequestAttrs { via } = parse_attrs(&field.attrs)?;
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
let field_ty = &field.ty;
let ty_span = field_ty.span();
@ -270,7 +288,7 @@ fn extract_each_field_rejection(
}
};
let impl_display = {
let impl_display = if rejection_derive_opt_outs.derive_display() {
let arms = fields
.iter()
.map(|field| {
@ -281,7 +299,7 @@ fn extract_each_field_rejection(
})
.collect::<syn::Result<Vec<_>>>()?;
quote! {
Some(quote! {
#[automatically_derived]
impl ::std::fmt::Display for #rejection_ident {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
@ -290,10 +308,12 @@ fn extract_each_field_rejection(
}
}
}
}
})
} else {
None
};
let impl_error = {
let impl_error = if rejection_derive_opt_outs.derive_error() {
let arms = fields
.iter()
.map(|field| {
@ -304,7 +324,7 @@ fn extract_each_field_rejection(
})
.collect::<syn::Result<Vec<_>>>()?;
quote! {
Some(quote! {
#[automatically_derived]
impl ::std::error::Error for #rejection_ident {
fn source(&self) -> ::std::option::Option<&(dyn ::std::error::Error + 'static)> {
@ -313,11 +333,17 @@ fn extract_each_field_rejection(
}
}
}
}
})
} else {
None
};
let impl_debug = rejection_derive_opt_outs.derive_debug().then(|| {
quote! { #[derive(Debug)] }
});
Ok(quote! {
#[derive(Debug)]
#impl_debug
#vis enum #rejection_ident {
#(#variants)*
}
@ -381,7 +407,7 @@ fn rejection_variant_name(field: &syn::Field) -> syn::Result<syn::Ident> {
let mut out = String::new();
rejection_variant_name_for_type(&mut out, &field.ty)?;
let FromRequestAttrs { via } = parse_attrs(&field.attrs)?;
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
if let Some((_, path)) = via {
let via_ident = &path.segments.last().unwrap().ident;
Ok(format_ident!("{}{}", via_ident, out))
@ -403,7 +429,7 @@ fn impl_by_extracting_all_at_once(
};
for field in fields {
let FromRequestAttrs { via } = parse_attrs(&field.attrs)?;
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
if let Some((via, _)) = via {
return Err(syn::Error::new_spanned(
via,
@ -437,72 +463,6 @@ fn impl_by_extracting_all_at_once(
})
}
#[derive(Default)]
struct FromRequestAttrs {
via: Option<(kw::via, syn::Path)>,
}
mod kw {
syn::custom_keyword!(via);
}
fn parse_attrs(attrs: &[syn::Attribute]) -> syn::Result<FromRequestAttrs> {
enum Attr {
FromRequest(Punctuated<FromRequestAttr, Token![,]>),
}
enum FromRequestAttr {
Via { via: kw::via, path: syn::Path },
}
impl Parse for FromRequestAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lh = input.lookahead1();
if lh.peek(kw::via) {
let via = input.parse::<kw::via>()?;
let content;
syn::parenthesized!(content in input);
content.parse().map(|path| Self::Via { via, path })
} else {
Err(lh.error())
}
}
}
let attrs = attrs
.iter()
.filter(|attr| attr.path.is_ident("from_request"))
.map(|attr| {
attr.parse_args_with(Punctuated::parse_terminated)
.map(Attr::FromRequest)
})
.collect::<syn::Result<Vec<_>>>()?;
let mut out = FromRequestAttrs::default();
for attr in attrs {
match attr {
Attr::FromRequest(from_request_attrs) => {
for from_request_attr in from_request_attrs {
match from_request_attr {
FromRequestAttr::Via { via, path } => {
if out.via.is_some() {
return Err(syn::Error::new_spanned(
via,
"`via` specified more than once",
));
} else {
out.via = Some((via, path));
}
}
}
}
}
}
}
Ok(out)
}
#[test]
fn ui() {
#[rustversion::stable]

View file

@ -0,0 +1,243 @@
use quote::ToTokens;
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
Token,
};
#[derive(Default)]
pub(crate) struct FromRequestFieldAttr {
pub(crate) via: Option<(kw::via, syn::Path)>,
}
#[derive(Default)]
pub(crate) struct FromRequestContainerAttr {
pub(crate) via: Option<(kw::via, syn::Path)>,
pub(crate) rejection_derive: Option<(kw::rejection_derive, RejectionDeriveOptOuts)>,
}
pub(crate) mod kw {
syn::custom_keyword!(via);
syn::custom_keyword!(rejection_derive);
syn::custom_keyword!(Display);
syn::custom_keyword!(Debug);
syn::custom_keyword!(Error);
}
pub(crate) fn parse_field_attrs(attrs: &[syn::Attribute]) -> syn::Result<FromRequestFieldAttr> {
let attrs = parse_attrs(attrs)?;
let mut out = FromRequestFieldAttr::default();
for from_request_attr in attrs {
match from_request_attr {
FieldAttr::Via { via, path } => {
if out.via.is_some() {
return Err(double_attr_error("via", via));
} else {
out.via = Some((via, path));
}
}
}
}
Ok(out)
}
pub(crate) fn parse_container_attrs(
attrs: &[syn::Attribute],
) -> syn::Result<FromRequestContainerAttr> {
let attrs = parse_attrs(attrs)?;
let mut out = FromRequestContainerAttr::default();
for from_request_attr in attrs {
match from_request_attr {
ContainerAttr::Via { via, path } => {
if out.rejection_derive.is_some() {
return Err(syn::Error::new_spanned(
via,
"cannot use both `rejection_derive` and `via`",
));
}
if out.via.is_some() {
return Err(double_attr_error("via", via));
} else {
out.via = Some((via, path));
}
}
ContainerAttr::RejectionDerive {
rejection_derive,
opt_outs,
} => {
if out.via.is_some() {
return Err(syn::Error::new_spanned(
rejection_derive,
"cannot use both `via` and `rejection_derive`",
));
}
if out.rejection_derive.is_some() {
return Err(double_attr_error("rejection_derive", rejection_derive));
} else {
out.rejection_derive = Some((rejection_derive, opt_outs));
}
}
}
}
Ok(out)
}
pub(crate) fn parse_attrs<T>(attrs: &[syn::Attribute]) -> syn::Result<Punctuated<T, Token![,]>>
where
T: Parse,
{
let attrs = attrs
.iter()
.filter(|attr| attr.path.is_ident("from_request"))
.map(|attr| attr.parse_args_with(Punctuated::<T, Token![,]>::parse_terminated))
.collect::<syn::Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect::<Punctuated<T, Token![,]>>();
Ok(attrs)
}
fn double_attr_error<T>(ident: &str, spanned: T) -> syn::Error
where
T: ToTokens,
{
syn::Error::new_spanned(spanned, format!("`{}` specified more than once", ident))
}
enum ContainerAttr {
Via {
via: kw::via,
path: syn::Path,
},
RejectionDerive {
rejection_derive: kw::rejection_derive,
opt_outs: RejectionDeriveOptOuts,
},
}
impl Parse for ContainerAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lh = input.lookahead1();
if lh.peek(kw::via) {
let via = input.parse::<kw::via>()?;
let content;
syn::parenthesized!(content in input);
content.parse().map(|path| Self::Via { via, path })
} else if lh.peek(kw::rejection_derive) {
let rejection_derive = input.parse::<kw::rejection_derive>()?;
let content;
syn::parenthesized!(content in input);
content.parse().map(|opt_outs| Self::RejectionDerive {
rejection_derive,
opt_outs,
})
} else {
Err(lh.error())
}
}
}
enum FieldAttr {
Via { via: kw::via, path: syn::Path },
}
impl Parse for FieldAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lh = input.lookahead1();
if lh.peek(kw::via) {
let via = input.parse::<kw::via>()?;
let content;
syn::parenthesized!(content in input);
content.parse().map(|path| Self::Via { via, path })
} else {
Err(lh.error())
}
}
}
#[derive(Default)]
pub(crate) struct RejectionDeriveOptOuts {
debug: Option<kw::Debug>,
display: Option<kw::Display>,
error: Option<kw::Error>,
}
impl RejectionDeriveOptOuts {
pub(crate) fn derive_debug(&self) -> bool {
self.debug.is_none()
}
pub(crate) fn derive_display(&self) -> bool {
self.display.is_none()
}
pub(crate) fn derive_error(&self) -> bool {
self.error.is_none()
}
}
impl Parse for RejectionDeriveOptOuts {
fn parse(input: ParseStream) -> syn::Result<Self> {
fn parse_opt_out<T>(out: &mut Option<T>, ident: &str, input: ParseStream) -> syn::Result<()>
where
T: Parse,
{
if out.is_some() {
Err(input.error(format!("`{}` opt out specified more than once", ident)))
} else {
*out = Some(input.parse()?);
Ok(())
}
}
let mut debug = None::<kw::Debug>;
let mut display = None::<kw::Display>;
let mut error = None::<kw::Error>;
while !input.is_empty() {
input.parse::<Token![!]>()?;
let lh = input.lookahead1();
if lh.peek(kw::Debug) {
parse_opt_out(&mut debug, "Debug", input)?;
} else if lh.peek(kw::Display) {
parse_opt_out(&mut display, "Display", input)?;
} else if lh.peek(kw::Error) {
parse_opt_out(&mut error, "Error", input)?;
} else {
return Err(lh.error());
}
input.parse::<Token![,]>().ok();
}
if error.is_none() {
match (debug, display) {
(Some(debug), Some(_)) => {
return Err(syn::Error::new_spanned(debug, "opt out of `Debug` and `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Display, !Error))]`"));
}
(Some(debug), None) => {
return Err(syn::Error::new_spanned(debug, "opt out of `Debug` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Error))]`"));
}
(None, Some(display)) => {
return Err(syn::Error::new_spanned(display, "opt out of `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Display, !Error))]`"));
}
(None, None) => {}
}
}
Ok(Self {
debug,
display,
error,
})
}
}

View file

@ -201,6 +201,41 @@ mod from_request;
/// means the inner rejection types must themselves implement `std::error::Error`. All extractors
/// in axum does this.
///
/// You can opt out of this using `#[from_request(rejection_derive(...))]`:
///
/// ```
/// use axum_macros::FromRequest;
/// use axum::{
/// extract::{FromRequest, RequestParts},
/// http::StatusCode,
/// headers::ContentType,
/// body::Bytes,
/// async_trait,
/// };
///
/// #[derive(FromRequest)]
/// #[from_request(rejection_derive(!Display, !Error))]
/// struct MyExtractor {
/// other: OtherExtractor,
/// }
///
/// struct OtherExtractor;
///
/// #[async_trait]
/// impl<B> FromRequest<B> for OtherExtractor
/// where
/// B: Send + 'static,
/// {
/// // this rejection doesn't implement `Display` and `Error`
/// type Rejection = (StatusCode, String);
///
/// async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
/// // ...
/// # unimplemented!()
/// }
/// }
/// ```
///
/// # The whole type at once
///
/// By using `#[from_request(via(...))]` on the container you can extract the whole type at once,

View file

@ -0,0 +1,9 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Debug, !Display))]
struct Extractor {
body: String,
}
fn main() {}

View file

@ -0,0 +1,5 @@
error: opt out of `Debug` and `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Display, !Error))]`
--> tests/from_request/fail/derive_opt_out_debug_and_display_without_error.rs:4:34
|
4 | #[from_request(rejection_derive(!Debug, !Display))]
| ^^^^^

View file

@ -0,0 +1,9 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Debug))]
struct Extractor {
body: String,
}
fn main() {}

View file

@ -0,0 +1,5 @@
error: opt out of `Debug` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Debug, !Error))]`
--> tests/from_request/fail/derive_opt_out_debug_without_error.rs:4:34
|
4 | #[from_request(rejection_derive(!Debug))]
| ^^^^^

View file

@ -0,0 +1,9 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Display))]
struct Extractor {
body: String,
}
fn main() {}

View file

@ -0,0 +1,5 @@
error: opt out of `Display` requires also opting out of `Error`. Use `#[from_request(rejection_derive(!Display, !Error))]`
--> tests/from_request/fail/derive_opt_out_display_without_error.rs:4:34
|
4 | #[from_request(rejection_derive(!Display))]
| ^^^^^^^

View file

@ -0,0 +1,9 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Error, !Error))]
struct Extractor {
body: String,
}
fn main() {}

View file

@ -0,0 +1,5 @@
error: `Error` opt out specified more than once
--> tests/from_request/fail/derive_opt_out_duplicate.rs:4:42
|
4 | #[from_request(rejection_derive(!Error, !Error))]
| ^^^^^

View file

@ -0,0 +1,10 @@
use axum_macros::FromRequest;
use axum::extract::Extension;
#[derive(FromRequest, Clone)]
#[from_request(rejection_derive(!Error), via(Extension))]
struct Extractor {
config: String,
}
fn main() {}

View file

@ -0,0 +1,13 @@
error: cannot use both `rejection_derive` and `via`
--> tests/from_request/fail/rejection_derive_and_via.rs:5:42
|
5 | #[from_request(rejection_derive(!Error), via(Extension))]
| ^^^
warning: unused import: `axum::extract::Extension`
--> tests/from_request/fail/rejection_derive_and_via.rs:2:5
|
2 | use axum::extract::Extension;
| ^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default

View file

@ -0,0 +1,7 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(foo)]
struct Extractor;
fn main() {}

View file

@ -0,0 +1,5 @@
error: expected `via` or `rejection_derive`
--> tests/from_request/fail/unknown_attr_container.rs:4:16
|
4 | #[from_request(foo)]
| ^^^

View file

@ -1,5 +1,5 @@
error: expected `via`
--> tests/from_request/fail/unknown_attr.rs:4:33
--> tests/from_request/fail/unknown_attr_field.rs:4:33
|
4 | struct Extractor(#[from_request(foo)] String);
| ^^^

View file

@ -0,0 +1,10 @@
use axum_macros::FromRequest;
use axum::extract::Extension;
#[derive(FromRequest, Clone)]
#[from_request(via(Extension), rejection_derive(!Error))]
struct Extractor {
config: String,
}
fn main() {}

View file

@ -0,0 +1,13 @@
error: cannot use both `via` and `rejection_derive`
--> tests/from_request/fail/via_and_rejection_derive.rs:5:32
|
5 | #[from_request(via(Extension), rejection_derive(!Error))]
| ^^^^^^^^^^^^^^^^
warning: unused import: `axum::extract::Extension`
--> tests/from_request/fail/via_and_rejection_derive.rs:2:5
|
2 | use axum::extract::Extension;
| ^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default

View file

@ -0,0 +1,37 @@
use axum::{
async_trait,
extract::{FromRequest, RequestParts},
response::{IntoResponse, Response},
};
use axum_macros::FromRequest;
#[derive(FromRequest)]
#[from_request(rejection_derive(!Display, !Error))]
struct Extractor {
other: OtherExtractor,
}
struct OtherExtractor;
#[async_trait]
impl<B> FromRequest<B> for OtherExtractor
where
B: Send + 'static,
{
type Rejection = OtherExtractorRejection;
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}
#[derive(Debug)]
struct OtherExtractorRejection;
impl IntoResponse for OtherExtractorRejection {
fn into_response(self) -> Response {
unimplemented!()
}
}
fn main() {}

View file

@ -0,0 +1,12 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
struct Extractor {}
fn assert_from_request()
where
Extractor: axum::extract::FromRequest<axum::body::Body, Rejection = std::convert::Infallible>,
{
}
fn main() {}

View file

@ -0,0 +1,12 @@
use axum_macros::FromRequest;
#[derive(FromRequest)]
struct Extractor();
fn assert_from_request()
where
Extractor: axum::extract::FromRequest<axum::body::Body, Rejection = std::convert::Infallible>,
{
}
fn main() {}

View file

@ -5,7 +5,7 @@ struct Extractor;
fn assert_from_request()
where
Extractor: axum::extract::FromRequest<axum::body::Body>,
Extractor: axum::extract::FromRequest<axum::body::Body, Rejection = std::convert::Infallible>,
{
}