Add #[derive(FromRequestParts)] (#1305)

* Add missing leading double colon

* Separate handling of last element in FromRequest derive

* FromRequestParts derive

* fix it and add lots of tests

* docs

* changelog

* Update axum-macros/src/lib.rs

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
Jonas Platte 2022-08-23 21:14:02 +02:00 committed by GitHub
parent db08419a3b
commit 7705ef6661
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 887 additions and 172 deletions

View file

@ -14,10 +14,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
enum but instead generate `type Rejection = axum::response::Response`. Use the
new `#[from_request(rejection(MyRejection))]` attribute to change this.
The `rejection_derive` attribute has also been removed ([#1272])
- **added:** Add `#[derive(FromRequestParts)]` for deriving an implementation of
`FromRequestParts`, similarly to `#[derive(FromRequest)]` ([#1305])
[#1239]: https://github.com/tokio-rs/axum/pull/1239
[#1256]: https://github.com/tokio-rs/axum/pull/1256
[#1272]: https://github.com/tokio-rs/axum/pull/1272
[#1305]: https://github.com/tokio-rs/axum/pull/1305
# 0.2.3 (27. June, 2022)

View file

@ -3,11 +3,27 @@ use self::attr::{
};
use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned};
use std::fmt;
use syn::{punctuated::Punctuated, spanned::Spanned, Ident, Token};
mod attr;
pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> {
#[derive(Clone, Copy)]
pub(crate) enum Trait {
FromRequest,
FromRequestParts,
}
impl fmt::Display for Trait {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Trait::FromRequest => f.write_str("FromRequest"),
Trait::FromRequestParts => f.write_str("FromRequestParts"),
}
}
}
pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
match item {
syn::Item::Struct(item) => {
let syn::ItemStruct {
@ -20,7 +36,7 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> {
struct_token: _,
} = item;
let generic_ident = parse_single_generic_type_on_struct(generics, &fields)?;
let generic_ident = parse_single_generic_type_on_struct(generics, &fields, tr)?;
match parse_container_attrs(&attrs)? {
FromRequestContainerAttr::Via { path, rejection } => {
@ -30,17 +46,18 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> {
path,
rejection,
generic_ident,
tr,
)
}
FromRequestContainerAttr::Rejection(rejection) => {
error_on_generic_ident(generic_ident)?;
error_on_generic_ident(generic_ident, tr)?;
impl_struct_by_extracting_each_field(ident, fields, Some(rejection))
impl_struct_by_extracting_each_field(ident, fields, Some(rejection), tr)
}
FromRequestContainerAttr::None => {
error_on_generic_ident(generic_ident)?;
error_on_generic_ident(generic_ident, tr)?;
impl_struct_by_extracting_each_field(ident, fields, None)
impl_struct_by_extracting_each_field(ident, fields, None, tr)
}
}
}
@ -55,19 +72,19 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> {
variants,
} = item;
const GENERICS_ERROR: &str = "`#[derive(FromRequest)] on enums don't support generics";
let generics_error = format!("`#[derive({tr})] on enums don't support generics");
if !generics.params.is_empty() {
return Err(syn::Error::new_spanned(generics, GENERICS_ERROR));
return Err(syn::Error::new_spanned(generics, generics_error));
}
if let Some(where_clause) = generics.where_clause {
return Err(syn::Error::new_spanned(where_clause, GENERICS_ERROR));
return Err(syn::Error::new_spanned(where_clause, generics_error));
}
match parse_container_attrs(&attrs)? {
FromRequestContainerAttr::Via { path, rejection } => {
impl_enum_by_extracting_all_at_once(ident, variants, path, rejection)
impl_enum_by_extracting_all_at_once(ident, variants, path, rejection, tr)
}
FromRequestContainerAttr::Rejection(rejection) => Err(syn::Error::new_spanned(
rejection,
@ -86,11 +103,12 @@ pub(crate) fn expand(item: syn::Item) -> syn::Result<TokenStream> {
fn parse_single_generic_type_on_struct(
generics: syn::Generics,
fields: &syn::Fields,
tr: Trait,
) -> syn::Result<Option<Ident>> {
if let Some(where_clause) = generics.where_clause {
return Err(syn::Error::new_spanned(
where_clause,
"#[derive(FromRequest)] doesn't support structs with `where` clauses",
format_args!("#[derive({tr})] doesn't support structs with `where` clauses"),
));
}
@ -103,13 +121,19 @@ fn parse_single_generic_type_on_struct(
syn::GenericParam::Lifetime(lifetime) => {
return Err(syn::Error::new_spanned(
lifetime,
"#[derive(FromRequest)] doesn't support structs that are generic over lifetimes",
format_args!(
"#[derive({tr})] doesn't support structs \
that are generic over lifetimes"
),
));
}
syn::GenericParam::Const(konst) => {
return Err(syn::Error::new_spanned(
konst,
"#[derive(FromRequest)] doesn't support structs that have const generics",
format_args!(
"#[derive({tr})] doesn't support structs \
that have const generics"
),
));
}
};
@ -118,14 +142,20 @@ fn parse_single_generic_type_on_struct(
syn::Fields::Named(fields_named) => {
return Err(syn::Error::new_spanned(
fields_named,
"#[derive(FromRequest)] doesn't support named fields for generic structs. Use a tuple struct instead",
format_args!(
"#[derive({tr})] doesn't support named fields \
for generic structs. Use a tuple struct instead"
),
));
}
syn::Fields::Unnamed(fields_unnamed) => {
if fields_unnamed.unnamed.len() != 1 {
return Err(syn::Error::new_spanned(
fields_unnamed,
"#[derive(FromRequest)] only supports generics on tuple structs that have exactly one field",
format_args!(
"#[derive({tr})] only supports generics on \
tuple structs that have exactly one field"
),
));
}
@ -139,7 +169,10 @@ fn parse_single_generic_type_on_struct(
{
return Err(syn::Error::new_spanned(
type_path,
"#[derive(FromRequest)] only supports generics on tuple structs that have exactly one field of the generic type",
format_args!(
"#[derive({tr})] only supports generics on \
tuple structs that have exactly one field of the generic type"
),
));
}
} else {
@ -153,16 +186,18 @@ fn parse_single_generic_type_on_struct(
}
_ => Err(syn::Error::new_spanned(
generics,
"#[derive(FromRequest)] only supports 0 or 1 generic type parameters",
format_args!("#[derive({tr})] only supports 0 or 1 generic type parameters"),
)),
}
}
fn error_on_generic_ident(generic_ident: Option<Ident>) -> syn::Result<()> {
fn error_on_generic_ident(generic_ident: Option<Ident>, tr: Trait) -> syn::Result<()> {
if let Some(generic_ident) = generic_ident {
Err(syn::Error::new_spanned(
generic_ident,
"#[derive(FromRequest)] only supports generics when used with #[from_request(via)]",
format_args!(
"#[derive({tr})] only supports generics when used with #[from_request(via)]"
),
))
} else {
Ok(())
@ -173,8 +208,9 @@ fn impl_struct_by_extracting_each_field(
ident: syn::Ident,
fields: syn::Fields,
rejection: Option<syn::Path>,
tr: Trait,
) -> syn::Result<TokenStream> {
let extract_fields = extract_fields(&fields, &rejection)?;
let extract_fields = extract_fields(&fields, &rejection, tr)?;
let rejection_ident = if let Some(rejection) = rejection {
quote!(#rejection)
@ -184,27 +220,48 @@ fn impl_struct_by_extracting_each_field(
quote!(::axum::response::Response)
};
Ok(quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where
B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
B::Data: ::std::marker::Send,
B::Error: ::std::convert::Into<::axum::BoxError>,
S: ::std::marker::Send + ::std::marker::Sync,
{
type Rejection = #rejection_ident;
Ok(match tr {
Trait::FromRequest => quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where
B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
B::Data: ::std::marker::Send,
B::Error: ::std::convert::Into<::axum::BoxError>,
S: ::std::marker::Send + ::std::marker::Sync,
{
type Rejection = #rejection_ident;
async fn from_request(
mut req: axum::http::Request<B>,
state: &S,
) -> ::std::result::Result<Self, Self::Rejection> {
::std::result::Result::Ok(Self {
#(#extract_fields)*
})
async fn from_request(
mut req: ::axum::http::Request<B>,
state: &S,
) -> ::std::result::Result<Self, Self::Rejection> {
::std::result::Result::Ok(Self {
#(#extract_fields)*
})
}
}
}
},
Trait::FromRequestParts => quote! {
#[::axum::async_trait]
#[automatically_derived]
impl<S> ::axum::extract::FromRequestParts<S> for #ident
where
S: ::std::marker::Send + ::std::marker::Sync,
{
type Rejection = #rejection_ident;
async fn from_request_parts(
parts: &mut ::axum::http::request::Parts,
state: &S,
) -> ::std::result::Result<Self, Self::Rejection> {
::std::result::Result::Ok(Self {
#(#extract_fields)*
})
}
}
},
})
}
@ -219,82 +276,118 @@ fn has_no_fields(fields: &syn::Fields) -> bool {
fn extract_fields(
fields: &syn::Fields,
rejection: &Option<syn::Path>,
tr: Trait,
) -> syn::Result<Vec<TokenStream>> {
fields
.iter()
.enumerate()
.map(|(index, field)| {
let is_last = fields.len() - 1 == index;
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
let member = if let Some(ident) = &field.ident {
quote! { #ident }
} else {
fn member(field: &syn::Field, index: usize) -> TokenStream {
match &field.ident {
Some(ident) => quote! { #ident },
_ => {
let member = syn::Member::Unnamed(syn::Index {
index: index as u32,
span: field.span(),
});
quote! { #member }
};
}
}
}
fn into_inner(via: Option<(attr::kw::via, syn::Path)>, ty_span: Span) -> TokenStream {
if let Some((_, path)) = via {
let span = path.span();
quote_spanned! {span=>
|#path(inner)| inner
}
} else {
quote_spanned! {ty_span=>
::std::convert::identity
}
}
}
let mut fields_iter = fields.iter();
let last = match tr {
// Use FromRequestParts for all elements except the last
Trait::FromRequest => fields_iter.next_back(),
// Use FromRequestParts for all elements
Trait::FromRequestParts => None,
};
let mut res: Vec<_> = fields_iter
.enumerate()
.map(|(index, field)| {
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
let member = member(field, index);
let ty_span = field.ty.span();
let into_inner = if let Some((_, path)) = via {
let span = path.span();
quote_spanned! {span=>
|#path(inner)| inner
}
} else {
quote_spanned! {ty_span=>
::std::convert::identity
}
};
let into_inner = into_inner(via, ty_span);
if peel_option(&field.ty).is_some() {
if is_last {
Ok(quote_spanned! {ty_span=>
#member: {
::axum::extract::FromRequest::from_request(req, state)
let tokens = match tr {
Trait::FromRequest => {
quote_spanned! {ty_span=>
#member: {
let (mut parts, body) = req.into_parts();
let value =
::axum::extract::FromRequestParts::from_request_parts(
&mut parts,
state,
)
.await
.ok()
.map(#into_inner);
req = ::axum::http::Request::from_parts(parts, body);
value
},
}
}
Trait::FromRequestParts => {
quote_spanned! {ty_span=>
#member: {
::axum::extract::FromRequestParts::from_request_parts(
parts,
state,
)
.await
.ok()
.map(#into_inner)
},
})
} else {
Ok(quote_spanned! {ty_span=>
#member: {
let (mut parts, body) = req.into_parts();
let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state)
.await
.ok()
.map(#into_inner);
req = ::axum::http::Request::from_parts(parts, body);
value
},
})
}
},
}
}
};
Ok(tokens)
} else if peel_result_ok(&field.ty).is_some() {
if is_last {
Ok(quote_spanned! {ty_span=>
#member: {
::axum::extract::FromRequest::from_request(req, state)
let tokens = match tr {
Trait::FromRequest => {
quote_spanned! {ty_span=>
#member: {
let (mut parts, body) = req.into_parts();
let value =
::axum::extract::FromRequestParts::from_request_parts(
&mut parts,
state,
)
.await
.map(#into_inner);
req = ::axum::http::Request::from_parts(parts, body);
value
},
}
}
Trait::FromRequestParts => {
quote_spanned! {ty_span=>
#member: {
::axum::extract::FromRequestParts::from_request_parts(
parts,
state,
)
.await
.map(#into_inner)
},
})
} else {
Ok(quote_spanned! {ty_span=>
#member: {
let (mut parts, body) = req.into_parts();
let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state)
.await
.map(#into_inner);
req = ::axum::http::Request::from_parts(parts, body);
value
},
})
}
},
}
}
};
Ok(tokens)
} else {
let map_err = if let Some(rejection) = rejection {
quote! { <#rejection as ::std::convert::From<_>>::from }
@ -302,31 +395,89 @@ fn extract_fields(
quote! { ::axum::response::IntoResponse::into_response }
};
if is_last {
Ok(quote_spanned! {ty_span=>
#member: {
::axum::extract::FromRequest::from_request(req, state)
let tokens = match tr {
Trait::FromRequest => {
quote_spanned! {ty_span=>
#member: {
let (mut parts, body) = req.into_parts();
let value =
::axum::extract::FromRequestParts::from_request_parts(
&mut parts,
state,
)
.await
.map(#into_inner)
.map_err(#map_err)?;
req = ::axum::http::Request::from_parts(parts, body);
value
},
}
}
Trait::FromRequestParts => {
quote_spanned! {ty_span=>
#member: {
::axum::extract::FromRequestParts::from_request_parts(
parts,
state,
)
.await
.map(#into_inner)
.map_err(#map_err)?
},
})
} else {
Ok(quote_spanned! {ty_span=>
#member: {
let (mut parts, body) = req.into_parts();
let value = ::axum::extract::FromRequestParts::from_request_parts(&mut parts, state)
.await
.map(#into_inner)
.map_err(#map_err)?;
req = ::axum::http::Request::from_parts(parts, body);
value
},
})
}
},
}
}
};
Ok(tokens)
}
})
.collect()
.collect::<syn::Result<_>>()?;
// Handle the last element, if deriving FromRequest
if let Some(field) = last {
let FromRequestFieldAttr { via } = parse_field_attrs(&field.attrs)?;
let member = member(field, fields.len() - 1);
let ty_span = field.ty.span();
let into_inner = into_inner(via, ty_span);
let item = if peel_option(&field.ty).is_some() {
quote_spanned! {ty_span=>
#member: {
::axum::extract::FromRequest::from_request(req, state)
.await
.ok()
.map(#into_inner)
},
}
} else if peel_result_ok(&field.ty).is_some() {
quote_spanned! {ty_span=>
#member: {
::axum::extract::FromRequest::from_request(req, state)
.await
.map(#into_inner)
},
}
} else {
let map_err = if let Some(rejection) = rejection {
quote! { <#rejection as ::std::convert::From<_>>::from }
} else {
quote! { ::axum::response::IntoResponse::into_response }
};
quote_spanned! {ty_span=>
#member: {
::axum::extract::FromRequest::from_request(req, state)
.await
.map(#into_inner)
.map_err(#map_err)?
},
}
};
res.push(item);
}
Ok(res)
}
fn peel_option(ty: &syn::Type) -> Option<&syn::Type> {
@ -397,6 +548,7 @@ fn impl_struct_by_extracting_all_at_once(
path: syn::Path,
rejection: Option<syn::Path>,
generic_ident: Option<Ident>,
tr: Trait,
) -> syn::Result<TokenStream> {
let fields = match fields {
syn::Fields::Named(fields) => fields.named.into_iter(),
@ -430,21 +582,35 @@ fn impl_struct_by_extracting_all_at_once(
};
let rejection_bound = rejection.as_ref().map(|rejection| {
if generic_ident.is_some() {
quote! {
#rejection: ::std::convert::From<<#path<T> as ::axum::extract::FromRequest<S, B>>::Rejection>,
}
} else {
quote! {
#rejection: ::std::convert::From<<#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection>,
match (tr, generic_ident.is_some()) {
(Trait::FromRequest, true) => {
quote! {
#rejection: ::std::convert::From<<#path<T> as ::axum::extract::FromRequest<S, B>>::Rejection>,
}
},
(Trait::FromRequest, false) => {
quote! {
#rejection: ::std::convert::From<<#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection>,
}
},
(Trait::FromRequestParts, true) => {
quote! {
#rejection: ::std::convert::From<<#path<T> as ::axum::extract::FromRequestParts<S>>::Rejection>,
}
},
(Trait::FromRequestParts, false) => {
quote! {
#rejection: ::std::convert::From<<#path<Self> as ::axum::extract::FromRequestParts<S>>::Rejection>,
}
}
}
}).unwrap_or_default();
let impl_generics = if generic_ident.is_some() {
quote! { S, B, T }
} else {
quote! { S, B }
let impl_generics = match (tr, generic_ident.is_some()) {
(Trait::FromRequest, true) => quote! { S, B, T },
(Trait::FromRequest, false) => quote! { S, B },
(Trait::FromRequestParts, true) => quote! { S, T },
(Trait::FromRequestParts, false) => quote! { S },
};
let type_generics = generic_ident
@ -466,29 +632,59 @@ fn impl_struct_by_extracting_all_at_once(
quote! { value }
};
Ok(quote_spanned! {path_span=>
#[::axum::async_trait]
#[automatically_derived]
impl<#impl_generics> ::axum::extract::FromRequest<S, B> for #ident #type_generics
where
#path<#via_type_generics>: ::axum::extract::FromRequest<S, B>,
#rejection_bound
B: ::std::marker::Send + 'static,
S: ::std::marker::Send + ::std::marker::Sync,
{
type Rejection = #associated_rejection_type;
let tokens = match tr {
Trait::FromRequest => {
quote_spanned! {path_span=>
#[::axum::async_trait]
#[automatically_derived]
impl<#impl_generics> ::axum::extract::FromRequest<S, B> for #ident #type_generics
where
#path<#via_type_generics>: ::axum::extract::FromRequest<S, B>,
#rejection_bound
B: ::std::marker::Send + 'static,
S: ::std::marker::Send + ::std::marker::Sync,
{
type Rejection = #associated_rejection_type;
async fn from_request(
req: ::axum::http::Request<B>,
state: &S
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequest::from_request(req, state)
.await
.map(|#path(value)| #value_to_self)
.map_err(#map_err)
async fn from_request(
req: ::axum::http::Request<B>,
state: &S
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequest::from_request(req, state)
.await
.map(|#path(value)| #value_to_self)
.map_err(#map_err)
}
}
}
}
})
Trait::FromRequestParts => {
quote_spanned! {path_span=>
#[::axum::async_trait]
#[automatically_derived]
impl<#impl_generics> ::axum::extract::FromRequestParts<S> for #ident #type_generics
where
#path<#via_type_generics>: ::axum::extract::FromRequestParts<S>,
#rejection_bound
S: ::std::marker::Send + ::std::marker::Sync,
{
type Rejection = #associated_rejection_type;
async fn from_request_parts(
parts: &mut ::axum::http::request::Parts,
state: &S
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequestParts::from_request_parts(parts, state)
.await
.map(|#path(value)| #value_to_self)
.map_err(#map_err)
}
}
}
}
};
Ok(tokens)
}
fn impl_enum_by_extracting_all_at_once(
@ -496,6 +692,7 @@ fn impl_enum_by_extracting_all_at_once(
variants: Punctuated<syn::Variant, Token![,]>,
path: syn::Path,
rejection: Option<syn::Path>,
tr: Trait,
) -> syn::Result<TokenStream> {
for variant in variants {
let FromRequestFieldAttr { via } = parse_field_attrs(&variant.attrs)?;
@ -537,29 +734,57 @@ fn impl_enum_by_extracting_all_at_once(
let path_span = path.span();
Ok(quote_spanned! {path_span=>
#[::axum::async_trait]
#[automatically_derived]
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where
B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
B::Data: ::std::marker::Send,
B::Error: ::std::convert::Into<::axum::BoxError>,
S: ::std::marker::Send + ::std::marker::Sync,
{
type Rejection = #associated_rejection_type;
let tokens = match tr {
Trait::FromRequest => {
quote_spanned! {path_span=>
#[::axum::async_trait]
#[automatically_derived]
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where
B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
B::Data: ::std::marker::Send,
B::Error: ::std::convert::Into<::axum::BoxError>,
S: ::std::marker::Send + ::std::marker::Sync,
{
type Rejection = #associated_rejection_type;
async fn from_request(
req: ::axum::http::Request<B>,
state: &S
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequest::from_request(req, state)
.await
.map(|#path(inner)| inner)
.map_err(#map_err)
async fn from_request(
req: ::axum::http::Request<B>,
state: &S
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequest::from_request(req, state)
.await
.map(|#path(inner)| inner)
.map_err(#map_err)
}
}
}
}
})
Trait::FromRequestParts => {
quote_spanned! {path_span=>
#[::axum::async_trait]
#[automatically_derived]
impl<S> ::axum::extract::FromRequestParts<S> for #ident
where
S: ::std::marker::Send + ::std::marker::Sync,
{
type Rejection = #associated_rejection_type;
async fn from_request_parts(
parts: &mut ::axum::http::request::Parts,
state: &S
) -> ::std::result::Result<Self, Self::Rejection> {
::axum::extract::FromRequestParts::from_request_parts(parts, state)
.await
.map(|#path(inner)| inner)
.map_err(#map_err)
}
}
}
}
};
Ok(tokens)
}
#[test]

View file

@ -52,6 +52,8 @@ mod from_request;
mod typed_path;
mod with_position;
use from_request::Trait::{FromRequest, FromRequestParts};
/// Derive an implementation of [`FromRequest`].
///
/// Supports generating two kinds of implementations:
@ -87,6 +89,8 @@ mod with_position;
///
/// This requires that each field is an extractor (i.e. implements [`FromRequest`]).
///
/// Note that only the last field can consume the request body. Therefore this doesn't compile:
///
/// ```compile_fail
/// use axum_macros::FromRequest;
/// use axum::body::Bytes;
@ -99,7 +103,6 @@ mod with_position;
/// string: String,
/// }
/// ```
/// Note that only the last field can consume the request body. Therefore this doesn't compile:
///
/// ## Extracting via another extractor
///
@ -353,7 +356,53 @@ mod with_position;
/// [`axum::extract::rejection::ExtensionRejection`]: https://docs.rs/axum/latest/axum/extract/rejection/enum.ExtensionRejection.html
#[proc_macro_derive(FromRequest, attributes(from_request))]
pub fn derive_from_request(item: TokenStream) -> TokenStream {
expand_with(item, from_request::expand)
expand_with(item, |item| from_request::expand(item, FromRequest))
}
/// Derive an implementation of [`FromRequestParts`].
///
/// This works similarly to `#[derive(FromRequest)]` except it uses [`FromRequestParts`]. All the
/// same options are supported.
///
/// # Example
///
/// ```
/// use axum_macros::FromRequestParts;
/// use axum::{
/// extract::{Query, TypedHeader},
/// headers::ContentType,
/// };
/// use std::collections::HashMap;
///
/// #[derive(FromRequestParts)]
/// struct MyExtractor {
/// #[from_request(via(Query))]
/// query_params: HashMap<String, String>,
/// content_type: TypedHeader<ContentType>,
/// }
///
/// async fn handler(extractor: MyExtractor) {}
/// ```
///
/// # Cannot extract the body
///
/// [`FromRequestParts`] cannot extract the request body:
///
/// ```compile_fail
/// use axum_macros::FromRequestParts;
///
/// #[derive(FromRequestParts)]
/// struct MyExtractor {
/// body: String,
/// }
/// ```
///
/// Use `#[derive(FromRequest)]` for that.
///
/// [`FromRequestParts`]: https://docs.rs/axum/0.6/axum/extract/trait.FromRequestParts.html
#[proc_macro_derive(FromRequestParts, attributes(from_request))]
pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
expand_with(item, |item| from_request::expand(item, FromRequestParts))
}
/// Generates better error messages when applied handler functions.

View file

@ -0,0 +1,15 @@
use axum::{extract::FromRequestParts, response::Response};
use axum_macros::FromRequestParts;
#[derive(FromRequestParts)]
struct Extractor {
body: String,
}
fn assert_from_request()
where
Extractor: FromRequestParts<(), Rejection = Response>,
{
}
fn main() {}

View file

@ -0,0 +1,16 @@
error[E0277]: the trait bound `String: FromRequestParts<S>` is not satisfied
--> tests/from_request/fail/parts_extracting_body.rs:6:11
|
6 | body: String,
| ^^^^^^ the trait `FromRequestParts<S>` is not implemented for `String`
|
= help: the following other types implement trait `FromRequestParts<S>`:
<() as FromRequestParts<S>>
<(T1, T2) as FromRequestParts<S>>
<(T1, T2, T3) as FromRequestParts<S>>
<(T1, T2, T3, T4) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6, T7) as FromRequestParts<S>>
<(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequestParts<S>>
and 27 others

View file

@ -0,0 +1,21 @@
use axum::{
extract::{FromRequestParts, Extension},
response::Response,
};
use axum_macros::FromRequestParts;
#[derive(Clone, FromRequestParts)]
#[from_request(via(Extension))]
struct Extractor {
one: i32,
two: String,
three: bool,
}
fn assert_from_request()
where
Extractor: FromRequestParts<(), Rejection = Response>,
{
}
fn main() {}

View file

@ -0,0 +1,12 @@
use axum_macros::FromRequestParts;
#[derive(FromRequestParts)]
struct Extractor {}
fn assert_from_request()
where
Extractor: axum::extract::FromRequestParts<(), Rejection = std::convert::Infallible>,
{
}
fn main() {}

View file

@ -0,0 +1,12 @@
use axum_macros::FromRequestParts;
#[derive(FromRequestParts)]
struct Extractor();
fn assert_from_request()
where
Extractor: axum::extract::FromRequestParts<(), Rejection = std::convert::Infallible>,
{
}
fn main() {}

View file

@ -0,0 +1,12 @@
use axum::{body::Body, routing::get, Extension, Router};
use axum_macros::FromRequestParts;
#[derive(FromRequestParts, Clone)]
#[from_request(via(Extension))]
enum Extractor {}
async fn foo(_: Extractor) {}
fn main() {
Router::<(), Body>::new().route("/", get(foo));
}

View file

@ -0,0 +1,23 @@
use axum::{
extract::{rejection::TypedHeaderRejection, FromRequestParts, TypedHeader},
headers::{self, UserAgent},
response::Response,
};
use axum_macros::FromRequestParts;
#[derive(FromRequestParts)]
struct Extractor {
uri: axum::http::Uri,
user_agent: TypedHeader<UserAgent>,
content_type: TypedHeader<headers::ContentType>,
etag: Option<TypedHeader<headers::ETag>>,
host: Result<TypedHeader<headers::Host>, TypedHeaderRejection>,
}
fn assert_from_request()
where
Extractor: FromRequestParts<(), Rejection = Response>,
{
}
fn main() {}

View file

@ -0,0 +1,34 @@
use axum::{
response::Response,
extract::{
rejection::TypedHeaderRejection,
Extension, FromRequestParts, TypedHeader,
},
headers::{self, UserAgent},
};
use axum_macros::FromRequestParts;
#[derive(FromRequestParts)]
struct Extractor {
#[from_request(via(Extension))]
state: State,
#[from_request(via(TypedHeader))]
user_agent: UserAgent,
#[from_request(via(TypedHeader))]
content_type: headers::ContentType,
#[from_request(via(TypedHeader))]
etag: Option<headers::ETag>,
#[from_request(via(TypedHeader))]
host: Result<headers::Host, TypedHeaderRejection>,
}
fn assert_from_request()
where
Extractor: FromRequestParts<(), Rejection = Response>,
{
}
#[derive(Clone)]
struct State;
fn main() {}

View file

@ -0,0 +1,39 @@
use axum::{
extract::rejection::JsonRejection,
response::{IntoResponse, Response},
routing::get,
Router,
};
use axum_macros::FromRequest;
use std::collections::HashMap;
use serde::Deserialize;
fn main() {
let _: Router = Router::new().route("/", get(handler).post(handler_result));
}
async fn handler(_: MyJson) {}
async fn handler_result(_: Result<MyJson, MyJsonRejection>) {}
#[derive(FromRequest, Deserialize)]
#[from_request(
via(axum::extract::Json),
rejection(MyJsonRejection),
)]
#[serde(transparent)]
struct MyJson(HashMap<String, String>);
struct MyJsonRejection {}
impl From<JsonRejection> for MyJsonRejection {
fn from(_: JsonRejection) -> Self {
todo!()
}
}
impl IntoResponse for MyJsonRejection {
fn into_response(self) -> Response {
todo!()
}
}

View file

@ -0,0 +1,39 @@
use axum::{
extract::rejection::QueryRejection,
response::{IntoResponse, Response},
routing::get,
Router,
};
use axum_macros::FromRequestParts;
use std::collections::HashMap;
use serde::Deserialize;
fn main() {
let _: Router = Router::new().route("/", get(handler).post(handler_result));
}
async fn handler(_: MyQuery) {}
async fn handler_result(_: Result<MyQuery, MyQueryRejection>) {}
#[derive(FromRequestParts, Deserialize)]
#[from_request(
via(axum::extract::Query),
rejection(MyQueryRejection),
)]
#[serde(transparent)]
struct MyQuery(HashMap<String, String>);
struct MyQueryRejection {}
impl From<QueryRejection> for MyQueryRejection {
fn from(_: QueryRejection) -> Self {
todo!()
}
}
impl IntoResponse for MyQueryRejection {
fn into_response(self) -> Response {
todo!()
}
}

View file

@ -0,0 +1,61 @@
use axum::{
async_trait,
extract::{rejection::ExtensionRejection, FromRequestParts},
http::{request::Parts, StatusCode},
response::{IntoResponse, Response},
routing::get,
Extension, Router,
};
use axum_macros::FromRequestParts;
fn main() {
let _: Router = Router::new().route("/", get(handler).post(handler_result));
}
async fn handler(_: MyExtractor) {}
async fn handler_result(_: Result<MyExtractor, MyRejection>) {}
#[derive(FromRequestParts)]
#[from_request(rejection(MyRejection))]
struct MyExtractor {
one: Extension<String>,
#[from_request(via(Extension))]
two: String,
three: OtherExtractor,
}
struct OtherExtractor;
#[async_trait]
impl<S> FromRequestParts<S> for OtherExtractor
where
S: Send + Sync,
{
// this rejection doesn't implement `Display` and `Error`
type Rejection = (StatusCode, String);
async fn from_request_parts(_parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
todo!()
}
}
struct MyRejection {}
impl From<ExtensionRejection> for MyRejection {
fn from(_: ExtensionRejection) -> Self {
todo!()
}
}
impl From<(StatusCode, String)> for MyRejection {
fn from(_: (StatusCode, String)) -> Self {
todo!()
}
}
impl IntoResponse for MyRejection {
fn into_response(self) -> Response {
todo!()
}
}

View file

@ -0,0 +1,33 @@
use axum::{
extract::rejection::ExtensionRejection,
response::{IntoResponse, Response},
routing::get,
Router,
};
use axum_macros::FromRequestParts;
fn main() {
let _: Router = Router::new().route("/", get(handler).post(handler_result));
}
async fn handler(_: MyExtractor) {}
async fn handler_result(_: Result<MyExtractor, MyRejection>) {}
#[derive(FromRequestParts, Clone)]
#[from_request(via(axum::Extension), rejection(MyRejection))]
enum MyExtractor {}
struct MyRejection {}
impl From<ExtensionRejection> for MyRejection {
fn from(_: ExtensionRejection) -> Self {
todo!()
}
}
impl IntoResponse for MyRejection {
fn into_response(self) -> Response {
todo!()
}
}

View file

@ -0,0 +1,40 @@
use axum::{
extract::rejection::QueryRejection,
response::{IntoResponse, Response},
routing::get,
Router,
};
use axum_macros::FromRequestParts;
use serde::Deserialize;
fn main() {
let _: Router = Router::new().route("/", get(handler).post(handler_result));
}
#[derive(Deserialize)]
struct Payload {}
async fn handler(_: MyQuery<Payload>) {}
async fn handler_result(_: Result<MyQuery<Payload>, MyQueryRejection>) {}
#[derive(FromRequestParts)]
#[from_request(
via(axum::extract::Query),
rejection(MyQueryRejection),
)]
struct MyQuery<T>(T);
struct MyQueryRejection {}
impl From<QueryRejection> for MyQueryRejection {
fn from(_: QueryRejection) -> Self {
todo!()
}
}
impl IntoResponse for MyQueryRejection {
fn into_response(self) -> Response {
todo!()
}
}

View file

@ -0,0 +1,12 @@
use axum_macros::FromRequestParts;
#[derive(FromRequestParts)]
struct Extractor(axum::http::HeaderMap, axum::http::Method);
fn assert_from_request()
where
Extractor: axum::extract::FromRequestParts<()>,
{
}
fn main() {}

View file

@ -0,0 +1,20 @@
use axum::extract::Query;
use axum_macros::FromRequestParts;
use serde::Deserialize;
#[derive(FromRequestParts)]
struct Extractor(
Query<Payload>,
axum::extract::Path<Payload>,
);
#[derive(Deserialize)]
struct Payload {}
fn assert_from_request()
where
Extractor: axum::extract::FromRequestParts<()>,
{
}
fn main() {}

View file

@ -0,0 +1,21 @@
use axum::extract::Query;
use axum::response::Response;
use axum_macros::FromRequestParts;
use serde::Deserialize;
#[derive(FromRequestParts)]
struct Extractor(
#[from_request(via(Query))] Payload,
#[from_request(via(axum::extract::Path))] Payload,
);
#[derive(Deserialize)]
struct Payload {}
fn assert_from_request()
where
Extractor: axum::extract::FromRequestParts<(), Rejection = Response>,
{
}
fn main() {}

View file

@ -0,0 +1,16 @@
use axum::Extension;
use axum_macros::FromRequestParts;
#[derive(FromRequestParts)]
struct Extractor(#[from_request(via(Extension))] State);
#[derive(Clone)]
struct State;
fn assert_from_request()
where
Extractor: axum::extract::FromRequestParts<()>,
{
}
fn main() {}

View file

@ -0,0 +1,12 @@
use axum_macros::FromRequestParts;
#[derive(FromRequestParts)]
struct Extractor;
fn assert_from_request()
where
Extractor: axum::extract::FromRequestParts<(), Rejection = std::convert::Infallible>,
{
}
fn main() {}