use std::collections::HashSet; use crate::{ attr_parsing::{parse_assignment_attribute, second}, with_position::{Position, WithPosition}, }; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, quote_spanned}; use syn::{parse::Parse, parse_quote, spanned::Spanned, FnArg, ItemFn, Token, Type}; pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { let Attrs { body_ty, state_ty } = attr; let body_ty = body_ty .map(second) .unwrap_or_else(|| parse_quote!(axum::body::Body)); let mut state_ty = state_ty.map(second); let check_extractor_count = check_extractor_count(&item_fn); let check_path_extractor = check_path_extractor(&item_fn); let check_output_impls_into_response = check_output_impls_into_response(&item_fn); // If the function is generic, we can't reliably check its inputs or whether the future it // returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors. let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() { let mut err = None; if state_ty.is_none() { let state_types_from_args = state_types_from_args(&item_fn); #[allow(clippy::comparison_chain)] if state_types_from_args.len() == 1 { state_ty = state_types_from_args.into_iter().next(); } else if state_types_from_args.len() > 1 { err = Some( syn::Error::new( Span::call_site(), "can't infer state type, please add set it explicitly, as in \ `#[debug_handler(state = MyStateType)]`", ) .into_compile_error(), ); } } err.unwrap_or_else(|| { let state_ty = state_ty.unwrap_or_else(|| syn::parse_quote!(())); let check_inputs_impls_from_request = check_inputs_impls_from_request(&item_fn, &body_ty, state_ty); let check_future_send = check_future_send(&item_fn); quote! { #check_inputs_impls_from_request #check_future_send } }) } else { syn::Error::new_spanned( &item_fn.sig.generics, "`#[axum_macros::debug_handler]` doesn't support generic functions", ) .into_compile_error() }; quote! { #item_fn #check_extractor_count #check_path_extractor #check_output_impls_into_response #check_inputs_and_future_send } } mod kw { syn::custom_keyword!(body); syn::custom_keyword!(state); } pub(crate) struct Attrs { body_ty: Option<(kw::body, Type)>, state_ty: Option<(kw::state, Type)>, } impl Parse for Attrs { fn parse(input: syn::parse::ParseStream) -> syn::Result { let mut body_ty = None; let mut state_ty = None; while !input.is_empty() { let lh = input.lookahead1(); if lh.peek(kw::body) { parse_assignment_attribute(input, &mut body_ty)?; } else if lh.peek(kw::state) { parse_assignment_attribute(input, &mut state_ty)?; } else { return Err(lh.error()); } let _ = input.parse::(); } Ok(Self { body_ty, state_ty }) } } fn check_extractor_count(item_fn: &ItemFn) -> Option { let max_extractors = 16; if item_fn.sig.inputs.len() <= max_extractors { None } else { let error_message = format!( "Handlers cannot take more than {} arguments. \ Use `(a, b): (ExtractorA, ExtractorA)` to further nest extractors", max_extractors, ); let error = syn::Error::new_spanned(&item_fn.sig.inputs, error_message).to_compile_error(); Some(error) } } fn extractor_idents(item_fn: &ItemFn) -> impl Iterator { item_fn .sig .inputs .iter() .enumerate() .filter_map(|(idx, fn_arg)| match fn_arg { FnArg::Receiver(_) => None, FnArg::Typed(pat_type) => { if let Type::Path(type_path) = &*pat_type.ty { type_path .path .segments .last() .map(|segment| (idx, fn_arg, &segment.ident)) } else { None } } }) } fn check_path_extractor(item_fn: &ItemFn) -> TokenStream { let path_extractors = extractor_idents(item_fn) .filter(|(_, _, ident)| *ident == "Path") .collect::>(); if path_extractors.len() > 1 { path_extractors .into_iter() .map(|(_, arg, _)| { syn::Error::new_spanned( arg, "Multiple parameters must be extracted with a tuple \ `Path<(_, _)>` or a struct `Path`, not by applying \ multiple `Path<_>` extractors", ) .to_compile_error() }) .collect() } else { quote! {} } } fn is_self_pat_type(typed: &syn::PatType) -> bool { let ident = if let syn::Pat::Ident(ident) = &*typed.pat { &ident.ident } else { return false; }; ident == "self" } fn check_inputs_impls_from_request( item_fn: &ItemFn, body_ty: &Type, state_ty: Type, ) -> TokenStream { let takes_self = item_fn.sig.inputs.first().map_or(false, |arg| match arg { FnArg::Receiver(_) => true, FnArg::Typed(typed) => is_self_pat_type(typed), }); WithPosition::new(item_fn.sig.inputs.iter()) .enumerate() .map(|(idx, arg)| { let must_impl_from_request_parts = match &arg { Position::First(_) | Position::Middle(_) => true, Position::Last(_) | Position::Only(_) => false, }; let arg = arg.into_inner(); let (span, ty) = match arg { FnArg::Receiver(receiver) => { if receiver.reference.is_some() { return syn::Error::new_spanned( receiver, "Handlers must only take owned values", ) .into_compile_error(); } let span = receiver.span(); (span, syn::parse_quote!(Self)) } FnArg::Typed(typed) => { let ty = &typed.ty; let span = ty.span(); if is_self_pat_type(typed) { (span, syn::parse_quote!(Self)) } else { (span, ty.clone()) } } }; let check_fn = format_ident!( "__axum_macros_check_{}_{}_from_request_check", item_fn.sig.ident, idx, span = span, ); let call_check_fn = format_ident!( "__axum_macros_check_{}_{}_from_request_call_check", item_fn.sig.ident, idx, span = span, ); let call_check_fn_body = if takes_self { quote_spanned! {span=> Self::#check_fn(); } } else { quote_spanned! {span=> #check_fn(); } }; let check_fn_generics = if must_impl_from_request_parts { quote! {} } else { quote! { } }; let from_request_bound = if must_impl_from_request_parts { quote! { #ty: ::axum::extract::FromRequestParts<#state_ty> + Send } } else { quote! { #ty: ::axum::extract::FromRequest<#state_ty, #body_ty, M> + Send } }; quote_spanned! {span=> #[allow(warnings)] fn #check_fn #check_fn_generics() where #from_request_bound, {} // we have to call the function to actually trigger a compile error // since the function is generic, just defining it is not enough #[allow(warnings)] fn #call_check_fn() { #call_check_fn_body } } }) .collect::() } fn check_output_impls_into_response(item_fn: &ItemFn) -> TokenStream { let ty = match &item_fn.sig.output { syn::ReturnType::Default => return quote! {}, syn::ReturnType::Type(_, ty) => ty, }; let span = ty.span(); let declare_inputs = item_fn .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Receiver(_) => None, FnArg::Typed(pat_ty) => { let pat = &pat_ty.pat; let ty = &pat_ty.ty; Some(quote! { let #pat: #ty = panic!(); }) } }) .collect::(); let block = &item_fn.block; let make_value_name = format_ident!( "__axum_macros_check_{}_into_response_make_value", item_fn.sig.ident ); let make = if item_fn.sig.asyncness.is_some() { quote_spanned! {span=> #[allow(warnings)] async fn #make_value_name() -> #ty { #declare_inputs #block } } } else { quote_spanned! {span=> #[allow(warnings)] fn #make_value_name() -> #ty { #declare_inputs #block } } }; let name = format_ident!("__axum_macros_check_{}_into_response", item_fn.sig.ident); if let Some(receiver) = self_receiver(item_fn) { quote_spanned! {span=> #make #[allow(warnings)] async fn #name() { let value = #receiver #make_value_name().await; fn check(_: T) where T: ::axum::response::IntoResponse {} check(value); } } } else { quote_spanned! {span=> #[allow(warnings)] async fn #name() { #make let value = #make_value_name().await; fn check(_: T) where T: ::axum::response::IntoResponse {} check(value); } } } } fn check_future_send(item_fn: &ItemFn) -> TokenStream { if item_fn.sig.asyncness.is_none() { match &item_fn.sig.output { syn::ReturnType::Default => { return syn::Error::new_spanned( &item_fn.sig.fn_token, "Handlers must be `async fn`s", ) .into_compile_error(); } syn::ReturnType::Type(_, ty) => ty, }; } let span = item_fn.span(); let handler_name = &item_fn.sig.ident; let args = item_fn.sig.inputs.iter().map(|_| { quote_spanned! {span=> panic!() } }); let name = format_ident!("__axum_macros_check_{}_future", item_fn.sig.ident); if let Some(receiver) = self_receiver(item_fn) { quote_spanned! {span=> #[allow(warnings)] fn #name() { let future = #receiver #handler_name(#(#args),*); fn check(_: T) where T: ::std::future::Future + Send {} check(future); } } } else { quote_spanned! {span=> #[allow(warnings)] fn #name() { #item_fn let future = #handler_name(#(#args),*); fn check(_: T) where T: ::std::future::Future + Send {} check(future); } } } } fn self_receiver(item_fn: &ItemFn) -> Option { let takes_self = item_fn.sig.inputs.iter().any(|arg| match arg { FnArg::Receiver(_) => true, FnArg::Typed(typed) => is_self_pat_type(typed), }); if takes_self { return Some(quote! { Self:: }); } if let syn::ReturnType::Type(_, ty) = &item_fn.sig.output { if let syn::Type::Path(path) = &**ty { let segments = &path.path.segments; if segments.len() == 1 { if let Some(last) = segments.last() { match &last.arguments { syn::PathArguments::None if last.ident == "Self" => { return Some(quote! { Self:: }); } _ => {} } } } } } None } /// Given a signature like /// /// ```skip /// #[debug_handler] /// async fn handler( /// _: axum::extract::State, /// _: State, /// ) {} /// ``` /// /// This will extract `AppState`. /// /// Returns `None` if there are no `State` args or multiple of different types. fn state_types_from_args(item_fn: &ItemFn) -> HashSet { let types = item_fn .sig .inputs .iter() .filter_map(|input| match input { FnArg::Receiver(_) => None, FnArg::Typed(pat_type) => Some(pat_type), }) .map(|pat_type| &*pat_type.ty); crate::infer_state_types(types).collect() } #[test] fn ui() { crate::run_ui_tests("debug_handler"); }