mirror of
https://github.com/tokio-rs/axum.git
synced 2025-04-26 13:56:22 +02:00
Support changing state type in #[debug_handler]
(#1271)
* support setting body type for #[debug_handler] * Use lookahead1 to give better errors and detect duplicate arguments * fix docs link
This commit is contained in:
parent
e7f1c88cd4
commit
568394a28e
9 changed files with 249 additions and 10 deletions
axum-macros
|
@ -18,7 +18,11 @@ proc-macro = true
|
|||
heck = "0.4"
|
||||
proc-macro2 = "1.0"
|
||||
quote = "1.0"
|
||||
syn = { version = "1.0", features = ["full"] }
|
||||
syn = { version = "1.0", features = [
|
||||
"full",
|
||||
# needed for `Hash` impls
|
||||
"extra-traits",
|
||||
] }
|
||||
|
||||
[dev-dependencies]
|
||||
axum = { path = "../axum", version = "0.5", features = ["headers"] }
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
use std::collections::HashSet;
|
||||
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::{format_ident, quote, quote_spanned};
|
||||
use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Token, Type};
|
||||
|
||||
pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
|
||||
pub(crate) fn expand(mut attr: Attrs, item_fn: ItemFn) -> TokenStream {
|
||||
let check_extractor_count = check_extractor_count(&item_fn);
|
||||
let check_request_last_extractor = check_request_last_extractor(&item_fn);
|
||||
let check_path_extractor = check_path_extractor(&item_fn);
|
||||
|
@ -12,8 +14,14 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
|
|||
// If the function is generic, we can't reliably check its inputs or whether the future it
|
||||
// returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors.
|
||||
let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() {
|
||||
if attr.state_ty.is_none() {
|
||||
attr.state_ty = state_type_from_args(&item_fn);
|
||||
}
|
||||
|
||||
let state_ty = attr.state_ty.unwrap_or_else(|| syn::parse_quote!(()));
|
||||
|
||||
let check_inputs_impls_from_request =
|
||||
check_inputs_impls_from_request(&item_fn, &attr.body_ty);
|
||||
check_inputs_impls_from_request(&item_fn, &attr.body_ty, state_ty);
|
||||
let check_future_send = check_future_send(&item_fn);
|
||||
|
||||
quote! {
|
||||
|
@ -39,21 +47,46 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
|
|||
}
|
||||
}
|
||||
|
||||
mod kw {
|
||||
syn::custom_keyword!(body);
|
||||
syn::custom_keyword!(state);
|
||||
}
|
||||
|
||||
pub(crate) struct Attrs {
|
||||
body_ty: Type,
|
||||
state_ty: Option<Type>,
|
||||
}
|
||||
|
||||
impl Parse for Attrs {
|
||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||
let mut body_ty = None;
|
||||
let mut state_ty = None;
|
||||
|
||||
while !input.is_empty() {
|
||||
let ident = input.parse::<syn::Ident>()?;
|
||||
if ident == "body" {
|
||||
let lh = input.lookahead1();
|
||||
|
||||
if lh.peek(kw::body) {
|
||||
let kw = input.parse::<kw::body>()?;
|
||||
if body_ty.is_some() {
|
||||
return Err(syn::Error::new_spanned(
|
||||
kw,
|
||||
"`body` specified more than once",
|
||||
));
|
||||
}
|
||||
input.parse::<Token![=]>()?;
|
||||
body_ty = Some(input.parse()?);
|
||||
} else if lh.peek(kw::state) {
|
||||
let kw = input.parse::<kw::state>()?;
|
||||
if state_ty.is_some() {
|
||||
return Err(syn::Error::new_spanned(
|
||||
kw,
|
||||
"`state` specified more than once",
|
||||
));
|
||||
}
|
||||
input.parse::<Token![=]>()?;
|
||||
state_ty = Some(input.parse()?);
|
||||
} else {
|
||||
return Err(syn::Error::new_spanned(ident, "unknown argument"));
|
||||
return Err(lh.error());
|
||||
}
|
||||
|
||||
let _ = input.parse::<Token![,]>();
|
||||
|
@ -61,7 +94,7 @@ impl Parse for Attrs {
|
|||
|
||||
let body_ty = body_ty.unwrap_or_else(|| syn::parse_quote!(axum::body::Body));
|
||||
|
||||
Ok(Self { body_ty })
|
||||
Ok(Self { body_ty, state_ty })
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -167,7 +200,11 @@ fn check_multiple_body_extractors(item_fn: &ItemFn) -> TokenStream {
|
|||
}
|
||||
}
|
||||
|
||||
fn check_inputs_impls_from_request(item_fn: &ItemFn, body_ty: &Type) -> TokenStream {
|
||||
fn check_inputs_impls_from_request(
|
||||
item_fn: &ItemFn,
|
||||
body_ty: &Type,
|
||||
state_ty: Type,
|
||||
) -> TokenStream {
|
||||
item_fn
|
||||
.sig
|
||||
.inputs
|
||||
|
@ -203,7 +240,7 @@ fn check_inputs_impls_from_request(item_fn: &ItemFn, body_ty: &Type) -> TokenStr
|
|||
#[allow(warnings)]
|
||||
fn #name()
|
||||
where
|
||||
#ty: ::axum::extract::FromRequest<(), #body_ty> + Send,
|
||||
#ty: ::axum::extract::FromRequest<#state_ty, #body_ty> + Send,
|
||||
{}
|
||||
}
|
||||
})
|
||||
|
@ -371,6 +408,68 @@ fn self_receiver(item_fn: &ItemFn) -> Option<TokenStream> {
|
|||
None
|
||||
}
|
||||
|
||||
/// Given a signature like
|
||||
///
|
||||
/// ```skip
|
||||
/// #[debug_handler]
|
||||
/// async fn handler(
|
||||
/// _: axum::extract::State<AppState>,
|
||||
/// _: State<AppState>,
|
||||
/// ) {}
|
||||
/// ```
|
||||
///
|
||||
/// This will extract `AppState`.
|
||||
///
|
||||
/// Returns `None` if there are no `State` args or multiple of different types.
|
||||
fn state_type_from_args(item_fn: &ItemFn) -> Option<Type> {
|
||||
let state_inputs = item_fn
|
||||
.sig
|
||||
.inputs
|
||||
.iter()
|
||||
.filter_map(|input| match input {
|
||||
FnArg::Receiver(_) => None,
|
||||
FnArg::Typed(pat_type) => Some(pat_type),
|
||||
})
|
||||
.map(|pat_type| &pat_type.ty)
|
||||
.filter_map(|ty| {
|
||||
if let Type::Path(path) = &**ty {
|
||||
Some(&path.path)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.filter_map(|path| {
|
||||
if let Some(last_segment) = path.segments.last() {
|
||||
if last_segment.ident != "State" {
|
||||
return None;
|
||||
}
|
||||
|
||||
match &last_segment.arguments {
|
||||
syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => {
|
||||
Some(args.args.first().unwrap())
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.filter_map(|generic_arg| {
|
||||
if let syn::GenericArgument::Type(ty) = generic_arg {
|
||||
Some(ty)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
if state_inputs.len() == 1 {
|
||||
state_inputs.iter().next().map(|&ty| ty.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ui() {
|
||||
#[rustversion::stable]
|
||||
|
|
|
@ -513,12 +513,60 @@ pub fn derive_from_request(item: TokenStream) -> TokenStream {
|
|||
/// async fn handler(request: Request<BoxBody>) {}
|
||||
/// ```
|
||||
///
|
||||
/// # Changing state type
|
||||
///
|
||||
/// By default `#[debug_handler]` assumes your state type is `()` unless your handler has a
|
||||
/// [`axum::extract::State`] argument:
|
||||
///
|
||||
/// ```
|
||||
/// use axum::extract::State;
|
||||
/// # use axum_macros::debug_handler;
|
||||
///
|
||||
/// #[debug_handler]
|
||||
/// async fn handler(
|
||||
/// // this makes `#[debug_handler]` use `AppState`
|
||||
/// State(state): State<AppState>,
|
||||
/// ) {}
|
||||
///
|
||||
/// #[derive(Clone)]
|
||||
/// struct AppState {}
|
||||
/// ```
|
||||
///
|
||||
/// If your handler takes multiple [`axum::extract::State`] arguments or you need to otherwise
|
||||
/// customize the state type you can set it with `#[debug_handler(state = ...)]`:
|
||||
///
|
||||
/// ```
|
||||
/// use axum::extract::{State, FromRef};
|
||||
/// # use axum_macros::debug_handler;
|
||||
///
|
||||
/// #[debug_handler(state = AppState)]
|
||||
/// async fn handler(
|
||||
/// State(app_state): State<AppState>,
|
||||
/// State(inner_state): State<InnerState>,
|
||||
/// ) {}
|
||||
///
|
||||
/// #[derive(Clone)]
|
||||
/// struct AppState {
|
||||
/// inner: InnerState,
|
||||
/// }
|
||||
///
|
||||
/// #[derive(Clone)]
|
||||
/// struct InnerState {}
|
||||
///
|
||||
/// impl FromRef<AppState> for InnerState {
|
||||
/// fn from_ref(state: &AppState) -> Self {
|
||||
/// state.inner.clone()
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// # Performance
|
||||
///
|
||||
/// This macro has no effect when compiled with the release profile. (eg. `cargo build --release`)
|
||||
///
|
||||
/// [`axum`]: https://docs.rs/axum/latest
|
||||
/// [`Handler`]: https://docs.rs/axum/latest/axum/handler/trait.Handler.html
|
||||
/// [`axum::extract::State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html
|
||||
/// [`debug_handler`]: macro@debug_handler
|
||||
#[proc_macro_attribute]
|
||||
pub fn debug_handler(_attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
|
|
9
axum-macros/tests/debug_handler/fail/duplicate_args.rs
Normal file
9
axum-macros/tests/debug_handler/fail/duplicate_args.rs
Normal file
|
@ -0,0 +1,9 @@
|
|||
use axum_macros::debug_handler;
|
||||
|
||||
#[debug_handler(body = BoxBody, body = BoxBody)]
|
||||
async fn handler() {}
|
||||
|
||||
#[debug_handler(state = (), state = ())]
|
||||
async fn handler_2() {}
|
||||
|
||||
fn main() {}
|
11
axum-macros/tests/debug_handler/fail/duplicate_args.stderr
Normal file
11
axum-macros/tests/debug_handler/fail/duplicate_args.stderr
Normal file
|
@ -0,0 +1,11 @@
|
|||
error: `body` specified more than once
|
||||
--> tests/debug_handler/fail/duplicate_args.rs:3:33
|
||||
|
|
||||
3 | #[debug_handler(body = BoxBody, body = BoxBody)]
|
||||
| ^^^^
|
||||
|
||||
error: `state` specified more than once
|
||||
--> tests/debug_handler/fail/duplicate_args.rs:6:29
|
||||
|
|
||||
6 | #[debug_handler(state = (), state = ())]
|
||||
| ^^^^^
|
|
@ -1,4 +1,4 @@
|
|||
error: unknown argument
|
||||
error: expected `body` or `state`
|
||||
--> tests/debug_handler/fail/invalid_attrs.rs:3:17
|
||||
|
|
||||
3 | #[debug_handler(foo)]
|
||||
|
|
31
axum-macros/tests/debug_handler/pass/infer_state.rs
Normal file
31
axum-macros/tests/debug_handler/pass/infer_state.rs
Normal file
|
@ -0,0 +1,31 @@
|
|||
use axum_macros::debug_handler;
|
||||
use axum::extract::State;
|
||||
|
||||
#[debug_handler]
|
||||
async fn handler(_: State<AppState>) {}
|
||||
|
||||
#[debug_handler]
|
||||
async fn handler_2(_: axum::extract::State<AppState>) {}
|
||||
|
||||
#[debug_handler]
|
||||
async fn handler_3(
|
||||
_: axum::extract::State<AppState>,
|
||||
_: axum::extract::State<AppState>,
|
||||
) {}
|
||||
|
||||
#[debug_handler]
|
||||
async fn handler_4(
|
||||
_: State<AppState>,
|
||||
_: State<AppState>,
|
||||
) {}
|
||||
|
||||
#[debug_handler]
|
||||
async fn handler_5(
|
||||
_: axum::extract::State<AppState>,
|
||||
_: State<AppState>,
|
||||
) {}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState;
|
||||
|
||||
fn main() {}
|
27
axum-macros/tests/debug_handler/pass/set_state.rs
Normal file
27
axum-macros/tests/debug_handler/pass/set_state.rs
Normal file
|
@ -0,0 +1,27 @@
|
|||
use axum_macros::debug_handler;
|
||||
use axum::extract::{FromRef, FromRequest, RequestParts};
|
||||
use axum::async_trait;
|
||||
|
||||
#[debug_handler(state = AppState)]
|
||||
async fn handler(_: A) {}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState;
|
||||
|
||||
struct A;
|
||||
|
||||
#[async_trait]
|
||||
impl<S, B> FromRequest<S, B> for A
|
||||
where
|
||||
B: Send,
|
||||
S: Send + Sync,
|
||||
AppState: FromRef<S>,
|
||||
{
|
||||
type Rejection = ();
|
||||
|
||||
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {}
|
10
axum-macros/tests/debug_handler/pass/state_and_body.rs
Normal file
10
axum-macros/tests/debug_handler/pass/state_and_body.rs
Normal file
|
@ -0,0 +1,10 @@
|
|||
use axum_macros::debug_handler;
|
||||
use axum::{body::BoxBody, extract::State, http::Request};
|
||||
|
||||
#[debug_handler(state = AppState, body = BoxBody)]
|
||||
async fn handler(_: State<AppState>, _: Request<BoxBody>) {}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState;
|
||||
|
||||
fn main() {}
|
Loading…
Add table
Reference in a new issue