diff --git a/axum-macros/CHANGELOG.md b/axum-macros/CHANGELOG.md index 36a8ca18..6d9fa42b 100644 --- a/axum-macros/CHANGELOG.md +++ b/axum-macros/CHANGELOG.md @@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **added:** Add `#[debug_middleware]` ([#1993], [#2725]) + +[#1993]: https://github.com/tokio-rs/axum/pull/1993 +[#2725]: https://github.com/tokio-rs/axum/pull/2725 # 0.4.1 (13. January, 2024) diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index 90e2ae92..3a37c17a 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -1,4 +1,4 @@ -use std::collections::HashSet; +use std::{collections::HashSet, fmt}; use crate::{ attr_parsing::{parse_assignment_attribute, second}, @@ -8,13 +8,13 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote, quote_spanned}; use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, ReturnType, Token, Type}; -pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { +pub(crate) fn expand(attr: Attrs, item_fn: ItemFn, kind: FunctionKind) -> TokenStream { let Attrs { state_ty } = attr; let mut state_ty = state_ty.map(second); - let check_extractor_count = check_extractor_count(&item_fn); - let check_path_extractor = check_path_extractor(&item_fn); + let check_extractor_count = check_extractor_count(&item_fn, kind); + let check_path_extractor = check_path_extractor(&item_fn, kind); let check_output_tuples = check_output_tuples(&item_fn); let check_output_impls_into_response = if check_output_tuples.is_empty() { check_output_impls_into_response(&item_fn) @@ -37,8 +37,10 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { err = Some( syn::Error::new( Span::call_site(), - "can't infer state type, please add set it explicitly, as in \ - `#[debug_handler(state = MyStateType)]`", + format!( + "can't infer state type, please add set it explicitly, as in \ + `#[axum_macros::debug_{kind}(state = MyStateType)]`" + ), ) .into_compile_error(), ); @@ -48,16 +50,16 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { err.unwrap_or_else(|| { let state_ty = state_ty.unwrap_or_else(|| syn::parse_quote!(())); - let check_future_send = check_future_send(&item_fn); + let check_future_send = check_future_send(&item_fn, kind); - if let Some(check_input_order) = check_input_order(&item_fn) { + if let Some(check_input_order) = check_input_order(&item_fn, kind) { quote! { #check_input_order #check_future_send } } else { let check_inputs_impls_from_request = - check_inputs_impls_from_request(&item_fn, state_ty); + check_inputs_impls_from_request(&item_fn, state_ty, kind); quote! { #check_inputs_impls_from_request @@ -68,17 +70,45 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { } else { syn::Error::new_spanned( &item_fn.sig.generics, - "`#[axum_macros::debug_handler]` doesn't support generic functions", + format!("`#[axum_macros::debug_{kind}]` doesn't support generic functions"), ) .into_compile_error() }; + let middleware_takes_next_as_last_arg = + matches!(kind, FunctionKind::Middleware).then(|| next_is_last_input(&item_fn)); + quote! { #item_fn #check_extractor_count #check_path_extractor #check_output_impls_into_response #check_inputs_and_future_send + #middleware_takes_next_as_last_arg + } +} + +#[derive(Clone, Copy)] +pub(crate) enum FunctionKind { + Handler, + Middleware, +} + +impl fmt::Display for FunctionKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FunctionKind::Handler => f.write_str("handler"), + FunctionKind::Middleware => f.write_str("middleware"), + } + } +} + +impl FunctionKind { + fn name_uppercase_plural(&self) -> &'static str { + match self { + FunctionKind::Handler => "Handlers", + FunctionKind::Middleware => "Middleware", + } } } @@ -110,25 +140,36 @@ impl Parse for Attrs { } } -fn check_extractor_count(item_fn: &ItemFn) -> Option { +fn check_extractor_count(item_fn: &ItemFn, kind: FunctionKind) -> Option { let max_extractors = 16; - if item_fn.sig.inputs.len() <= max_extractors { + let inputs = item_fn + .sig + .inputs + .iter() + .filter(|arg| skip_next_arg(arg, kind)) + .count(); + if inputs <= max_extractors { None } else { let error_message = format!( - "Handlers cannot take more than {max_extractors} arguments. \ + "{} cannot take more than {max_extractors} arguments. \ Use `(a, b): (ExtractorA, ExtractorA)` to further nest extractors", + kind.name_uppercase_plural(), ); let error = syn::Error::new_spanned(&item_fn.sig.inputs, error_message).to_compile_error(); Some(error) } } -fn extractor_idents(item_fn: &ItemFn) -> impl Iterator { +fn extractor_idents( + item_fn: &ItemFn, + kind: FunctionKind, +) -> impl Iterator { item_fn .sig .inputs .iter() + .filter(move |arg| skip_next_arg(arg, kind)) .enumerate() .filter_map(|(idx, fn_arg)| match fn_arg { FnArg::Receiver(_) => None, @@ -146,8 +187,8 @@ fn extractor_idents(item_fn: &ItemFn) -> impl Iterator TokenStream { - let path_extractors = extractor_idents(item_fn) +fn check_path_extractor(item_fn: &ItemFn, kind: FunctionKind) -> TokenStream { + let path_extractors = extractor_idents(item_fn, kind) .filter(|(_, _, ident)| *ident == "Path") .collect::>(); @@ -179,113 +220,122 @@ fn is_self_pat_type(typed: &syn::PatType) -> bool { ident == "self" } -fn check_inputs_impls_from_request(item_fn: &ItemFn, state_ty: Type) -> TokenStream { +fn check_inputs_impls_from_request( + item_fn: &ItemFn, + state_ty: Type, + kind: FunctionKind, +) -> TokenStream { let takes_self = item_fn.sig.inputs.first().map_or(false, |arg| match arg { FnArg::Receiver(_) => true, FnArg::Typed(typed) => is_self_pat_type(typed), }); - WithPosition::new(&item_fn.sig.inputs) - .enumerate() - .map(|(idx, arg)| { - let must_impl_from_request_parts = match &arg { - Position::First(_) | Position::Middle(_) => true, - Position::Last(_) | Position::Only(_) => false, - }; + WithPosition::new( + item_fn + .sig + .inputs + .iter() + .filter(|arg| skip_next_arg(arg, kind)), + ) + .enumerate() + .map(|(idx, arg)| { + let must_impl_from_request_parts = match &arg { + Position::First(_) | Position::Middle(_) => true, + Position::Last(_) | Position::Only(_) => false, + }; - let arg = arg.into_inner(); + let arg = arg.into_inner(); - let (span, ty) = match arg { - FnArg::Receiver(receiver) => { - if receiver.reference.is_some() { - return syn::Error::new_spanned( - receiver, - "Handlers must only take owned values", - ) - .into_compile_error(); - } + let (span, ty) = match arg { + FnArg::Receiver(receiver) => { + if receiver.reference.is_some() { + return syn::Error::new_spanned( + receiver, + "Handlers must only take owned values", + ) + .into_compile_error(); + } - let span = receiver.span(); + let span = receiver.span(); + (span, syn::parse_quote!(Self)) + } + FnArg::Typed(typed) => { + let ty = &typed.ty; + let span = ty.span(); + + if is_self_pat_type(typed) { (span, syn::parse_quote!(Self)) - } - FnArg::Typed(typed) => { - let ty = &typed.ty; - let span = ty.span(); - - if is_self_pat_type(typed) { - (span, syn::parse_quote!(Self)) - } else { - (span, ty.clone()) - } - } - }; - - let consumes_request = request_consuming_type_name(&ty).is_some(); - - let check_fn = format_ident!( - "__axum_macros_check_{}_{}_from_request_check", - item_fn.sig.ident, - idx, - span = span, - ); - - let call_check_fn = format_ident!( - "__axum_macros_check_{}_{}_from_request_call_check", - item_fn.sig.ident, - idx, - span = span, - ); - - let call_check_fn_body = if takes_self { - quote_spanned! {span=> - Self::#check_fn(); - } - } else { - quote_spanned! {span=> - #check_fn(); - } - }; - - let check_fn_generics = if must_impl_from_request_parts || consumes_request { - quote! {} - } else { - quote! { } - }; - - let from_request_bound = if must_impl_from_request_parts { - quote_spanned! {span=> - #ty: ::axum::extract::FromRequestParts<#state_ty> + Send - } - } else if consumes_request { - quote_spanned! {span=> - #ty: ::axum::extract::FromRequest<#state_ty> + Send - } - } else { - quote_spanned! {span=> - #ty: ::axum::extract::FromRequest<#state_ty, M> + Send - } - }; - - quote_spanned! {span=> - #[allow(warnings)] - #[allow(unreachable_code)] - #[doc(hidden)] - fn #check_fn #check_fn_generics() - where - #from_request_bound, - {} - - // we have to call the function to actually trigger a compile error - // since the function is generic, just defining it is not enough - #[allow(warnings)] - #[allow(unreachable_code)] - #[doc(hidden)] - fn #call_check_fn() { - #call_check_fn_body + } else { + (span, ty.clone()) } } - }) - .collect::() + }; + + let consumes_request = request_consuming_type_name(&ty).is_some(); + + let check_fn = format_ident!( + "__axum_macros_check_{}_{}_from_request_check", + item_fn.sig.ident, + idx, + span = span, + ); + + let call_check_fn = format_ident!( + "__axum_macros_check_{}_{}_from_request_call_check", + item_fn.sig.ident, + idx, + span = span, + ); + + let call_check_fn_body = if takes_self { + quote_spanned! {span=> + Self::#check_fn(); + } + } else { + quote_spanned! {span=> + #check_fn(); + } + }; + + let check_fn_generics = if must_impl_from_request_parts || consumes_request { + quote! {} + } else { + quote! { } + }; + + let from_request_bound = if must_impl_from_request_parts { + quote_spanned! {span=> + #ty: ::axum::extract::FromRequestParts<#state_ty> + Send + } + } else if consumes_request { + quote_spanned! {span=> + #ty: ::axum::extract::FromRequest<#state_ty> + Send + } + } else { + quote_spanned! {span=> + #ty: ::axum::extract::FromRequest<#state_ty, M> + Send + } + }; + + quote_spanned! {span=> + #[allow(warnings)] + #[doc(hidden)] + fn #check_fn #check_fn_generics() + where + #from_request_bound, + {} + + // we have to call the function to actually trigger a compile error + // since the function is generic, just defining it is not enough + #[allow(warnings)] + #[doc(hidden)] + fn #call_check_fn() + { + #call_check_fn_body + } + } + }) + .collect::() } fn check_output_tuples(item_fn: &ItemFn) -> TokenStream { @@ -445,11 +495,19 @@ fn check_into_response_parts(ty: &Type, ident: &Ident, index: usize) -> TokenStr } } -fn check_input_order(item_fn: &ItemFn) -> Option { +fn check_input_order(item_fn: &ItemFn, kind: FunctionKind) -> Option { + let number_of_inputs = item_fn + .sig + .inputs + .iter() + .filter(|arg| skip_next_arg(arg, kind)) + .count(); + let types_that_consume_the_request = item_fn .sig .inputs .iter() + .filter(|arg| skip_next_arg(arg, kind)) .enumerate() .filter_map(|(idx, arg)| { let ty = match arg { @@ -469,7 +527,7 @@ fn check_input_order(item_fn: &ItemFn) -> Option { // exactly one type that consumes the request if types_that_consume_the_request.len() == 1 { // and that is not the last - if types_that_consume_the_request[0].0 != item_fn.sig.inputs.len() - 1 { + if types_that_consume_the_request[0].0 != number_of_inputs - 1 { let (_idx, type_name, span) = &types_that_consume_the_request[0]; let error = format!( "`{type_name}` consumes the request body and thus must be \ @@ -653,13 +711,13 @@ fn check_output_impls_into_response(item_fn: &ItemFn) -> TokenStream { } } -fn check_future_send(item_fn: &ItemFn) -> TokenStream { +fn check_future_send(item_fn: &ItemFn, kind: FunctionKind) -> TokenStream { if item_fn.sig.asyncness.is_none() { match &item_fn.sig.output { syn::ReturnType::Default => { return syn::Error::new_spanned( item_fn.sig.fn_token, - "Handlers must be `async fn`s", + format!("{} must be `async fn`s", kind.name_uppercase_plural()), ) .into_compile_error(); } @@ -763,7 +821,69 @@ fn state_types_from_args(item_fn: &ItemFn) -> HashSet { crate::infer_state_types(types).collect() } +fn next_is_last_input(item_fn: &ItemFn) -> TokenStream { + let next_args = item_fn + .sig + .inputs + .iter() + .enumerate() + .filter(|(_, arg)| !skip_next_arg(arg, FunctionKind::Middleware)) + .collect::>(); + + if next_args.is_empty() { + return quote! { + compile_error!( + "Middleware functions must take `axum::middleware::Next` as the last argument", + ); + }; + } + + if next_args.len() == 1 { + let (idx, arg) = &next_args[0]; + if *idx != item_fn.sig.inputs.len() - 1 { + return quote_spanned! {arg.span()=> + compile_error!("`axum::middleware::Next` must the last argument"); + }; + } + } + + if next_args.len() >= 2 { + return quote! { + compile_error!( + "Middleware functions can only take one argument of type `axum::middleware::Next`", + ); + }; + } + + quote! {} +} + +fn skip_next_arg(arg: &FnArg, kind: FunctionKind) -> bool { + match kind { + FunctionKind::Handler => true, + FunctionKind::Middleware => match arg { + FnArg::Receiver(_) => true, + FnArg::Typed(pat_type) => { + if let Type::Path(type_path) = &*pat_type.ty { + type_path + .path + .segments + .last() + .map_or(true, |path_segment| path_segment.ident != "Next") + } else { + true + } + } + }, + } +} + #[test] -fn ui() { +fn ui_debug_handler() { crate::run_ui_tests("debug_handler"); } + +#[test] +fn ui_debug_middleware() { + crate::run_ui_tests("debug_middleware"); +} diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index f35e7dae..7e3d465c 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -44,6 +44,7 @@ #![cfg_attr(test, allow(clippy::float_cmp))] #![cfg_attr(not(test), warn(clippy::print_stdout, clippy::dbg_macro))] +use debug_handler::FunctionKind; use proc_macro::TokenStream; use quote::{quote, ToTokens}; use syn::{parse::Parse, Type}; @@ -464,7 +465,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream { expand_with(item, |item| from_request::expand(item, FromRequestParts)) } -/// Generates better error messages when applied handler functions. +/// Generates better error messages when applied to handler functions. /// /// While using [`axum`], you can get long error messages for simple mistakes. For example: /// @@ -515,17 +516,15 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream { /// /// As the error message says, handler function needs to be async. /// -/// ``` +/// ```no_run /// use axum::{routing::get, Router, debug_handler}; /// /// #[tokio::main] /// async fn main() { -/// # async { /// let app = Router::new().route("/", get(handler)); /// /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); /// axum::serve(listener, app).await.unwrap(); -/// # }; /// } /// /// #[debug_handler] @@ -618,7 +617,65 @@ pub fn debug_handler(_attr: TokenStream, input: TokenStream) -> TokenStream { return input; #[cfg(debug_assertions)] - return expand_attr_with(_attr, input, debug_handler::expand); + return expand_attr_with(_attr, input, |attrs, item_fn| { + debug_handler::expand(attrs, item_fn, FunctionKind::Handler) + }); +} + +/// Generates better error messages when applied to middleware functions. +/// +/// This works similarly to [`#[debug_handler]`](macro@debug_handler) except for middleware using +/// [`axum::middleware::from_fn`]. +/// +/// # Example +/// +/// ```no_run +/// use axum::{ +/// routing::get, +/// extract::Request, +/// response::Response, +/// Router, +/// middleware::{self, Next}, +/// debug_middleware, +/// }; +/// +/// #[tokio::main] +/// async fn main() { +/// let app = Router::new() +/// .route("/", get(|| async {})) +/// .layer(middleware::from_fn(my_middleware)); +/// +/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); +/// axum::serve(listener, app).await.unwrap(); +/// } +/// +/// // if this wasn't a valid middleware function #[debug_middleware] would +/// // improve compile error +/// #[debug_middleware] +/// async fn my_middleware( +/// request: Request, +/// next: Next, +/// ) -> Response { +/// next.run(request).await +/// } +/// ``` +/// +/// # Performance +/// +/// This macro has no effect when compiled with the release profile. (eg. `cargo build --release`) +/// +/// [`axum`]: https://docs.rs/axum/latest +/// [`axum::middleware::from_fn`]: https://docs.rs/axum/0.7/axum/middleware/fn.from_fn.html +/// [`debug_middleware`]: macro@debug_middleware +#[proc_macro_attribute] +pub fn debug_middleware(_attr: TokenStream, input: TokenStream) -> TokenStream { + #[cfg(not(debug_assertions))] + return input; + + #[cfg(debug_assertions)] + return expand_attr_with(_attr, input, |attrs, item_fn| { + debug_handler::expand(attrs, item_fn, FunctionKind::Middleware) + }); } /// Private API: Do no use this! diff --git a/axum-macros/tests/debug_middleware/fail/doesnt_take_next.rs b/axum-macros/tests/debug_middleware/fail/doesnt_take_next.rs new file mode 100644 index 00000000..12092e85 --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/doesnt_take_next.rs @@ -0,0 +1,13 @@ +use axum::{ + debug_middleware, + extract::Request, + response::{IntoResponse, Response}, +}; + +#[debug_middleware] +async fn my_middleware(request: Request) -> Response { + let _ = request; + ().into_response() +} + +fn main() {} diff --git a/axum-macros/tests/debug_middleware/fail/doesnt_take_next.stderr b/axum-macros/tests/debug_middleware/fail/doesnt_take_next.stderr new file mode 100644 index 00000000..2474a4eb --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/doesnt_take_next.stderr @@ -0,0 +1,7 @@ +error: Middleware functions must take `axum::middleware::Next` as the last argument + --> tests/debug_middleware/fail/doesnt_take_next.rs:7:1 + | +7 | #[debug_middleware] + | ^^^^^^^^^^^^^^^^^^^ + | + = note: this error originates in the attribute macro `debug_middleware` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/debug_middleware/fail/next_not_last.rs b/axum-macros/tests/debug_middleware/fail/next_not_last.rs new file mode 100644 index 00000000..0108c854 --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/next_not_last.rs @@ -0,0 +1,13 @@ +use axum::{ + extract::Request, + response::Response, + middleware::Next, + debug_middleware, +}; + +#[debug_middleware] +async fn my_middleware(next: Next, request: Request) -> Response { + next.run(request).await +} + +fn main() {} diff --git a/axum-macros/tests/debug_middleware/fail/next_not_last.stderr b/axum-macros/tests/debug_middleware/fail/next_not_last.stderr new file mode 100644 index 00000000..8f08bed7 --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/next_not_last.stderr @@ -0,0 +1,5 @@ +error: `axum::middleware::Next` must the last argument + --> tests/debug_middleware/fail/next_not_last.rs:9:24 + | +9 | async fn my_middleware(next: Next, request: Request) -> Response { + | ^^^^^^^^^^ diff --git a/axum-macros/tests/debug_middleware/fail/takes_next_twice.rs b/axum-macros/tests/debug_middleware/fail/takes_next_twice.rs new file mode 100644 index 00000000..995a97bd --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/takes_next_twice.rs @@ -0,0 +1,9 @@ +use axum::{debug_middleware, extract::Request, middleware::Next, response::Response}; + +#[debug_middleware] +async fn my_middleware(request: Request, next: Next, next2: Next) -> Response { + let _ = next2; + next.run(request).await +} + +fn main() {} diff --git a/axum-macros/tests/debug_middleware/fail/takes_next_twice.stderr b/axum-macros/tests/debug_middleware/fail/takes_next_twice.stderr new file mode 100644 index 00000000..596f5581 --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/takes_next_twice.stderr @@ -0,0 +1,7 @@ +error: Middleware functions can only take one argument of type `axum::middleware::Next` + --> tests/debug_middleware/fail/takes_next_twice.rs:3:1 + | +3 | #[debug_middleware] + | ^^^^^^^^^^^^^^^^^^^ + | + = note: this error originates in the attribute macro `debug_middleware` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/debug_middleware/pass/basic.rs b/axum-macros/tests/debug_middleware/pass/basic.rs new file mode 100644 index 00000000..605cacfd --- /dev/null +++ b/axum-macros/tests/debug_middleware/pass/basic.rs @@ -0,0 +1,13 @@ +use axum::{ + extract::Request, + response::Response, + middleware::Next, + debug_middleware, +}; + +#[debug_middleware] +async fn my_middleware(request: Request, next: Next) -> Response { + next.run(request).await +} + +fn main() {} diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 316f2150..14b52f1f 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -18,7 +18,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 a `Router` or `MethodRouter` ([#2586]) - **fixed:** `h2` is no longer pulled as a dependency unless the `http2` feature is enabled ([#2605]) +- **added:** Add `#[debug_middleware]` ([#1993], [#2725]) +[#1993]: https://github.com/tokio-rs/axum/pull/1993 +[#2725]: https://github.com/tokio-rs/axum/pull/2725 [#2586]: https://github.com/tokio-rs/axum/pull/2586 [#2605]: https://github.com/tokio-rs/axum/pull/2605 diff --git a/axum/src/lib.rs b/axum/src/lib.rs index 601c14ae..fcd78a8b 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -463,7 +463,7 @@ pub use self::form::Form; pub use axum_core::{BoxError, Error, RequestExt, RequestPartsExt}; #[cfg(feature = "macros")] -pub use axum_macros::debug_handler; +pub use axum_macros::{debug_handler, debug_middleware}; #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[doc(inline)]