From 7705ef666131a11ae77ed8998ea9bb78bfb2f62b Mon Sep 17 00:00:00 2001 From: Jonas Platte <jplatte+git@posteo.de> Date: Tue, 23 Aug 2022 21:14:02 +0200 Subject: [PATCH] Add `#[derive(FromRequestParts)]` (#1305) * Add missing leading double colon * Separate handling of last element in FromRequest derive * FromRequestParts derive * fix it and add lots of tests * docs * changelog * Update axum-macros/src/lib.rs Co-authored-by: Jonas Platte <jplatte+git@posteo.de> Co-authored-by: David Pedersen <david.pdrsn@gmail.com> --- axum-macros/CHANGELOG.md | 3 + axum-macros/src/from_request.rs | 565 ++++++++++++------ axum-macros/src/lib.rs | 53 +- .../fail/parts_extracting_body.rs | 15 + .../fail/parts_extracting_body.stderr | 16 + .../from_request/pass/container_parts.rs | 21 + .../from_request/pass/empty_named_parts.rs | 12 + .../from_request/pass/empty_tuple_parts.rs | 12 + .../tests/from_request/pass/enum_via_parts.rs | 12 + .../tests/from_request/pass/named_parts.rs | 23 + .../from_request/pass/named_via_parts.rs | 34 ++ .../pass/override_rejection_non_generic.rs | 39 ++ .../override_rejection_non_generic_parts.rs | 39 ++ .../pass/override_rejection_parts.rs | 61 ++ ...erride_rejection_with_via_on_enum_parts.rs | 33 + ...ride_rejection_with_via_on_struct_parts.rs | 40 ++ .../tests/from_request/pass/tuple_parts.rs | 12 + .../pass/tuple_same_type_twice_parts.rs | 20 + .../pass/tuple_same_type_twice_via_parts.rs | 21 + .../from_request/pass/tuple_via_parts.rs | 16 + .../tests/from_request/pass/unit_parts.rs | 12 + 21 files changed, 887 insertions(+), 172 deletions(-) create mode 100644 axum-macros/tests/from_request/fail/parts_extracting_body.rs create mode 100644 axum-macros/tests/from_request/fail/parts_extracting_body.stderr create mode 100644 axum-macros/tests/from_request/pass/container_parts.rs create mode 100644 axum-macros/tests/from_request/pass/empty_named_parts.rs create mode 100644 axum-macros/tests/from_request/pass/empty_tuple_parts.rs create mode 100644 axum-macros/tests/from_request/pass/enum_via_parts.rs create mode 100644 axum-macros/tests/from_request/pass/named_parts.rs create mode 100644 axum-macros/tests/from_request/pass/named_via_parts.rs create mode 100644 axum-macros/tests/from_request/pass/override_rejection_non_generic.rs create mode 100644 axum-macros/tests/from_request/pass/override_rejection_non_generic_parts.rs create mode 100644 axum-macros/tests/from_request/pass/override_rejection_parts.rs create mode 100644 axum-macros/tests/from_request/pass/override_rejection_with_via_on_enum_parts.rs create mode 100644 axum-macros/tests/from_request/pass/override_rejection_with_via_on_struct_parts.rs create mode 100644 axum-macros/tests/from_request/pass/tuple_parts.rs create mode 100644 axum-macros/tests/from_request/pass/tuple_same_type_twice_parts.rs create mode 100644 axum-macros/tests/from_request/pass/tuple_same_type_twice_via_parts.rs create mode 100644 axum-macros/tests/from_request/pass/tuple_via_parts.rs create mode 100644 axum-macros/tests/from_request/pass/unit_parts.rs diff --git a/axum-macros/CHANGELOG.md b/axum-macros/CHANGELOG.md index 619c5bd7..d9112308 100644 --- a/axum-macros/CHANGELOG.md +++ b/axum-macros/CHANGELOG.md @@ -14,10 +14,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 enum but instead generate `type Rejection = axum::response::Response`. Use the new `#[from_request(rejection(MyRejection))]` attribute to change this. The `rejection_derive` attribute has also been removed ([#1272]) +- **added:** Add `#[derive(FromRequestParts)]` for deriving an implementation of + `FromRequestParts`, similarly to `#[derive(FromRequest)]` ([#1305]) [#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1256]: https://github.com/tokio-rs/axum/pull/1256 [#1272]: https://github.com/tokio-rs/axum/pull/1272 +[#1305]: https://github.com/tokio-rs/axum/pull/1305 # 0.2.3 (27. June, 2022) diff --git a/axum-macros/src/from_request.rs b/axum-macros/src/from_request.rs index f92eaeb4..e1c026a8 100644 --- a/axum-macros/src/from_request.rs +++ b/axum-macros/src/from_request.rs @@ -3,11 +3,27 @@ use self::attr::{ }; use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; +use std::fmt; use syn::{punctuated::Punctuated, spanned::Spanned, Ident, Token}; mod attr; -pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> { +#[derive(Clone, Copy)] +pub(crate) enum Trait { + FromRequest, + FromRequestParts, +} + +impl fmt::Display for Trait { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Trait::FromRequest => f.write_str("FromRequest"), + Trait::FromRequestParts => f.write_str("FromRequestParts"), + } + } +} + +pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> { match item { syn::Item::Struct(item) => { let syn::ItemStruct { @@ -20,7 +36,7 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> { struct_token: _, } = item; - let generic_ident = parse_single_generic_type_on_struct(generics, &fields)?; + let generic_ident = parse_single_generic_type_on_struct(generics, &fields, tr)?; match parse_container_attrs(&attrs)? { FromRequestContainerAttr::Via { path, rejection } => { @@ -30,17 +46,18 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> { path, rejection, generic_ident, + tr, ) } FromRequestContainerAttr::Rejection(rejection) => { - error_on_generic_ident(generic_ident)?; + error_on_generic_ident(generic_ident, tr)?; - impl_struct_by_extracting_each_field(ident, fields, Some(rejection)) + impl_struct_by_extracting_each_field(ident, fields, Some(rejection), tr) } FromRequestContainerAttr::None => { - error_on_generic_ident(generic_ident)?; + error_on_generic_ident(generic_ident, tr)?; - impl_struct_by_extracting_each_field(ident, fields, None) + impl_struct_by_extracting_each_field(ident, fields, None, tr) } } } @@ -55,19 +72,19 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> { variants, } = item; - const GENERICS_ERROR: &str = "`#[derive(FromRequest)] on enums don't support generics"; + let generics_error = format!("`#[derive({tr})] on enums don't support generics"); if !generics.params.is_empty() { - return Err(syn::Error::new_spanned(generics, GENERICS_ERROR)); + return Err(syn::Error::new_spanned(generics, generics_error)); } if let Some(where_clause) = generics.where_clause { - return Err(syn::Error::new_spanned(where_clause, GENERICS_ERROR)); + return Err(syn::Error::new_spanned(where_clause, generics_error)); } match parse_container_attrs(&attrs)? { FromRequestContainerAttr::Via { path, rejection } => { - impl_enum_by_extracting_all_at_once(ident, variants, path, rejection) + impl_enum_by_extracting_all_at_once(ident, variants, path, rejection, tr) } FromRequestContainerAttr::Rejection(rejection) => Err(syn::Error::new_spanned( rejection, @@ -86,11 +103,12 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> { fn parse_single_generic_type_on_struct( generics: syn::Generics, fields: &syn::Fields, + tr: Trait, ) -> syn::Result<Option<Ident>> { if let Some(where_clause) = generics.where_clause { return Err(syn::Error::new_spanned( where_clause, - "#[derive(FromRequest)] doesn't support structs with `where` clauses", + format_args!("#[derive({tr})] doesn't support structs with `where` clauses"), )); } @@ -103,13 +121,19 @@ fn parse_single_generic_type_on_struct( syn::GenericParam::Lifetime(lifetime) => { return Err(syn::Error::new_spanned( lifetime, - "#[derive(FromRequest)] doesn't support structs that are generic over lifetimes", + format_args!( + "#[derive({tr})] doesn't support structs \ + that are generic over lifetimes" + ), )); } syn::GenericParam::Const(konst) => { return Err(syn::Error::new_spanned( konst, - "#[derive(FromRequest)] doesn't support structs that have const generics", + format_args!( + "#[derive({tr})] doesn't support structs \ + that have const generics" + ), )); } }; @@ -118,14 +142,20 @@ fn parse_single_generic_type_on_struct( syn::Fields::Named(fields_named) => { return Err(syn::Error::new_spanned( fields_named, - "#[derive(FromRequest)] doesn't support named fields for generic structs. Use a tuple struct instead", + format_args!( + "#[derive({tr})] doesn't support named fields \ + for generic structs. Use a tuple struct instead" + ), )); } syn::Fields::Unnamed(fields_unnamed) => { if fields_unnamed.unnamed.len() != 1 { return Err(syn::Error::new_spanned( fields_unnamed, - "#[derive(FromRequest)] only supports generics on tuple structs that have exactly one field", + format_args!( + "#[derive({tr})] only supports generics on \ + tuple structs that have exactly one field" + ), )); } @@ -139,7 +169,10 @@ fn parse_single_generic_type_on_struct( { return Err(syn::Error::new_spanned( type_path, - "#[derive(FromRequest)] only supports generics on tuple structs that have exactly one field of the generic type", + format_args!( + "#[derive({tr})] only supports generics on \ + tuple structs that have exactly one field of the generic type" + ), )); } } else { @@ -153,16 +186,18 @@ fn parse_single_generic_type_on_struct( } _ => Err(syn::Error::new_spanned( generics, - "#[derive(FromRequest)] only supports 0 or 1 generic type parameters", + format_args!("#[derive({tr})] only supports 0 or 1 generic type parameters"), )), } } -fn error_on_generic_ident(generic_ident: Option<Ident>) -> syn::Result<()> { +fn error_on_generic_ident(generic_ident: Option<Ident>, tr: Trait) -> syn::Result<()> { if let Some(generic_ident) = generic_ident { Err(syn::Error::new_spanned( generic_ident, - "#[derive(FromRequest)] only supports generics when used with #[from_request(via)]", + format_args!( + "#[derive({tr})] only supports generics when used with #[from_request(via)]" + ), )) } else { Ok(()) @@ -173,8 +208,9 @@ fn impl_struct_by_extracting_each_field( ident: syn::Ident, fields: syn::Fields, rejection: Option<syn::Path>, + tr: Trait, ) -> syn::Result<TokenStream> { - let extract_fields = extract_fields(&fields, &rejection)?; + let extract_fields = extract_fields(&fields, &rejection, tr)?; let rejection_ident = if let Some(rejection) = rejection { quote!(#rejection) @@ -184,27 +220,48 @@ fn impl_struct_by_extracting_each_field( quote!(::axum::response::Response) }; - Ok(quote! { - #[::axum::async_trait] - #[automatically_derived] - impl<S, B> ::axum::extract::FromRequest<S, B> for #ident - where - B: ::axum::body::HttpBody + ::std::marker::Send + 'static, - B::Data: ::std::marker::Send, - B::Error: ::std::convert::Into<::axum::BoxError>, - S: ::std::marker::Send + ::std::marker::Sync, - { - type Rejection = #rejection_ident; + Ok(match tr { + Trait::FromRequest => quote! { + #[::axum::async_trait] + #[automatically_derived] + impl<S, B> ::axum::extract::FromRequest<S, B> for #ident + where + B: ::axum::body::HttpBody + ::std::marker::Send + 'static, + B::Data: ::std::marker::Send, + B::Error: ::std::convert::Into<::axum::BoxError>, + S: ::std::marker::Send + ::std::marker::Sync, + { + type Rejection = #rejection_ident; - async fn from_request( - mut req: axum::http::Request<B>, - state: &S, - ) -> ::std::result::Result<Self, Self::Rejection> { - ::std::result::Result::Ok(Self { - #(#extract_fields)* - }) + async fn from_request( + mut req: ::axum::http::Request<B>, + state: &S, + ) -> ::std::result::Result<Self, Self::Rejection> { + ::std::result::Result::Ok(Self { + #(#extract_fields)* + }) + } } - } + }, + Trait::FromRequestParts => quote! { + #[::axum::async_trait] + #[automatically_derived] + impl<S> ::axum::extract::FromRequestParts<S> for #ident + where + S: ::std::marker::Send + ::std::marker::Sync, + { + type Rejection = #rejection_ident; + + async fn from_request_parts( + parts: &mut ::axum::http::request::Parts, + state: &S, + ) -> ::std::result::Result<Self, Self::Rejection> { + ::std::result::Result::Ok(Self { + #(#extract_fields)* + }) + } + } + }, }) } @@ -219,82 +276,118 @@ fn has_no_fields(fields: &syn::Fields) -> bool { fn extract_fields( fields: &syn::Fields, rejection: &Option<syn::Path>, + tr: Trait, ) -> syn::Result<Vec<TokenStream>> { - fields - .iter() - .enumerate() - .map(|(index, field)| { - let is_last = fields.len() - 1 == index; - - let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?; - - let member = if let Some(ident) = &field.ident { - quote! { #ident } - } else { + fn member(field: &syn::Field, index: usize) -> TokenStream { + match &field.ident { + Some(ident) => quote! { #ident }, + _ => { let member = syn::Member::Unnamed(syn::Index { index: index as u32, span: field.span(), }); quote! { #member } - }; + } + } + } + fn into_inner(via: Option<(attr::kw::via, syn::Path)>, ty_span: Span) -> TokenStream { + if let Some((_, path)) = via { + let span = path.span(); + quote_spanned! {span=> + |#path(inner)| inner + } + } else { + quote_spanned! {ty_span=> + ::std::convert::identity + } + } + } + + let mut fields_iter = fields.iter(); + + let last = match tr { + // Use FromRequestParts for all elements except the last + Trait::FromRequest => fields_iter.next_back(), + // Use FromRequestParts for all elements + Trait::FromRequestParts => None, + }; + + let mut res: Vec<_> = fields_iter + .enumerate() + .map(|(index, field)| { + let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?; + + let member = member(field, index); let ty_span = field.ty.span(); - - let into_inner = if let Some((_, path)) = via { - let span = path.span(); - quote_spanned! {span=> - |#path(inner)| inner - } - } else { - quote_spanned! {ty_span=> - ::std::convert::identity - } - }; + let into_inner = into_inner(via, ty_span); if peel_option(&field.ty).is_some() { - if is_last { - Ok(quote_spanned! {ty_span=> - #member: { - ::axum::extract::FromRequest::from_request(req, state) + let tokens = match tr { + Trait::FromRequest => { + quote_spanned! {ty_span=> + #member: { + let (mut parts, body) = req.into_parts(); + let value = + ::axum::extract::FromRequestParts::from_request_parts( + &mut parts, + state, + ) + .await + .ok() + .map(#into_inner); + req = ::axum::http::Request::from_parts(parts, body); + value + }, + } + } + Trait::FromRequestParts => { + quote_spanned! {ty_span=> + #member: { + ::axum::extract::FromRequestParts::from_request_parts( + parts, + state, + ) .await .ok() .map(#into_inner) - }, - }) - } else { - Ok(quote_spanned! {ty_span=> - #member: { - let (mut parts, body) = req.into_parts(); - let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state) - .await - .ok() - .map(#into_inner); - req = ::axum::http::Request::from_parts(parts, body); - value - }, - }) - } + }, + } + } + }; + Ok(tokens) } else if peel_result_ok(&field.ty).is_some() { - if is_last { - Ok(quote_spanned! {ty_span=> - #member: { - ::axum::extract::FromRequest::from_request(req, state) + let tokens = match tr { + Trait::FromRequest => { + quote_spanned! {ty_span=> + #member: { + let (mut parts, body) = req.into_parts(); + let value = + ::axum::extract::FromRequestParts::from_request_parts( + &mut parts, + state, + ) + .await + .map(#into_inner); + req = ::axum::http::Request::from_parts(parts, body); + value + }, + } + } + Trait::FromRequestParts => { + quote_spanned! {ty_span=> + #member: { + ::axum::extract::FromRequestParts::from_request_parts( + parts, + state, + ) .await .map(#into_inner) - }, - }) - } else { - Ok(quote_spanned! {ty_span=> - #member: { - let (mut parts, body) = req.into_parts(); - let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state) - .await - .map(#into_inner); - req = ::axum::http::Request::from_parts(parts, body); - value - }, - }) - } + }, + } + } + }; + Ok(tokens) } else { let map_err = if let Some(rejection) = rejection { quote! { <#rejection as ::std::convert::From<_>>::from } @@ -302,31 +395,89 @@ fn extract_fields( quote! { ::axum::response::IntoResponse::into_response } }; - if is_last { - Ok(quote_spanned! {ty_span=> - #member: { - ::axum::extract::FromRequest::from_request(req, state) + let tokens = match tr { + Trait::FromRequest => { + quote_spanned! {ty_span=> + #member: { + let (mut parts, body) = req.into_parts(); + let value = + ::axum::extract::FromRequestParts::from_request_parts( + &mut parts, + state, + ) + .await + .map(#into_inner) + .map_err(#map_err)?; + req = ::axum::http::Request::from_parts(parts, body); + value + }, + } + } + Trait::FromRequestParts => { + quote_spanned! {ty_span=> + #member: { + ::axum::extract::FromRequestParts::from_request_parts( + parts, + state, + ) .await .map(#into_inner) .map_err(#map_err)? - }, - }) - } else { - Ok(quote_spanned! {ty_span=> - #member: { - let (mut parts, body) = req.into_parts(); - let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state) - .await - .map(#into_inner) - .map_err(#map_err)?; - req = ::axum::http::Request::from_parts(parts, body); - value - }, - }) - } + }, + } + } + }; + Ok(tokens) } }) - .collect() + .collect::<syn::Result<_>>()?; + + // Handle the last element, if deriving FromRequest + if let Some(field) = last { + let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?; + + let member = member(field, fields.len() - 1); + let ty_span = field.ty.span(); + let into_inner = into_inner(via, ty_span); + + let item = if peel_option(&field.ty).is_some() { + quote_spanned! {ty_span=> + #member: { + ::axum::extract::FromRequest::from_request(req, state) + .await + .ok() + .map(#into_inner) + }, + } + } else if peel_result_ok(&field.ty).is_some() { + quote_spanned! {ty_span=> + #member: { + ::axum::extract::FromRequest::from_request(req, state) + .await + .map(#into_inner) + }, + } + } else { + let map_err = if let Some(rejection) = rejection { + quote! { <#rejection as ::std::convert::From<_>>::from } + } else { + quote! { ::axum::response::IntoResponse::into_response } + }; + + quote_spanned! {ty_span=> + #member: { + ::axum::extract::FromRequest::from_request(req, state) + .await + .map(#into_inner) + .map_err(#map_err)? + }, + } + }; + + res.push(item); + } + + Ok(res) } fn peel_option(ty: &syn::Type) -> Option<&syn::Type> { @@ -397,6 +548,7 @@ fn impl_struct_by_extracting_all_at_once( path: syn::Path, rejection: Option<syn::Path>, generic_ident: Option<Ident>, + tr: Trait, ) -> syn::Result<TokenStream> { let fields = match fields { syn::Fields::Named(fields) => fields.named.into_iter(), @@ -430,21 +582,35 @@ fn impl_struct_by_extracting_all_at_once( }; let rejection_bound = rejection.as_ref().map(|rejection| { - if generic_ident.is_some() { - quote! { - #rejection: ::std::convert::From<<#path<T> as ::axum::extract::FromRequest<S, B>>::Rejection>, - } - } else { - quote! { - #rejection: ::std::convert::From<<#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection>, + match (tr, generic_ident.is_some()) { + (Trait::FromRequest, true) => { + quote! { + #rejection: ::std::convert::From<<#path<T> as ::axum::extract::FromRequest<S, B>>::Rejection>, + } + }, + (Trait::FromRequest, false) => { + quote! { + #rejection: ::std::convert::From<<#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection>, + } + }, + (Trait::FromRequestParts, true) => { + quote! { + #rejection: ::std::convert::From<<#path<T> as ::axum::extract::FromRequestParts<S>>::Rejection>, + } + }, + (Trait::FromRequestParts, false) => { + quote! { + #rejection: ::std::convert::From<<#path<Self> as ::axum::extract::FromRequestParts<S>>::Rejection>, + } } } }).unwrap_or_default(); - let impl_generics = if generic_ident.is_some() { - quote! { S, B, T } - } else { - quote! { S, B } + let impl_generics = match (tr, generic_ident.is_some()) { + (Trait::FromRequest, true) => quote! { S, B, T }, + (Trait::FromRequest, false) => quote! { S, B }, + (Trait::FromRequestParts, true) => quote! { S, T }, + (Trait::FromRequestParts, false) => quote! { S }, }; let type_generics = generic_ident @@ -466,29 +632,59 @@ fn impl_struct_by_extracting_all_at_once( quote! { value } }; - Ok(quote_spanned! {path_span=> - #[::axum::async_trait] - #[automatically_derived] - impl<#impl_generics> ::axum::extract::FromRequest<S, B> for #ident #type_generics - where - #path<#via_type_generics>: ::axum::extract::FromRequest<S, B>, - #rejection_bound - B: ::std::marker::Send + 'static, - S: ::std::marker::Send + ::std::marker::Sync, - { - type Rejection = #associated_rejection_type; + let tokens = match tr { + Trait::FromRequest => { + quote_spanned! {path_span=> + #[::axum::async_trait] + #[automatically_derived] + impl<#impl_generics> ::axum::extract::FromRequest<S, B> for #ident #type_generics + where + #path<#via_type_generics>: ::axum::extract::FromRequest<S, B>, + #rejection_bound + B: ::std::marker::Send + 'static, + S: ::std::marker::Send + ::std::marker::Sync, + { + type Rejection = #associated_rejection_type; - async fn from_request( - req: ::axum::http::Request<B>, - state: &S - ) -> ::std::result::Result<Self, Self::Rejection> { - ::axum::extract::FromRequest::from_request(req, state) - .await - .map(|#path(value)| #value_to_self) - .map_err(#map_err) + async fn from_request( + req: ::axum::http::Request<B>, + state: &S + ) -> ::std::result::Result<Self, Self::Rejection> { + ::axum::extract::FromRequest::from_request(req, state) + .await + .map(|#path(value)| #value_to_self) + .map_err(#map_err) + } + } } } - }) + Trait::FromRequestParts => { + quote_spanned! {path_span=> + #[::axum::async_trait] + #[automatically_derived] + impl<#impl_generics> ::axum::extract::FromRequestParts<S> for #ident #type_generics + where + #path<#via_type_generics>: ::axum::extract::FromRequestParts<S>, + #rejection_bound + S: ::std::marker::Send + ::std::marker::Sync, + { + type Rejection = #associated_rejection_type; + + async fn from_request_parts( + parts: &mut ::axum::http::request::Parts, + state: &S + ) -> ::std::result::Result<Self, Self::Rejection> { + ::axum::extract::FromRequestParts::from_request_parts(parts, state) + .await + .map(|#path(value)| #value_to_self) + .map_err(#map_err) + } + } + } + } + }; + + Ok(tokens) } fn impl_enum_by_extracting_all_at_once( @@ -496,6 +692,7 @@ fn impl_enum_by_extracting_all_at_once( variants: Punctuated<syn::Variant, Token![,]>, path: syn::Path, rejection: Option<syn::Path>, + tr: Trait, ) -> syn::Result<TokenStream> { for variant in variants { let FromRequestFieldAttr { via } = parse_field_attrs(&variant.attrs)?; @@ -537,29 +734,57 @@ fn impl_enum_by_extracting_all_at_once( let path_span = path.span(); - Ok(quote_spanned! {path_span=> - #[::axum::async_trait] - #[automatically_derived] - impl<S, B> ::axum::extract::FromRequest<S, B> for #ident - where - B: ::axum::body::HttpBody + ::std::marker::Send + 'static, - B::Data: ::std::marker::Send, - B::Error: ::std::convert::Into<::axum::BoxError>, - S: ::std::marker::Send + ::std::marker::Sync, - { - type Rejection = #associated_rejection_type; + let tokens = match tr { + Trait::FromRequest => { + quote_spanned! {path_span=> + #[::axum::async_trait] + #[automatically_derived] + impl<S, B> ::axum::extract::FromRequest<S, B> for #ident + where + B: ::axum::body::HttpBody + ::std::marker::Send + 'static, + B::Data: ::std::marker::Send, + B::Error: ::std::convert::Into<::axum::BoxError>, + S: ::std::marker::Send + ::std::marker::Sync, + { + type Rejection = #associated_rejection_type; - async fn from_request( - req: ::axum::http::Request<B>, - state: &S - ) -> ::std::result::Result<Self, Self::Rejection> { - ::axum::extract::FromRequest::from_request(req, state) - .await - .map(|#path(inner)| inner) - .map_err(#map_err) + async fn from_request( + req: ::axum::http::Request<B>, + state: &S + ) -> ::std::result::Result<Self, Self::Rejection> { + ::axum::extract::FromRequest::from_request(req, state) + .await + .map(|#path(inner)| inner) + .map_err(#map_err) + } + } } } - }) + Trait::FromRequestParts => { + quote_spanned! {path_span=> + #[::axum::async_trait] + #[automatically_derived] + impl<S> ::axum::extract::FromRequestParts<S> for #ident + where + S: ::std::marker::Send + ::std::marker::Sync, + { + type Rejection = #associated_rejection_type; + + async fn from_request_parts( + parts: &mut ::axum::http::request::Parts, + state: &S + ) -> ::std::result::Result<Self, Self::Rejection> { + ::axum::extract::FromRequestParts::from_request_parts(parts, state) + .await + .map(|#path(inner)| inner) + .map_err(#map_err) + } + } + } + } + }; + + Ok(tokens) } #[test] diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index a3ebf2a5..63f17b0b 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -52,6 +52,8 @@ mod from_request; mod typed_path; mod with_position; +use from_request::Trait::{FromRequest, FromRequestParts}; + /// Derive an implementation of [`FromRequest`]. /// /// Supports generating two kinds of implementations: @@ -87,6 +89,8 @@ mod with_position; /// /// This requires that each field is an extractor (i.e. implements [`FromRequest`]). /// +/// Note that only the last field can consume the request body. Therefore this doesn't compile: +/// /// ```compile_fail /// use axum_macros::FromRequest; /// use axum::body::Bytes; @@ -99,7 +103,6 @@ mod with_position; /// string: String, /// } /// ``` -/// Note that only the last field can consume the request body. Therefore this doesn't compile: /// /// ## Extracting via another extractor /// @@ -353,7 +356,53 @@ mod with_position; /// [`axum::extract::rejection::ExtensionRejection`]: https://docs.rs/axum/latest/axum/extract/rejection/enum.ExtensionRejection.html #[proc_macro_derive(FromRequest, attributes(from_request))] pub fn derive_from_request(item: TokenStream) -> TokenStream { - expand_with(item, from_request::expand) + expand_with(item, |item| from_request::expand(item, FromRequest)) +} + +/// Derive an implementation of [`FromRequestParts`]. +/// +/// This works similarly to `#[derive(FromRequest)]` except it uses [`FromRequestParts`]. All the +/// same options are supported. +/// +/// # Example +/// +/// ``` +/// use axum_macros::FromRequestParts; +/// use axum::{ +/// extract::{Query, TypedHeader}, +/// headers::ContentType, +/// }; +/// use std::collections::HashMap; +/// +/// #[derive(FromRequestParts)] +/// struct MyExtractor { +/// #[from_request(via(Query))] +/// query_params: HashMap<String, String>, +/// content_type: TypedHeader<ContentType>, +/// } +/// +/// async fn handler(extractor: MyExtractor) {} +/// ``` +/// +/// # Cannot extract the body +/// +/// [`FromRequestParts`] cannot extract the request body: +/// +/// ```compile_fail +/// use axum_macros::FromRequestParts; +/// +/// #[derive(FromRequestParts)] +/// struct MyExtractor { +/// body: String, +/// } +/// ``` +/// +/// Use `#[derive(FromRequest)]` for that. +/// +/// [`FromRequestParts`]: https://docs.rs/axum/0.6/axum/extract/trait.FromRequestParts.html +#[proc_macro_derive(FromRequestParts, attributes(from_request))] +pub fn derive_from_request_parts(item: TokenStream) -> TokenStream { + expand_with(item, |item| from_request::expand(item, FromRequestParts)) } /// Generates better error messages when applied handler functions. diff --git a/axum-macros/tests/from_request/fail/parts_extracting_body.rs b/axum-macros/tests/from_request/fail/parts_extracting_body.rs new file mode 100644 index 00000000..18fb312d --- /dev/null +++ b/axum-macros/tests/from_request/fail/parts_extracting_body.rs @@ -0,0 +1,15 @@ +use axum::{extract::FromRequestParts, response::Response}; +use axum_macros::FromRequestParts; + +#[derive(FromRequestParts)] +struct Extractor { + body: String, +} + +fn assert_from_request() +where + Extractor: FromRequestParts<(), Rejection = Response>, +{ +} + +fn main() {} diff --git a/axum-macros/tests/from_request/fail/parts_extracting_body.stderr b/axum-macros/tests/from_request/fail/parts_extracting_body.stderr new file mode 100644 index 00000000..2a2c8040 --- /dev/null +++ b/axum-macros/tests/from_request/fail/parts_extracting_body.stderr @@ -0,0 +1,16 @@ +error[E0277]: the trait bound `String: FromRequestParts<S>` is not satisfied + --> tests/from_request/fail/parts_extracting_body.rs:6:11 + | +6 | body: String, + | ^^^^^^ the trait `FromRequestParts<S>` is not implemented for `String` + | + = help: the following other types implement trait `FromRequestParts<S>`: + <() as FromRequestParts<S>> + <(T1, T2) as FromRequestParts<S>> + <(T1, T2, T3) as FromRequestParts<S>> + <(T1, T2, T3, T4) as FromRequestParts<S>> + <(T1, T2, T3, T4, T5) as FromRequestParts<S>> + <(T1, T2, T3, T4, T5, T6) as FromRequestParts<S>> + <(T1, T2, T3, T4, T5, T6, T7) as FromRequestParts<S>> + <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequestParts<S>> + and 27 others diff --git a/axum-macros/tests/from_request/pass/container_parts.rs b/axum-macros/tests/from_request/pass/container_parts.rs new file mode 100644 index 00000000..c3dabe54 --- /dev/null +++ b/axum-macros/tests/from_request/pass/container_parts.rs @@ -0,0 +1,21 @@ +use axum::{ + extract::{FromRequestParts, Extension}, + response::Response, +}; +use axum_macros::FromRequestParts; + +#[derive(Clone, FromRequestParts)] +#[from_request(via(Extension))] +struct Extractor { + one: i32, + two: String, + three: bool, +} + +fn assert_from_request() +where + Extractor: FromRequestParts<(), Rejection = Response>, +{ +} + +fn main() {} diff --git a/axum-macros/tests/from_request/pass/empty_named_parts.rs b/axum-macros/tests/from_request/pass/empty_named_parts.rs new file mode 100644 index 00000000..20194b68 --- /dev/null +++ b/axum-macros/tests/from_request/pass/empty_named_parts.rs @@ -0,0 +1,12 @@ +use axum_macros::FromRequestParts; + +#[derive(FromRequestParts)] +struct Extractor {} + +fn assert_from_request() +where + Extractor: axum::extract::FromRequestParts<(), Rejection = std::convert::Infallible>, +{ +} + +fn main() {} diff --git a/axum-macros/tests/from_request/pass/empty_tuple_parts.rs b/axum-macros/tests/from_request/pass/empty_tuple_parts.rs new file mode 100644 index 00000000..ade2125b --- /dev/null +++ b/axum-macros/tests/from_request/pass/empty_tuple_parts.rs @@ -0,0 +1,12 @@ +use axum_macros::FromRequestParts; + +#[derive(FromRequestParts)] +struct Extractor(); + +fn assert_from_request() +where + Extractor: axum::extract::FromRequestParts<(), Rejection = std::convert::Infallible>, +{ +} + +fn main() {} diff --git a/axum-macros/tests/from_request/pass/enum_via_parts.rs b/axum-macros/tests/from_request/pass/enum_via_parts.rs new file mode 100644 index 00000000..5e18d922 --- /dev/null +++ b/axum-macros/tests/from_request/pass/enum_via_parts.rs @@ -0,0 +1,12 @@ +use axum::{body::Body, routing::get, Extension, Router}; +use axum_macros::FromRequestParts; + +#[derive(FromRequestParts, Clone)] +#[from_request(via(Extension))] +enum Extractor {} + +async fn foo(_: Extractor) {} + +fn main() { + Router::<(), Body>::new().route("/", get(foo)); +} diff --git a/axum-macros/tests/from_request/pass/named_parts.rs b/axum-macros/tests/from_request/pass/named_parts.rs new file mode 100644 index 00000000..b997a0dd --- /dev/null +++ b/axum-macros/tests/from_request/pass/named_parts.rs @@ -0,0 +1,23 @@ +use axum::{ + extract::{rejection::TypedHeaderRejection, FromRequestParts, TypedHeader}, + headers::{self, UserAgent}, + response::Response, +}; +use axum_macros::FromRequestParts; + +#[derive(FromRequestParts)] +struct Extractor { + uri: axum::http::Uri, + user_agent: TypedHeader<UserAgent>, + content_type: TypedHeader<headers::ContentType>, + etag: Option<TypedHeader<headers::ETag>>, + host: Result<TypedHeader<headers::Host>, TypedHeaderRejection>, +} + +fn assert_from_request() +where + Extractor: FromRequestParts<(), Rejection = Response>, +{ +} + +fn main() {} diff --git a/axum-macros/tests/from_request/pass/named_via_parts.rs b/axum-macros/tests/from_request/pass/named_via_parts.rs new file mode 100644 index 00000000..bdf1ac6e --- /dev/null +++ b/axum-macros/tests/from_request/pass/named_via_parts.rs @@ -0,0 +1,34 @@ +use axum::{ + response::Response, + extract::{ + rejection::TypedHeaderRejection, + Extension, FromRequestParts, TypedHeader, + }, + headers::{self, UserAgent}, +}; +use axum_macros::FromRequestParts; + +#[derive(FromRequestParts)] +struct Extractor { + #[from_request(via(Extension))] + state: State, + #[from_request(via(TypedHeader))] + user_agent: UserAgent, + #[from_request(via(TypedHeader))] + content_type: headers::ContentType, + #[from_request(via(TypedHeader))] + etag: Option<headers::ETag>, + #[from_request(via(TypedHeader))] + host: Result<headers::Host, TypedHeaderRejection>, +} + +fn assert_from_request() +where + Extractor: FromRequestParts<(), Rejection = Response>, +{ +} + +#[derive(Clone)] +struct State; + +fn main() {} diff --git a/axum-macros/tests/from_request/pass/override_rejection_non_generic.rs b/axum-macros/tests/from_request/pass/override_rejection_non_generic.rs new file mode 100644 index 00000000..6c4d87fe --- /dev/null +++ b/axum-macros/tests/from_request/pass/override_rejection_non_generic.rs @@ -0,0 +1,39 @@ +use axum::{ + extract::rejection::JsonRejection, + response::{IntoResponse, Response}, + routing::get, + Router, +}; +use axum_macros::FromRequest; +use std::collections::HashMap; +use serde::Deserialize; + +fn main() { + let _: Router = Router::new().route("/", get(handler).post(handler_result)); +} + +async fn handler(_: MyJson) {} + +async fn handler_result(_: Result<MyJson, MyJsonRejection>) {} + +#[derive(FromRequest, Deserialize)] +#[from_request( + via(axum::extract::Json), + rejection(MyJsonRejection), +)] +#[serde(transparent)] +struct MyJson(HashMap<String, String>); + +struct MyJsonRejection {} + +impl From<JsonRejection> for MyJsonRejection { + fn from(_: JsonRejection) -> Self { + todo!() + } +} + +impl IntoResponse for MyJsonRejection { + fn into_response(self) -> Response { + todo!() + } +} diff --git a/axum-macros/tests/from_request/pass/override_rejection_non_generic_parts.rs b/axum-macros/tests/from_request/pass/override_rejection_non_generic_parts.rs new file mode 100644 index 00000000..9aca7345 --- /dev/null +++ b/axum-macros/tests/from_request/pass/override_rejection_non_generic_parts.rs @@ -0,0 +1,39 @@ +use axum::{ + extract::rejection::QueryRejection, + response::{IntoResponse, Response}, + routing::get, + Router, +}; +use axum_macros::FromRequestParts; +use std::collections::HashMap; +use serde::Deserialize; + +fn main() { + let _: Router = Router::new().route("/", get(handler).post(handler_result)); +} + +async fn handler(_: MyQuery) {} + +async fn handler_result(_: Result<MyQuery, MyQueryRejection>) {} + +#[derive(FromRequestParts, Deserialize)] +#[from_request( + via(axum::extract::Query), + rejection(MyQueryRejection), +)] +#[serde(transparent)] +struct MyQuery(HashMap<String, String>); + +struct MyQueryRejection {} + +impl From<QueryRejection> for MyQueryRejection { + fn from(_: QueryRejection) -> Self { + todo!() + } +} + +impl IntoResponse for MyQueryRejection { + fn into_response(self) -> Response { + todo!() + } +} diff --git a/axum-macros/tests/from_request/pass/override_rejection_parts.rs b/axum-macros/tests/from_request/pass/override_rejection_parts.rs new file mode 100644 index 00000000..8e2271c3 --- /dev/null +++ b/axum-macros/tests/from_request/pass/override_rejection_parts.rs @@ -0,0 +1,61 @@ +use axum::{ + async_trait, + extract::{rejection::ExtensionRejection, FromRequestParts}, + http::{request::Parts, StatusCode}, + response::{IntoResponse, Response}, + routing::get, + Extension, Router, +}; +use axum_macros::FromRequestParts; + +fn main() { + let _: Router = Router::new().route("/", get(handler).post(handler_result)); +} + +async fn handler(_: MyExtractor) {} + +async fn handler_result(_: Result<MyExtractor, MyRejection>) {} + +#[derive(FromRequestParts)] +#[from_request(rejection(MyRejection))] +struct MyExtractor { + one: Extension<String>, + #[from_request(via(Extension))] + two: String, + three: OtherExtractor, +} + +struct OtherExtractor; + +#[async_trait] +impl<S> FromRequestParts<S> for OtherExtractor +where + S: Send + Sync, +{ + // this rejection doesn't implement `Display` and `Error` + type Rejection = (StatusCode, String); + + async fn from_request_parts(_parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { + todo!() + } +} + +struct MyRejection {} + +impl From<ExtensionRejection> for MyRejection { + fn from(_: ExtensionRejection) -> Self { + todo!() + } +} + +impl From<(StatusCode, String)> for MyRejection { + fn from(_: (StatusCode, String)) -> Self { + todo!() + } +} + +impl IntoResponse for MyRejection { + fn into_response(self) -> Response { + todo!() + } +} diff --git a/axum-macros/tests/from_request/pass/override_rejection_with_via_on_enum_parts.rs b/axum-macros/tests/from_request/pass/override_rejection_with_via_on_enum_parts.rs new file mode 100644 index 00000000..bd1fc062 --- /dev/null +++ b/axum-macros/tests/from_request/pass/override_rejection_with_via_on_enum_parts.rs @@ -0,0 +1,33 @@ +use axum::{ + extract::rejection::ExtensionRejection, + response::{IntoResponse, Response}, + routing::get, + Router, +}; +use axum_macros::FromRequestParts; + +fn main() { + let _: Router = Router::new().route("/", get(handler).post(handler_result)); +} + +async fn handler(_: MyExtractor) {} + +async fn handler_result(_: Result<MyExtractor, MyRejection>) {} + +#[derive(FromRequestParts, Clone)] +#[from_request(via(axum::Extension), rejection(MyRejection))] +enum MyExtractor {} + +struct MyRejection {} + +impl From<ExtensionRejection> for MyRejection { + fn from(_: ExtensionRejection) -> Self { + todo!() + } +} + +impl IntoResponse for MyRejection { + fn into_response(self) -> Response { + todo!() + } +} diff --git a/axum-macros/tests/from_request/pass/override_rejection_with_via_on_struct_parts.rs b/axum-macros/tests/from_request/pass/override_rejection_with_via_on_struct_parts.rs new file mode 100644 index 00000000..eaeeeacf --- /dev/null +++ b/axum-macros/tests/from_request/pass/override_rejection_with_via_on_struct_parts.rs @@ -0,0 +1,40 @@ +use axum::{ + extract::rejection::QueryRejection, + response::{IntoResponse, Response}, + routing::get, + Router, +}; +use axum_macros::FromRequestParts; +use serde::Deserialize; + +fn main() { + let _: Router = Router::new().route("/", get(handler).post(handler_result)); +} + +#[derive(Deserialize)] +struct Payload {} + +async fn handler(_: MyQuery<Payload>) {} + +async fn handler_result(_: Result<MyQuery<Payload>, MyQueryRejection>) {} + +#[derive(FromRequestParts)] +#[from_request( + via(axum::extract::Query), + rejection(MyQueryRejection), +)] +struct MyQuery<T>(T); + +struct MyQueryRejection {} + +impl From<QueryRejection> for MyQueryRejection { + fn from(_: QueryRejection) -> Self { + todo!() + } +} + +impl IntoResponse for MyQueryRejection { + fn into_response(self) -> Response { + todo!() + } +} diff --git a/axum-macros/tests/from_request/pass/tuple_parts.rs b/axum-macros/tests/from_request/pass/tuple_parts.rs new file mode 100644 index 00000000..80a9458b --- /dev/null +++ b/axum-macros/tests/from_request/pass/tuple_parts.rs @@ -0,0 +1,12 @@ +use axum_macros::FromRequestParts; + +#[derive(FromRequestParts)] +struct Extractor(axum::http::HeaderMap, axum::http::Method); + +fn assert_from_request() +where + Extractor: axum::extract::FromRequestParts<()>, +{ +} + +fn main() {} diff --git a/axum-macros/tests/from_request/pass/tuple_same_type_twice_parts.rs b/axum-macros/tests/from_request/pass/tuple_same_type_twice_parts.rs new file mode 100644 index 00000000..44c42dc5 --- /dev/null +++ b/axum-macros/tests/from_request/pass/tuple_same_type_twice_parts.rs @@ -0,0 +1,20 @@ +use axum::extract::Query; +use axum_macros::FromRequestParts; +use serde::Deserialize; + +#[derive(FromRequestParts)] +struct Extractor( + Query<Payload>, + axum::extract::Path<Payload>, +); + +#[derive(Deserialize)] +struct Payload {} + +fn assert_from_request() +where + Extractor: axum::extract::FromRequestParts<()>, +{ +} + +fn main() {} diff --git a/axum-macros/tests/from_request/pass/tuple_same_type_twice_via_parts.rs b/axum-macros/tests/from_request/pass/tuple_same_type_twice_via_parts.rs new file mode 100644 index 00000000..10de8708 --- /dev/null +++ b/axum-macros/tests/from_request/pass/tuple_same_type_twice_via_parts.rs @@ -0,0 +1,21 @@ +use axum::extract::Query; +use axum::response::Response; +use axum_macros::FromRequestParts; +use serde::Deserialize; + +#[derive(FromRequestParts)] +struct Extractor( + #[from_request(via(Query))] Payload, + #[from_request(via(axum::extract::Path))] Payload, +); + +#[derive(Deserialize)] +struct Payload {} + +fn assert_from_request() +where + Extractor: axum::extract::FromRequestParts<(), Rejection = Response>, +{ +} + +fn main() {} diff --git a/axum-macros/tests/from_request/pass/tuple_via_parts.rs b/axum-macros/tests/from_request/pass/tuple_via_parts.rs new file mode 100644 index 00000000..fac99911 --- /dev/null +++ b/axum-macros/tests/from_request/pass/tuple_via_parts.rs @@ -0,0 +1,16 @@ +use axum::Extension; +use axum_macros::FromRequestParts; + +#[derive(FromRequestParts)] +struct Extractor(#[from_request(via(Extension))] State); + +#[derive(Clone)] +struct State; + +fn assert_from_request() +where + Extractor: axum::extract::FromRequestParts<()>, +{ +} + +fn main() {} diff --git a/axum-macros/tests/from_request/pass/unit_parts.rs b/axum-macros/tests/from_request/pass/unit_parts.rs new file mode 100644 index 00000000..06c07344 --- /dev/null +++ b/axum-macros/tests/from_request/pass/unit_parts.rs @@ -0,0 +1,12 @@ +use axum_macros::FromRequestParts; + +#[derive(FromRequestParts)] +struct Extractor; + +fn assert_from_request() +where + Extractor: axum::extract::FromRequestParts<(), Rejection = std::convert::Infallible>, +{ +} + +fn main() {}