diff --git a/axum-macros/Cargo.toml b/axum-macros/Cargo.toml index a490702e..0162e08f 100644 --- a/axum-macros/Cargo.toml +++ b/axum-macros/Cargo.toml @@ -26,7 +26,7 @@ syn = { version = "1.0", features = [ [dev-dependencies] axum = { path = "../axum", version = "0.6.0-rc.2", features = ["headers"] } -axum-extra = { path = "../axum-extra", version = "0.4.0-rc.1", features = ["typed-routing"] } +axum-extra = { path = "../axum-extra", version = "0.4.0-rc.1", features = ["typed-routing", "cookie-private"] } rustversion = "1.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index 4ed0a5a2..e3b4a885 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -4,7 +4,6 @@ use crate::{ }; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; -use std::collections::HashSet; use syn::{parse::Parse, parse_quote, spanned::Spanned, FnArg, ItemFn, Token, Type}; pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { @@ -435,7 +434,7 @@ fn self_receiver(item_fn: &ItemFn) -> Option { /// /// Returns `None` if there are no `State` args or multiple of different types. fn state_type_from_args(item_fn: &ItemFn) -> Option { - let state_inputs = item_fn + let types = item_fn .sig .inputs .iter() @@ -443,44 +442,8 @@ fn state_type_from_args(item_fn: &ItemFn) -> Option { FnArg::Receiver(_) => None, FnArg::Typed(pat_type) => Some(pat_type), }) - .map(|pat_type| &pat_type.ty) - .filter_map(|ty| { - if let Type::Path(path) = &**ty { - Some(&path.path) - } else { - None - } - }) - .filter_map(|path| { - if let Some(last_segment) = path.segments.last() { - if last_segment.ident != "State" { - return None; - } - - match &last_segment.arguments { - syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => { - Some(args.args.first().unwrap()) - } - _ => None, - } - } else { - None - } - }) - .filter_map(|generic_arg| { - if let syn::GenericArgument::Type(ty) = generic_arg { - Some(ty) - } else { - None - } - }) - .collect::>(); - - if state_inputs.len() == 1 { - state_inputs.iter().next().map(|&ty| ty.clone()) - } else { - None - } + .map(|pat_type| &*pat_type.ty); + crate::infer_state_type(types) } #[test] diff --git a/axum-macros/src/from_request.rs b/axum-macros/src/from_request.rs index 2cfff53c..6803a2c2 100644 --- a/axum-macros/src/from_request.rs +++ b/axum-macros/src/from_request.rs @@ -4,9 +4,11 @@ use crate::{ from_request::attr::FromRequestFieldAttrs, }; use proc_macro2::{Span, TokenStream}; -use quote::{quote, quote_spanned}; -use std::fmt; -use syn::{punctuated::Punctuated, spanned::Spanned, Ident, Token}; +use quote::{quote, quote_spanned, ToTokens}; +use std::{collections::HashSet, fmt, iter}; +use syn::{ + parse_quote, punctuated::Punctuated, spanned::Spanned, Fields, Ident, Path, Token, Type, +}; mod attr; @@ -16,6 +18,22 @@ pub(crate) enum Trait { FromRequestParts, } +impl Trait { + fn body_type(&self) -> impl Iterator { + match self { + Trait::FromRequest => Some(parse_quote!(B)).into_iter(), + Trait::FromRequestParts => None.into_iter(), + } + } + + fn via_marker_type(&self) -> Option { + match self { + Trait::FromRequest => Some(parse_quote!(M)), + Trait::FromRequestParts => None, + } + } +} + impl fmt::Display for Trait { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -25,6 +43,55 @@ impl fmt::Display for Trait { } } +#[derive(Debug)] +enum State { + Custom(syn::Type), + Default(syn::Type), +} + +impl State { + /// ```not_rust + /// impl A for B {} + /// ^ this type + /// ``` + fn impl_generics(&self) -> impl Iterator { + match self { + State::Default(inner) => Some(inner.clone()), + State::Custom(_) => None, + } + .into_iter() + } + + /// ```not_rust + /// impl A for B {} + /// ^ this type + /// ``` + fn trait_generics(&self) -> impl Iterator { + match self { + State::Default(inner) => iter::once(inner.clone()), + State::Custom(inner) => iter::once(inner.clone()), + } + } + + fn bounds(&self) -> TokenStream { + match self { + State::Custom(_) => quote! {}, + State::Default(inner) => quote! { + #inner: ::std::marker::Send + ::std::marker::Sync, + }, + } + } +} + +impl ToTokens for State { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + State::Custom(inner) => inner.to_tokens(tokens), + State::Default(inner) => inner.to_tokens(tokens), + } + } +} + pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result { match item { syn::Item::Struct(item) => { @@ -40,7 +107,23 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result { let generic_ident = parse_single_generic_type_on_struct(generics, &fields, tr)?; - let FromRequestContainerAttrs { via, rejection } = parse_attrs("from_request", &attrs)?; + let FromRequestContainerAttrs { + via, + rejection, + state, + } = parse_attrs("from_request", &attrs)?; + + let state = match state { + Some((_, state)) => State::Custom(state), + None => infer_state_type_from_field_types(&fields) + .map(State::Custom) + .or_else(|| infer_state_type_from_field_attributes(&fields).map(State::Custom)) + .or_else(|| { + let via = via.as_ref().map(|(_, via)| via)?; + state_from_via(&ident, via).map(State::Custom) + }) + .unwrap_or_else(|| State::Default(syn::parse_quote!(S))), + }; match (via.map(second), rejection.map(second)) { (Some(via), rejection) => impl_struct_by_extracting_all_at_once( @@ -49,11 +132,12 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result { via, rejection, generic_ident, + state, tr, ), (None, rejection) => { error_on_generic_ident(generic_ident, tr)?; - impl_struct_by_extracting_each_field(ident, fields, rejection, tr) + impl_struct_by_extracting_each_field(ident, fields, rejection, state, tr) } } } @@ -78,7 +162,20 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result { return Err(syn::Error::new_spanned(where_clause, generics_error)); } - let FromRequestContainerAttrs { via, rejection } = parse_attrs("from_request", &attrs)?; + let FromRequestContainerAttrs { + via, + rejection, + state, + } = parse_attrs("from_request", &attrs)?; + + let state = match state { + Some((_, state)) => State::Custom(state), + None => (|| { + let via = via.as_ref().map(|(_, via)| via)?; + state_from_via(&ident, via).map(State::Custom) + })() + .unwrap_or_else(|| State::Default(syn::parse_quote!(S))), + }; match (via.map(second), rejection) { (Some(via), rejection) => impl_enum_by_extracting_all_at_once( @@ -86,6 +183,7 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result { variants, via, rejection.map(second), + state, tr, ), (None, Some((rejection_kw, _))) => Err(syn::Error::new_spanned( @@ -210,6 +308,7 @@ fn impl_struct_by_extracting_each_field( ident: syn::Ident, fields: syn::Fields, rejection: Option, + state: State, tr: Trait, ) -> syn::Result { let extract_fields = extract_fields(&fields, &rejection, tr)?; @@ -222,22 +321,34 @@ fn impl_struct_by_extracting_each_field( quote!(::axum::response::Response) }; + let impl_generics = tr + .body_type() + .chain(state.impl_generics()) + .collect::>(); + + let trait_generics = state + .trait_generics() + .chain(tr.body_type()) + .collect::>(); + + let state_bounds = state.bounds(); + Ok(match tr { Trait::FromRequest => quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident where B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, B::Error: ::std::convert::Into<::axum::BoxError>, - S: ::std::marker::Send + ::std::marker::Sync, + #state_bounds { type Rejection = #rejection_ident; async fn from_request( mut req: ::axum::http::Request, - state: &S, + state: &#state, ) -> ::std::result::Result { ::std::result::Result::Ok(Self { #(#extract_fields)* @@ -248,15 +359,15 @@ fn impl_struct_by_extracting_each_field( Trait::FromRequestParts => quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequestParts for #ident + impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident where - S: ::std::marker::Send + ::std::marker::Sync, + #state_bounds { type Rejection = #rejection_ident; async fn from_request_parts( parts: &mut ::axum::http::request::Parts, - state: &S, + state: &#state, ) -> ::std::result::Result { ::std::result::Result::Ok(Self { #(#extract_fields)* @@ -547,9 +658,10 @@ fn peel_result_ok(ty: &syn::Type) -> Option<&syn::Type> { fn impl_struct_by_extracting_all_at_once( ident: syn::Ident, fields: syn::Fields, - path: syn::Path, + via_path: syn::Path, rejection: Option, generic_ident: Option, + state: State, tr: Trait, ) -> syn::Result { let fields = match fields { @@ -570,7 +682,7 @@ fn impl_struct_by_extracting_all_at_once( } } - let path_span = path.span(); + let path_span = via_path.span(); let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection { let rejection = quote! { #rejection }; @@ -584,43 +696,68 @@ fn impl_struct_by_extracting_all_at_once( (rejection, map_err) }; + // for something like + // + // ``` + // #[derive(Clone, Default, FromRequest)] + // #[from_request(via(State))] + // struct AppState {} + // ``` + // + // we need to implement `impl FromRequest` but only for + // - `#[derive(FromRequest)]`, not `#[derive(FromRequestParts)]` + // - `State`, not other extractors + // + // honestly not sure why but the tests all pass + let via_marker_type = if path_ident_is_state(&via_path) { + tr.via_marker_type() + } else { + None + }; + + let impl_generics = tr + .body_type() + .chain(via_marker_type.clone()) + .chain(state.impl_generics()) + .chain(generic_ident.is_some().then(|| parse_quote!(T))) + .collect::>(); + + let trait_generics = state + .trait_generics() + .chain(tr.body_type()) + .chain(via_marker_type) + .collect::>(); + + let ident_generics = generic_ident + .is_some() + .then(|| quote! { }) + .unwrap_or_default(); + let rejection_bound = rejection.as_ref().map(|rejection| { match (tr, generic_ident.is_some()) { (Trait::FromRequest, true) => { quote! { - #rejection: ::std::convert::From<<#path as ::axum::extract::FromRequest>::Rejection>, + #rejection: ::std::convert::From<<#via_path as ::axum::extract::FromRequest<#trait_generics>>::Rejection>, } }, (Trait::FromRequest, false) => { quote! { - #rejection: ::std::convert::From<<#path as ::axum::extract::FromRequest>::Rejection>, + #rejection: ::std::convert::From<<#via_path as ::axum::extract::FromRequest<#trait_generics>>::Rejection>, } }, (Trait::FromRequestParts, true) => { quote! { - #rejection: ::std::convert::From<<#path as ::axum::extract::FromRequestParts>::Rejection>, + #rejection: ::std::convert::From<<#via_path as ::axum::extract::FromRequestParts<#trait_generics>>::Rejection>, } }, (Trait::FromRequestParts, false) => { quote! { - #rejection: ::std::convert::From<<#path as ::axum::extract::FromRequestParts>::Rejection>, + #rejection: ::std::convert::From<<#via_path as ::axum::extract::FromRequestParts<#trait_generics>>::Rejection>, } } } }).unwrap_or_default(); - let impl_generics = match (tr, generic_ident.is_some()) { - (Trait::FromRequest, true) => quote! { S, B, T }, - (Trait::FromRequest, false) => quote! { S, B }, - (Trait::FromRequestParts, true) => quote! { S, T }, - (Trait::FromRequestParts, false) => quote! { S }, - }; - - let type_generics = generic_ident - .is_some() - .then(|| quote! { }) - .unwrap_or_default(); - let via_type_generics = if generic_ident.is_some() { quote! { T } } else { @@ -635,27 +772,29 @@ fn impl_struct_by_extracting_all_at_once( quote! { value } }; + let state_bounds = state.bounds(); + let tokens = match tr { Trait::FromRequest => { quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] - impl<#impl_generics> ::axum::extract::FromRequest for #ident #type_generics + impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident #ident_generics where - #path<#via_type_generics>: ::axum::extract::FromRequest, + #via_path<#via_type_generics>: ::axum::extract::FromRequest<#trait_generics>, #rejection_bound B: ::std::marker::Send + 'static, - S: ::std::marker::Send + ::std::marker::Sync, + #state_bounds { type Rejection = #associated_rejection_type; async fn from_request( req: ::axum::http::Request, - state: &S + state: &#state, ) -> ::std::result::Result { ::axum::extract::FromRequest::from_request(req, state) .await - .map(|#path(value)| #value_to_self) + .map(|#via_path(value)| #value_to_self) .map_err(#map_err) } } @@ -665,21 +804,21 @@ fn impl_struct_by_extracting_all_at_once( quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] - impl<#impl_generics> ::axum::extract::FromRequestParts for #ident #type_generics + impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident #ident_generics where - #path<#via_type_generics>: ::axum::extract::FromRequestParts, + #via_path<#via_type_generics>: ::axum::extract::FromRequestParts<#trait_generics>, #rejection_bound - S: ::std::marker::Send + ::std::marker::Sync, + #state_bounds { type Rejection = #associated_rejection_type; async fn from_request_parts( parts: &mut ::axum::http::request::Parts, - state: &S + state: &#state, ) -> ::std::result::Result { ::axum::extract::FromRequestParts::from_request_parts(parts, state) .await - .map(|#path(value)| #value_to_self) + .map(|#via_path(value)| #value_to_self) .map_err(#map_err) } } @@ -695,6 +834,7 @@ fn impl_enum_by_extracting_all_at_once( variants: Punctuated, path: syn::Path, rejection: Option, + state: State, tr: Trait, ) -> syn::Result { for variant in variants { @@ -738,23 +878,35 @@ fn impl_enum_by_extracting_all_at_once( let path_span = path.span(); + let impl_generics = tr + .body_type() + .chain(state.impl_generics()) + .collect::>(); + + let trait_generics = state + .trait_generics() + .chain(tr.body_type()) + .collect::>(); + + let state_bounds = state.bounds(); + let tokens = match tr { Trait::FromRequest => { quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident where B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, B::Error: ::std::convert::Into<::axum::BoxError>, - S: ::std::marker::Send + ::std::marker::Sync, + #state_bounds { type Rejection = #associated_rejection_type; async fn from_request( req: ::axum::http::Request, - state: &S + state: &#state, ) -> ::std::result::Result { ::axum::extract::FromRequest::from_request(req, state) .await @@ -768,15 +920,15 @@ fn impl_enum_by_extracting_all_at_once( quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequestParts for #ident + impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident where - S: ::std::marker::Send + ::std::marker::Sync, + #state_bounds { type Rejection = #associated_rejection_type; async fn from_request_parts( parts: &mut ::axum::http::request::Parts, - state: &S + state: &#state, ) -> ::std::result::Result { ::axum::extract::FromRequestParts::from_request_parts(parts, state) .await @@ -791,6 +943,90 @@ fn impl_enum_by_extracting_all_at_once( Ok(tokens) } +/// For a struct like +/// +/// ```skip +/// struct Extractor { +/// state: State, +/// } +/// ``` +/// +/// We can infer the state type to be `AppState` because it appears inside a `State` +fn infer_state_type_from_field_types(fields: &Fields) -> Option { + match fields { + Fields::Named(fields_named) => { + crate::infer_state_type(fields_named.named.iter().map(|field| &field.ty)) + } + Fields::Unnamed(fields_unnamed) => { + crate::infer_state_type(fields_unnamed.unnamed.iter().map(|field| &field.ty)) + } + Fields::Unit => None, + } +} + +/// For a struct like +/// +/// ```skip +/// struct Extractor { +/// #[from_request(via(State))] +/// state: AppState, +/// } +/// ``` +/// +/// We can infer the state type to be `AppState` because it has `via(State)` and thus can be +/// extracted with `State` +fn infer_state_type_from_field_attributes(fields: &Fields) -> Option { + let state_inputs = match fields { + Fields::Named(fields_named) => { + fields_named + .named + .iter() + .filter_map(|field| { + // TODO(david): its a little wasteful to parse the attributes again here + // ideally we should parse things once and pass the data down + let FromRequestFieldAttrs { via } = + parse_attrs("from_request", &field.attrs).ok()?; + let (_, via_path) = via?; + path_ident_is_state(&via_path).then(|| &field.ty) + }) + .collect::>() + } + Fields::Unnamed(fields_unnamed) => { + fields_unnamed + .unnamed + .iter() + .filter_map(|field| { + // TODO(david): its a little wasteful to parse the attributes again here + // ideally we should parse things once and pass the data down + let FromRequestFieldAttrs { via } = + parse_attrs("from_request", &field.attrs).ok()?; + let (_, via_path) = via?; + path_ident_is_state(&via_path).then(|| &field.ty) + }) + .collect::>() + } + Fields::Unit => return None, + }; + + if state_inputs.len() == 1 { + state_inputs.iter().next().map(|&ty| ty.clone()) + } else { + None + } +} + +fn path_ident_is_state(path: &Path) -> bool { + if let Some(last_segment) = path.segments.last() { + last_segment.ident == "State" + } else { + false + } +} + +fn state_from_via(ident: &Ident, via: &Path) -> Option { + path_ident_is_state(via).then(|| parse_quote!(#ident)) +} + #[test] fn ui() { crate::run_ui_tests("from_request"); diff --git a/axum-macros/src/from_request/attr.rs b/axum-macros/src/from_request/attr.rs index 45ef1f24..77dfc470 100644 --- a/axum-macros/src/from_request/attr.rs +++ b/axum-macros/src/from_request/attr.rs @@ -7,18 +7,21 @@ use syn::{ pub(crate) mod kw { syn::custom_keyword!(via); syn::custom_keyword!(rejection); + syn::custom_keyword!(state); } #[derive(Default)] pub(super) struct FromRequestContainerAttrs { pub(super) via: Option<(kw::via, syn::Path)>, pub(super) rejection: Option<(kw::rejection, syn::Path)>, + pub(super) state: Option<(kw::state, syn::Type)>, } impl Parse for FromRequestContainerAttrs { fn parse(input: ParseStream) -> syn::Result { let mut via = None; let mut rejection = None; + let mut state = None; while !input.is_empty() { let lh = input.lookahead1(); @@ -26,6 +29,8 @@ impl Parse for FromRequestContainerAttrs { parse_parenthesized_attribute(input, &mut via)?; } else if lh.peek(kw::rejection) { parse_parenthesized_attribute(input, &mut rejection)?; + } else if lh.peek(kw::state) { + parse_parenthesized_attribute(input, &mut state)?; } else { return Err(lh.error()); } @@ -33,15 +38,24 @@ impl Parse for FromRequestContainerAttrs { let _ = input.parse::(); } - Ok(Self { via, rejection }) + Ok(Self { + via, + rejection, + state, + }) } } impl Combine for FromRequestContainerAttrs { fn combine(mut self, other: Self) -> syn::Result { - let Self { via, rejection } = other; + let Self { + via, + rejection, + state, + } = other; combine_attribute(&mut self.via, via)?; combine_attribute(&mut self.rejection, rejection)?; + combine_attribute(&mut self.state, state)?; Ok(self) } } diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index 690127b5..0529315c 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -43,9 +43,11 @@ #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(test, allow(clippy::float_cmp))] +use std::collections::HashSet; + use proc_macro::TokenStream; use quote::{quote, ToTokens}; -use syn::parse::Parse; +use syn::{parse::Parse, Type}; mod attr_parsing; mod debug_handler; @@ -613,6 +615,50 @@ where } } +fn infer_state_type<'a, I>(types: I) -> Option +where + I: Iterator, +{ + let state_inputs = types + .filter_map(|ty| { + if let Type::Path(path) = ty { + Some(&path.path) + } else { + None + } + }) + .filter_map(|path| { + if let Some(last_segment) = path.segments.last() { + if last_segment.ident != "State" { + return None; + } + + match &last_segment.arguments { + syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => { + Some(args.args.first().unwrap()) + } + _ => None, + } + } else { + None + } + }) + .filter_map(|generic_arg| { + if let syn::GenericArgument::Type(ty) = generic_arg { + Some(ty) + } else { + None + } + }) + .collect::>(); + + if state_inputs.len() == 1 { + state_inputs.iter().next().map(|&ty| ty.clone()) + } else { + None + } +} + #[cfg(test)] fn run_ui_tests(directory: &str) { #[rustversion::stable] diff --git a/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.rs b/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.rs new file mode 100644 index 00000000..57400377 --- /dev/null +++ b/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.rs @@ -0,0 +1,22 @@ +use axum_macros::FromRequest; +use axum::extract::State; + +#[derive(FromRequest)] +struct Extractor { + inner_state: State, + other_state: State, +} + +#[derive(Clone)] +struct AppState {} + +#[derive(Clone)] +struct OtherState {} + +fn assert_from_request() +where + Extractor: axum::extract::FromRequest, +{ +} + +fn main() {} diff --git a/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.stderr b/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.stderr new file mode 100644 index 00000000..7b9bfdb3 --- /dev/null +++ b/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.stderr @@ -0,0 +1,23 @@ +error[E0277]: the trait bound `AppState: FromRef` is not satisfied + --> tests/from_request/fail/state_infer_multiple_different_types.rs:6:18 + | +6 | inner_state: State, + | ^^^^^ the trait `FromRef` is not implemented for `AppState` + | + = note: required because of the requirements on the impl of `FromRequestParts` for `State` +help: consider extending the `where` clause, but there might be an alternative better way to express this requirement + | +4 | #[derive(FromRequest, AppState: FromRef)] + | ++++++++++++++++++++++ + +error[E0277]: the trait bound `OtherState: FromRef` is not satisfied + --> tests/from_request/fail/state_infer_multiple_different_types.rs:7:18 + | +7 | other_state: State, + | ^^^^^ the trait `FromRef` is not implemented for `OtherState` + | + = note: required because of the requirements on the impl of `FromRequestParts` for `State` +help: consider extending the `where` clause, but there might be an alternative better way to express this requirement + | +4 | #[derive(FromRequest, OtherState: FromRef)] + | ++++++++++++++++++++++++ diff --git a/axum-macros/tests/from_request/fail/unknown_attr_container.stderr b/axum-macros/tests/from_request/fail/unknown_attr_container.stderr index 25eeca56..09a25165 100644 --- a/axum-macros/tests/from_request/fail/unknown_attr_container.stderr +++ b/axum-macros/tests/from_request/fail/unknown_attr_container.stderr @@ -1,4 +1,4 @@ -error: expected `via` or `rejection` +error: expected one of: `via`, `rejection`, `state` --> tests/from_request/fail/unknown_attr_container.rs:4:16 | 4 | #[from_request(foo)] diff --git a/axum-macros/tests/from_request/pass/state_cookie.rs b/axum-macros/tests/from_request/pass/state_cookie.rs new file mode 100644 index 00000000..a4f46c6a --- /dev/null +++ b/axum-macros/tests/from_request/pass/state_cookie.rs @@ -0,0 +1,27 @@ +use axum_macros::FromRequest; +use axum::extract::FromRef; +use axum_extra::extract::cookie::{PrivateCookieJar, Key}; + +#[derive(FromRequest)] +#[from_request(state(AppState))] +struct Extractor { + cookies: PrivateCookieJar, +} + +struct AppState { + key: Key, +} + +impl FromRef for Key { + fn from_ref(input: &AppState) -> Self { + input.key.clone() + } +} + +fn assert_from_request() +where + Extractor: axum::extract::FromRequest, +{ +} + +fn main() {} diff --git a/axum-macros/tests/from_request/pass/state_enum_via.rs b/axum-macros/tests/from_request/pass/state_enum_via.rs new file mode 100644 index 00000000..99af401c --- /dev/null +++ b/axum-macros/tests/from_request/pass/state_enum_via.rs @@ -0,0 +1,34 @@ +use axum::{ + extract::{State, FromRef}, + routing::get, + Router, +}; +use axum_macros::FromRequest; + +fn main() { + let _: Router = Router::with_state(AppState::default()) + .route("/a", get(|_: AppState| async {})) + .route("/b", get(|_: InnerState| async {})); +} + +#[derive(Clone, FromRequest)] +#[from_request(via(State))] +enum AppState { + One, +} + +impl Default for AppState { + fn default() -> AppState { + Self::One + } +} + +#[derive(FromRequest)] +#[from_request(via(State), state(AppState))] +enum InnerState {} + +impl FromRef for InnerState { + fn from_ref(_: &AppState) -> Self { + todo!(":shrug:") + } +} diff --git a/axum-macros/tests/from_request/pass/state_enum_via_parts.rs b/axum-macros/tests/from_request/pass/state_enum_via_parts.rs new file mode 100644 index 00000000..9700ac94 --- /dev/null +++ b/axum-macros/tests/from_request/pass/state_enum_via_parts.rs @@ -0,0 +1,35 @@ +use axum::{ + extract::{State, FromRef}, + routing::get, + Router, +}; +use axum_macros::FromRequestParts; + +fn main() { + let _: Router = Router::with_state(AppState::default()) + .route("/a", get(|_: AppState| async {})) + .route("/b", get(|_: InnerState| async {})) + .route("/c", get(|_: AppState, _: InnerState| async {})); +} + +#[derive(Clone, FromRequestParts)] +#[from_request(via(State))] +enum AppState { + One, +} + +impl Default for AppState { + fn default() -> AppState { + Self::One + } +} + +#[derive(FromRequestParts)] +#[from_request(via(State), state(AppState))] +enum InnerState {} + +impl FromRef for InnerState { + fn from_ref(_: &AppState) -> Self { + todo!(":shrug:") + } +} diff --git a/axum-macros/tests/from_request/pass/state_explicit.rs b/axum-macros/tests/from_request/pass/state_explicit.rs new file mode 100644 index 00000000..5a608eab --- /dev/null +++ b/axum-macros/tests/from_request/pass/state_explicit.rs @@ -0,0 +1,44 @@ +use axum_macros::FromRequest; +use axum::{ + extract::{FromRef, State}, + Router, + routing::get, +}; + +fn main() { + let _: Router = Router::with_state(AppState::default()) + .route("/b", get(|_: Extractor| async {})); +} + +#[derive(FromRequest)] +#[from_request(state(AppState))] +struct Extractor { + app_state: State, + one: State, + two: State, + other_extractor: String, +} + +#[derive(Clone, Default)] +struct AppState { + one: One, + two: Two, +} + +#[derive(Clone, Default)] +struct One {} + +impl FromRef for One { + fn from_ref(input: &AppState) -> Self { + input.one.clone() + } +} + +#[derive(Clone, Default)] +struct Two {} + +impl FromRef for Two { + fn from_ref(input: &AppState) -> Self { + input.two.clone() + } +} diff --git a/axum-macros/tests/from_request/pass/state_explicit_parts.rs b/axum-macros/tests/from_request/pass/state_explicit_parts.rs new file mode 100644 index 00000000..5581ef5f --- /dev/null +++ b/axum-macros/tests/from_request/pass/state_explicit_parts.rs @@ -0,0 +1,33 @@ +use axum_macros::FromRequestParts; +use axum::{ + extract::{FromRef, State, Query}, + Router, + routing::get, +}; +use std::collections::HashMap; + +fn main() { + let _: Router = Router::with_state(AppState::default()) + .route("/b", get(|_: Extractor| async {})); +} + +#[derive(FromRequestParts)] +#[from_request(state(AppState))] +struct Extractor { + inner_state: State, + other: Query>, +} + +#[derive(Default)] +struct AppState { + inner: InnerState, +} + +#[derive(Clone, Default)] +struct InnerState {} + +impl FromRef for InnerState { + fn from_ref(input: &AppState) -> Self { + input.inner.clone() + } +} diff --git a/axum-macros/tests/from_request/pass/state_field_explicit.rs b/axum-macros/tests/from_request/pass/state_field_explicit.rs new file mode 100644 index 00000000..363efab8 --- /dev/null +++ b/axum-macros/tests/from_request/pass/state_field_explicit.rs @@ -0,0 +1,34 @@ +use axum::{ + extract::{State, FromRef}, + routing::get, + Router, +}; +use axum_macros::FromRequest; + +fn main() { + let _: Router = Router::with_state(AppState::default()) + .route("/", get(|_: Extractor| async {})); +} + +#[derive(FromRequest)] +#[from_request(state(AppState))] +struct Extractor { + #[from_request(via(State))] + state: AppState, + #[from_request(via(State))] + inner: InnerState, +} + +#[derive(Clone, Default)] +struct AppState { + inner: InnerState, +} + +#[derive(Clone, Default)] +struct InnerState {} + +impl FromRef for InnerState { + fn from_ref(input: &AppState) -> Self { + input.inner.clone() + } +} diff --git a/axum-macros/tests/from_request/pass/state_field_infer.rs b/axum-macros/tests/from_request/pass/state_field_infer.rs new file mode 100644 index 00000000..03330578 --- /dev/null +++ b/axum-macros/tests/from_request/pass/state_field_infer.rs @@ -0,0 +1,20 @@ +use axum::{ + extract::State, + routing::get, + Router, +}; +use axum_macros::FromRequest; + +fn main() { + let _: Router = Router::with_state(AppState::default()) + .route("/", get(|_: Extractor| async {})); +} + +#[derive(FromRequest)] +struct Extractor { + #[from_request(via(State))] + state: AppState, +} + +#[derive(Clone, Default)] +struct AppState {} diff --git a/axum-macros/tests/from_request/pass/state_infer.rs b/axum-macros/tests/from_request/pass/state_infer.rs new file mode 100644 index 00000000..52906149 --- /dev/null +++ b/axum-macros/tests/from_request/pass/state_infer.rs @@ -0,0 +1,18 @@ +use axum_macros::FromRequest; +use axum::extract::State; + +#[derive(FromRequest)] +struct Extractor { + inner_state: State, +} + +#[derive(Clone)] +struct AppState {} + +fn assert_from_request() +where + Extractor: axum::extract::FromRequest, +{ +} + +fn main() {} diff --git a/axum-macros/tests/from_request/pass/state_infer_multiple.rs b/axum-macros/tests/from_request/pass/state_infer_multiple.rs new file mode 100644 index 00000000..6729e615 --- /dev/null +++ b/axum-macros/tests/from_request/pass/state_infer_multiple.rs @@ -0,0 +1,19 @@ +use axum_macros::FromRequest; +use axum::extract::State; + +#[derive(FromRequest)] +struct Extractor { + inner_state: State, + also_inner_state: State, +} + +#[derive(Clone)] +struct AppState {} + +fn assert_from_request() +where + Extractor: axum::extract::FromRequest, +{ +} + +fn main() {} diff --git a/axum-macros/tests/from_request/pass/state_infer_parts.rs b/axum-macros/tests/from_request/pass/state_infer_parts.rs new file mode 100644 index 00000000..f3f078c5 --- /dev/null +++ b/axum-macros/tests/from_request/pass/state_infer_parts.rs @@ -0,0 +1,18 @@ +use axum_macros::FromRequestParts; +use axum::extract::State; + +#[derive(FromRequestParts)] +struct Extractor { + inner_state: State, +} + +#[derive(Clone)] +struct AppState {} + +fn assert_from_request() +where + Extractor: axum::extract::FromRequestParts, +{ +} + +fn main() {} diff --git a/axum-macros/tests/from_request/pass/state_via.rs b/axum-macros/tests/from_request/pass/state_via.rs new file mode 100644 index 00000000..b7196a39 --- /dev/null +++ b/axum-macros/tests/from_request/pass/state_via.rs @@ -0,0 +1,28 @@ +use axum::{ + extract::{FromRef, State}, + routing::get, + Router, +}; +use axum_macros::FromRequest; + +fn main() { + let _: Router = Router::with_state(AppState::default()) + .route("/b", get(|_: (), _: AppState| async {})) + .route("/c", get(|_: (), _: InnerState| async {})); +} + +#[derive(Clone, Default, FromRequest)] +#[from_request(via(State), state(AppState))] +struct AppState { + inner: InnerState, +} + +#[derive(Clone, Default, FromRequest)] +#[from_request(via(State), state(AppState))] +struct InnerState {} + +impl FromRef for InnerState { + fn from_ref(input: &AppState) -> Self { + input.inner.clone() + } +} diff --git a/axum-macros/tests/from_request/pass/state_via_infer.rs b/axum-macros/tests/from_request/pass/state_via_infer.rs new file mode 100644 index 00000000..75b170a6 --- /dev/null +++ b/axum-macros/tests/from_request/pass/state_via_infer.rs @@ -0,0 +1,17 @@ +use axum::{ + extract::State, + routing::get, + Router, +}; +use axum_macros::FromRequest; + +fn main() { + let _: Router = Router::with_state(AppState::default()) + .route("/b", get(|_: AppState| async {})); +} + +// if we're extract "via" `State` and not specifying state +// assume `AppState` is the state +#[derive(Clone, Default, FromRequest)] +#[from_request(via(State))] +struct AppState {} diff --git a/axum-macros/tests/from_request/pass/state_via_parts.rs b/axum-macros/tests/from_request/pass/state_via_parts.rs new file mode 100644 index 00000000..b747f474 --- /dev/null +++ b/axum-macros/tests/from_request/pass/state_via_parts.rs @@ -0,0 +1,29 @@ +use axum::{ + extract::{FromRef, State}, + routing::get, + Router, +}; +use axum_macros::FromRequestParts; + +fn main() { + let _: Router = Router::with_state(AppState::default()) + .route("/a", get(|_: AppState, _: InnerState, _: String| async {})) + .route("/b", get(|_: AppState, _: String| async {})) + .route("/c", get(|_: InnerState, _: String| async {})); +} + +#[derive(Clone, Default, FromRequestParts)] +#[from_request(via(State))] +struct AppState { + inner: InnerState, +} + +#[derive(Clone, Default, FromRequestParts)] +#[from_request(via(State), state(AppState))] +struct InnerState {} + +impl FromRef for InnerState { + fn from_ref(input: &AppState) -> Self { + input.inner.clone() + } +} diff --git a/axum-macros/tests/from_request/pass/state_with_rejection.rs b/axum-macros/tests/from_request/pass/state_with_rejection.rs new file mode 100644 index 00000000..82ecfe3b --- /dev/null +++ b/axum-macros/tests/from_request/pass/state_with_rejection.rs @@ -0,0 +1,36 @@ +use std::convert::Infallible; +use axum::{ + extract::State, + response::{IntoResponse, Response}, + routing::get, + Router, +}; +use axum_macros::FromRequest; + +fn main() { + let _: Router = + Router::with_state(AppState::default()).route("/a", get(|_: Extractor| async {})); +} + +#[derive(Clone, Default, FromRequest)] +#[from_request(rejection(MyRejection))] +struct Extractor { + state: State, +} + +#[derive(Clone, Default)] +struct AppState {} + +struct MyRejection {} + +impl From for MyRejection { + fn from(err: Infallible) -> Self { + match err {} + } +} + +impl IntoResponse for MyRejection { + fn into_response(self) -> Response { + ().into_response() + } +}