diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index e3b4a885..66abd6ff 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -1,8 +1,10 @@ +use std::collections::HashSet; + use crate::{ attr_parsing::{parse_assignment_attribute, second}, with_position::{Position, WithPosition}, }; -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, quote_spanned}; use syn::{parse::Parse, parse_quote, spanned::Spanned, FnArg, ItemFn, Token, Type}; @@ -22,20 +24,38 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { // If the function is generic, we can't reliably check its inputs or whether the future it // returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors. let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() { + let mut err = None; + if state_ty.is_none() { - state_ty = state_type_from_args(&item_fn); + let state_types_from_args = state_types_from_args(&item_fn); + + #[allow(clippy::comparison_chain)] + if state_types_from_args.len() == 1 { + state_ty = state_types_from_args.into_iter().next(); + } else if state_types_from_args.len() > 1 { + err = Some( + syn::Error::new( + Span::call_site(), + "can't infer state type, please add set it explicitly, as in \ + `#[debug_handler(state = MyStateType)]`", + ) + .into_compile_error(), + ); + } } - let state_ty = state_ty.unwrap_or_else(|| syn::parse_quote!(())); + err.unwrap_or_else(|| { + let state_ty = state_ty.unwrap_or_else(|| syn::parse_quote!(())); - let check_inputs_impls_from_request = - check_inputs_impls_from_request(&item_fn, &body_ty, state_ty); - let check_future_send = check_future_send(&item_fn); + let check_inputs_impls_from_request = + check_inputs_impls_from_request(&item_fn, &body_ty, state_ty); + let check_future_send = check_future_send(&item_fn); - quote! { - #check_inputs_impls_from_request - #check_future_send - } + quote! { + #check_inputs_impls_from_request + #check_future_send + } + }) } else { syn::Error::new_spanned( &item_fn.sig.generics, @@ -433,7 +453,7 @@ fn self_receiver(item_fn: &ItemFn) -> Option { /// This will extract `AppState`. /// /// Returns `None` if there are no `State` args or multiple of different types. -fn state_type_from_args(item_fn: &ItemFn) -> Option { +fn state_types_from_args(item_fn: &ItemFn) -> HashSet { let types = item_fn .sig .inputs @@ -443,7 +463,7 @@ fn state_type_from_args(item_fn: &ItemFn) -> Option { FnArg::Typed(pat_type) => Some(pat_type), }) .map(|pat_type| &*pat_type.ty); - crate::infer_state_type(types) + crate::infer_state_types(types).collect() } #[test] diff --git a/axum-macros/src/from_request.rs b/axum-macros/src/from_request.rs index 6803a2c2..7f6e76c9 100644 --- a/axum-macros/src/from_request.rs +++ b/axum-macros/src/from_request.rs @@ -47,6 +47,7 @@ impl fmt::Display for Trait { enum State { Custom(syn::Type), Default(syn::Type), + CannotInfer, } impl State { @@ -58,6 +59,7 @@ impl State { match self { State::Default(inner) => Some(inner.clone()), State::Custom(_) => None, + State::CannotInfer => Some(parse_quote!(S)), } .into_iter() } @@ -70,6 +72,7 @@ impl State { match self { State::Default(inner) => iter::once(inner.clone()), State::Custom(inner) => iter::once(inner.clone()), + State::CannotInfer => iter::once(parse_quote!(S)), } } @@ -79,6 +82,9 @@ impl State { State::Default(inner) => quote! { #inner: ::std::marker::Send + ::std::marker::Sync, }, + State::CannotInfer => quote! { + S: ::std::marker::Send + ::std::marker::Sync, + }, } } } @@ -88,6 +94,7 @@ impl ToTokens for State { match self { State::Custom(inner) => inner.to_tokens(tokens), State::Default(inner) => inner.to_tokens(tokens), + State::CannotInfer => quote! { S }.to_tokens(tokens), } } } @@ -115,30 +122,60 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result { let state = match state { Some((_, state)) => State::Custom(state), - None => infer_state_type_from_field_types(&fields) - .map(State::Custom) - .or_else(|| infer_state_type_from_field_attributes(&fields).map(State::Custom)) - .or_else(|| { - let via = via.as_ref().map(|(_, via)| via)?; - state_from_via(&ident, via).map(State::Custom) - }) - .unwrap_or_else(|| State::Default(syn::parse_quote!(S))), + None => { + let mut inferred_state_types: HashSet<_> = + infer_state_type_from_field_types(&fields) + .chain(infer_state_type_from_field_attributes(&fields)) + .collect(); + + if let Some((_, via)) = &via { + inferred_state_types.extend(state_from_via(&ident, via)); + } + + match inferred_state_types.len() { + 0 => State::Default(syn::parse_quote!(S)), + 1 => State::Custom(inferred_state_types.iter().next().unwrap().to_owned()), + _ => State::CannotInfer, + } + } }; - match (via.map(second), rejection.map(second)) { + let trait_impl = match (via.map(second), rejection.map(second)) { (Some(via), rejection) => impl_struct_by_extracting_all_at_once( ident, fields, via, rejection, generic_ident, - state, + &state, tr, - ), + )?, (None, rejection) => { error_on_generic_ident(generic_ident, tr)?; - impl_struct_by_extracting_each_field(ident, fields, rejection, state, tr) + impl_struct_by_extracting_each_field(ident, fields, rejection, &state, tr)? } + }; + + if let State::CannotInfer = state { + let attr_name = match tr { + Trait::FromRequest => "from_request", + Trait::FromRequestParts => "from_request_parts", + }; + let compile_error = syn::Error::new( + Span::call_site(), + format_args!( + "can't infer state type, please add \ + `#[{attr_name}(state = MyStateType)]` attribute", + ), + ) + .into_compile_error(); + + Ok(quote! { + #trait_impl + #compile_error + }) + } else { + Ok(trait_impl) } } syn::Item::Enum(item) => { @@ -308,10 +345,22 @@ fn impl_struct_by_extracting_each_field( ident: syn::Ident, fields: syn::Fields, rejection: Option, - state: State, + state: &State, tr: Trait, ) -> syn::Result { - let extract_fields = extract_fields(&fields, &rejection, tr)?; + let trait_fn_body = match state { + State::CannotInfer => quote! { + ::std::unimplemented!() + }, + _ => { + let extract_fields = extract_fields(&fields, &rejection, tr)?; + quote! { + ::std::result::Result::Ok(Self { + #(#extract_fields)* + }) + } + } + }; let rejection_ident = if let Some(rejection) = rejection { quote!(#rejection) @@ -350,9 +399,7 @@ fn impl_struct_by_extracting_each_field( mut req: ::axum::http::Request, state: &#state, ) -> ::std::result::Result { - ::std::result::Result::Ok(Self { - #(#extract_fields)* - }) + #trait_fn_body } } }, @@ -369,9 +416,7 @@ fn impl_struct_by_extracting_each_field( parts: &mut ::axum::http::request::Parts, state: &#state, ) -> ::std::result::Result { - ::std::result::Result::Ok(Self { - #(#extract_fields)* - }) + #trait_fn_body } } }, @@ -661,7 +706,7 @@ fn impl_struct_by_extracting_all_at_once( via_path: syn::Path, rejection: Option, generic_ident: Option, - state: State, + state: &State, tr: Trait, ) -> syn::Result { let fields = match fields { @@ -952,15 +997,15 @@ fn impl_enum_by_extracting_all_at_once( /// ``` /// /// We can infer the state type to be `AppState` because it appears inside a `State` -fn infer_state_type_from_field_types(fields: &Fields) -> Option { +fn infer_state_type_from_field_types(fields: &Fields) -> impl Iterator + '_ { match fields { - Fields::Named(fields_named) => { - crate::infer_state_type(fields_named.named.iter().map(|field| &field.ty)) - } - Fields::Unnamed(fields_unnamed) => { - crate::infer_state_type(fields_unnamed.unnamed.iter().map(|field| &field.ty)) - } - Fields::Unit => None, + Fields::Named(fields_named) => Box::new(crate::infer_state_types( + fields_named.named.iter().map(|field| &field.ty), + )) as Box>, + Fields::Unnamed(fields_unnamed) => Box::new(crate::infer_state_types( + fields_unnamed.unnamed.iter().map(|field| &field.ty), + )), + Fields::Unit => Box::new(iter::empty()), } } @@ -975,43 +1020,29 @@ fn infer_state_type_from_field_types(fields: &Fields) -> Option { /// /// We can infer the state type to be `AppState` because it has `via(State)` and thus can be /// extracted with `State` -fn infer_state_type_from_field_attributes(fields: &Fields) -> Option { - let state_inputs = match fields { +fn infer_state_type_from_field_attributes(fields: &Fields) -> impl Iterator + '_ { + match fields { Fields::Named(fields_named) => { - fields_named - .named - .iter() - .filter_map(|field| { - // TODO(david): its a little wasteful to parse the attributes again here - // ideally we should parse things once and pass the data down - let FromRequestFieldAttrs { via } = - parse_attrs("from_request", &field.attrs).ok()?; - let (_, via_path) = via?; - path_ident_is_state(&via_path).then(|| &field.ty) - }) - .collect::>() + Box::new(fields_named.named.iter().filter_map(|field| { + // TODO(david): its a little wasteful to parse the attributes again here + // ideally we should parse things once and pass the data down + let FromRequestFieldAttrs { via } = + parse_attrs("from_request", &field.attrs).ok()?; + let (_, via_path) = via?; + path_ident_is_state(&via_path).then(|| field.ty.clone()) + })) as Box> } Fields::Unnamed(fields_unnamed) => { - fields_unnamed - .unnamed - .iter() - .filter_map(|field| { - // TODO(david): its a little wasteful to parse the attributes again here - // ideally we should parse things once and pass the data down - let FromRequestFieldAttrs { via } = - parse_attrs("from_request", &field.attrs).ok()?; - let (_, via_path) = via?; - path_ident_is_state(&via_path).then(|| &field.ty) - }) - .collect::>() + Box::new(fields_unnamed.unnamed.iter().filter_map(|field| { + // TODO(david): its a little wasteful to parse the attributes again here + // ideally we should parse things once and pass the data down + let FromRequestFieldAttrs { via } = + parse_attrs("from_request", &field.attrs).ok()?; + let (_, via_path) = via?; + path_ident_is_state(&via_path).then(|| field.ty.clone()) + })) } - Fields::Unit => return None, - }; - - if state_inputs.len() == 1 { - state_inputs.iter().next().map(|&ty| ty.clone()) - } else { - None + Fields::Unit => Box::new(iter::empty()), } } diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index 0529315c..0644f057 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -43,8 +43,6 @@ #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(test, allow(clippy::float_cmp))] -use std::collections::HashSet; - use proc_macro::TokenStream; use quote::{quote, ToTokens}; use syn::{parse::Parse, Type}; @@ -615,11 +613,11 @@ where } } -fn infer_state_type<'a, I>(types: I) -> Option +fn infer_state_types<'a, I>(types: I) -> impl Iterator + 'a where - I: Iterator, + I: Iterator + 'a, { - let state_inputs = types + types .filter_map(|ty| { if let Type::Path(path) = ty { Some(&path.path) @@ -650,13 +648,7 @@ where None } }) - .collect::>(); - - if state_inputs.len() == 1 { - state_inputs.iter().next().map(|&ty| ty.clone()) - } else { - None - } + .cloned() } #[cfg(test)] diff --git a/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.stderr b/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.stderr index 7b9bfdb3..5ed0ddef 100644 --- a/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.stderr +++ b/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.stderr @@ -1,23 +1,7 @@ -error[E0277]: the trait bound `AppState: FromRef` is not satisfied - --> tests/from_request/fail/state_infer_multiple_different_types.rs:6:18 +error: can't infer state type, please add `#[from_request(state = MyStateType)]` attribute + --> tests/from_request/fail/state_infer_multiple_different_types.rs:4:10 | -6 | inner_state: State, - | ^^^^^ the trait `FromRef` is not implemented for `AppState` +4 | #[derive(FromRequest)] + | ^^^^^^^^^^^ | - = note: required because of the requirements on the impl of `FromRequestParts` for `State` -help: consider extending the `where` clause, but there might be an alternative better way to express this requirement - | -4 | #[derive(FromRequest, AppState: FromRef)] - | ++++++++++++++++++++++ - -error[E0277]: the trait bound `OtherState: FromRef` is not satisfied - --> tests/from_request/fail/state_infer_multiple_different_types.rs:7:18 - | -7 | other_state: State, - | ^^^^^ the trait `FromRef` is not implemented for `OtherState` - | - = note: required because of the requirements on the impl of `FromRequestParts` for `State` -help: consider extending the `where` clause, but there might be an alternative better way to express this requirement - | -4 | #[derive(FromRequest, OtherState: FromRef)] - | ++++++++++++++++++++++++ + = note: this error originates in the derive macro `FromRequest` (in Nightly builds, run with -Z macro-backtrace for more info)