1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-04-26 13:56:22 +02:00

Support changing state type in #[debug_handler] ()

* support setting body type for #[debug_handler]

* Use lookahead1 to give better errors and detect duplicate arguments

* fix docs link
This commit is contained in:
David Pedersen 2022-08-18 11:41:14 +02:00 committed by GitHub
parent e7f1c88cd4
commit 568394a28e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 249 additions and 10 deletions

View file

@ -18,7 +18,11 @@ proc-macro = true
heck = "0.4"
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "1.0", features = ["full"] }
syn = { version = "1.0", features = [
"full",
# needed for `Hash` impls
"extra-traits",
] }
[dev-dependencies]
axum = { path = "../axum", version = "0.5", features = ["headers"] }

View file

@ -1,8 +1,10 @@
use std::collections::HashSet;
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned};
use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Token, Type};
pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
pub(crate) fn expand(mut attr: Attrs, item_fn: ItemFn) -> TokenStream {
let check_extractor_count = check_extractor_count(&item_fn);
let check_request_last_extractor = check_request_last_extractor(&item_fn);
let check_path_extractor = check_path_extractor(&item_fn);
@ -12,8 +14,14 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
// If the function is generic, we can't reliably check its inputs or whether the future it
// returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors.
let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() {
if attr.state_ty.is_none() {
attr.state_ty = state_type_from_args(&item_fn);
}
let state_ty = attr.state_ty.unwrap_or_else(|| syn::parse_quote!(()));
let check_inputs_impls_from_request =
check_inputs_impls_from_request(&item_fn, &attr.body_ty);
check_inputs_impls_from_request(&item_fn, &attr.body_ty, state_ty);
let check_future_send = check_future_send(&item_fn);
quote! {
@ -39,21 +47,46 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
}
}
mod kw {
syn::custom_keyword!(body);
syn::custom_keyword!(state);
}
pub(crate) struct Attrs {
body_ty: Type,
state_ty: Option<Type>,
}
impl Parse for Attrs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut body_ty = None;
let mut state_ty = None;
while !input.is_empty() {
let ident = input.parse::<syn::Ident>()?;
if ident == "body" {
let lh = input.lookahead1();
if lh.peek(kw::body) {
let kw = input.parse::<kw::body>()?;
if body_ty.is_some() {
return Err(syn::Error::new_spanned(
kw,
"`body` specified more than once",
));
}
input.parse::<Token![=]>()?;
body_ty = Some(input.parse()?);
} else if lh.peek(kw::state) {
let kw = input.parse::<kw::state>()?;
if state_ty.is_some() {
return Err(syn::Error::new_spanned(
kw,
"`state` specified more than once",
));
}
input.parse::<Token![=]>()?;
state_ty = Some(input.parse()?);
} else {
return Err(syn::Error::new_spanned(ident, "unknown argument"));
return Err(lh.error());
}
let _ = input.parse::<Token![,]>();
@ -61,7 +94,7 @@ impl Parse for Attrs {
let body_ty = body_ty.unwrap_or_else(|| syn::parse_quote!(axum::body::Body));
Ok(Self { body_ty })
Ok(Self { body_ty, state_ty })
}
}
@ -167,7 +200,11 @@ fn check_multiple_body_extractors(item_fn: &ItemFn) -> TokenStream {
}
}
fn check_inputs_impls_from_request(item_fn: &ItemFn, body_ty: &Type) -> TokenStream {
fn check_inputs_impls_from_request(
item_fn: &ItemFn,
body_ty: &Type,
state_ty: Type,
) -> TokenStream {
item_fn
.sig
.inputs
@ -203,7 +240,7 @@ fn check_inputs_impls_from_request(item_fn: &ItemFn, body_ty: &Type) -> TokenStr
#[allow(warnings)]
fn #name()
where
#ty: ::axum::extract::FromRequest<(), #body_ty> + Send,
#ty: ::axum::extract::FromRequest<#state_ty, #body_ty> + Send,
{}
}
})
@ -371,6 +408,68 @@ fn self_receiver(item_fn: &ItemFn) -> Option<TokenStream> {
None
}
/// Given a signature like
///
/// ```skip
/// #[debug_handler]
/// async fn handler(
/// _: axum::extract::State<AppState>,
/// _: State<AppState>,
/// ) {}
/// ```
///
/// This will extract `AppState`.
///
/// Returns `None` if there are no `State` args or multiple of different types.
fn state_type_from_args(item_fn: &ItemFn) -> Option<Type> {
let state_inputs = item_fn
.sig
.inputs
.iter()
.filter_map(|input| match input {
FnArg::Receiver(_) => None,
FnArg::Typed(pat_type) => Some(pat_type),
})
.map(|pat_type| &pat_type.ty)
.filter_map(|ty| {
if let Type::Path(path) = &**ty {
Some(&path.path)
} else {
None
}
})
.filter_map(|path| {
if let Some(last_segment) = path.segments.last() {
if last_segment.ident != "State" {
return None;
}
match &last_segment.arguments {
syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => {
Some(args.args.first().unwrap())
}
_ => None,
}
} else {
None
}
})
.filter_map(|generic_arg| {
if let syn::GenericArgument::Type(ty) = generic_arg {
Some(ty)
} else {
None
}
})
.collect::<HashSet<_>>();
if state_inputs.len() == 1 {
state_inputs.iter().next().map(|&ty| ty.clone())
} else {
None
}
}
#[test]
fn ui() {
#[rustversion::stable]

View file

@ -513,12 +513,60 @@ pub fn derive_from_request(item: TokenStream) -> TokenStream {
/// async fn handler(request: Request<BoxBody>) {}
/// ```
///
/// # Changing state type
///
/// By default `#[debug_handler]` assumes your state type is `()` unless your handler has a
/// [`axum::extract::State`] argument:
///
/// ```
/// use axum::extract::State;
/// # use axum_macros::debug_handler;
///
/// #[debug_handler]
/// async fn handler(
/// // this makes `#[debug_handler]` use `AppState`
/// State(state): State<AppState>,
/// ) {}
///
/// #[derive(Clone)]
/// struct AppState {}
/// ```
///
/// If your handler takes multiple [`axum::extract::State`] arguments or you need to otherwise
/// customize the state type you can set it with `#[debug_handler(state = ...)]`:
///
/// ```
/// use axum::extract::{State, FromRef};
/// # use axum_macros::debug_handler;
///
/// #[debug_handler(state = AppState)]
/// async fn handler(
/// State(app_state): State<AppState>,
/// State(inner_state): State<InnerState>,
/// ) {}
///
/// #[derive(Clone)]
/// struct AppState {
/// inner: InnerState,
/// }
///
/// #[derive(Clone)]
/// struct InnerState {}
///
/// impl FromRef<AppState> for InnerState {
/// fn from_ref(state: &AppState) -> Self {
/// state.inner.clone()
/// }
/// }
/// ```
///
/// # Performance
///
/// This macro has no effect when compiled with the release profile. (eg. `cargo build --release`)
///
/// [`axum`]: https://docs.rs/axum/latest
/// [`Handler`]: https://docs.rs/axum/latest/axum/handler/trait.Handler.html
/// [`axum::extract::State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html
/// [`debug_handler`]: macro@debug_handler
#[proc_macro_attribute]
pub fn debug_handler(_attr: TokenStream, input: TokenStream) -> TokenStream {

View file

@ -0,0 +1,9 @@
use axum_macros::debug_handler;
#[debug_handler(body = BoxBody, body = BoxBody)]
async fn handler() {}
#[debug_handler(state = (), state = ())]
async fn handler_2() {}
fn main() {}

View file

@ -0,0 +1,11 @@
error: `body` specified more than once
--> tests/debug_handler/fail/duplicate_args.rs:3:33
|
3 | #[debug_handler(body = BoxBody, body = BoxBody)]
| ^^^^
error: `state` specified more than once
--> tests/debug_handler/fail/duplicate_args.rs:6:29
|
6 | #[debug_handler(state = (), state = ())]
| ^^^^^

View file

@ -1,4 +1,4 @@
error: unknown argument
error: expected `body` or `state`
--> tests/debug_handler/fail/invalid_attrs.rs:3:17
|
3 | #[debug_handler(foo)]

View file

@ -0,0 +1,31 @@
use axum_macros::debug_handler;
use axum::extract::State;
#[debug_handler]
async fn handler(_: State<AppState>) {}
#[debug_handler]
async fn handler_2(_: axum::extract::State<AppState>) {}
#[debug_handler]
async fn handler_3(
_: axum::extract::State<AppState>,
_: axum::extract::State<AppState>,
) {}
#[debug_handler]
async fn handler_4(
_: State<AppState>,
_: State<AppState>,
) {}
#[debug_handler]
async fn handler_5(
_: axum::extract::State<AppState>,
_: State<AppState>,
) {}
#[derive(Clone)]
struct AppState;
fn main() {}

View file

@ -0,0 +1,27 @@
use axum_macros::debug_handler;
use axum::extract::{FromRef, FromRequest, RequestParts};
use axum::async_trait;
#[debug_handler(state = AppState)]
async fn handler(_: A) {}
#[derive(Clone)]
struct AppState;
struct A;
#[async_trait]
impl<S, B> FromRequest<S, B> for A
where
B: Send,
S: Send + Sync,
AppState: FromRef<S>,
{
type Rejection = ();
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
unimplemented!()
}
}
fn main() {}

View file

@ -0,0 +1,10 @@
use axum_macros::debug_handler;
use axum::{body::BoxBody, extract::State, http::Request};
#[debug_handler(state = AppState, body = BoxBody)]
async fn handler(_: State<AppState>, _: Request<BoxBody>) {}
#[derive(Clone)]
struct AppState;
fn main() {}