diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index 7c261b50..8d69129a 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -1,3 +1,4 @@ +use crate::with_position::{Position, WithPosition}; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; use std::collections::HashSet; @@ -173,12 +174,16 @@ fn check_inputs_impls_from_request( FnArg::Typed(typed) => is_self_pat_type(typed), }); - item_fn - .sig - .inputs - .iter() + WithPosition::new(item_fn.sig.inputs.iter()) .enumerate() .map(|(idx, arg)| { + let must_impl_from_request_parts = match &arg { + Position::First(_) | Position::Middle(_) => true, + Position::Last(_) | Position::Only(_) => false, + }; + + let arg = arg.into_inner(); + let (span, ty) = match arg { FnArg::Receiver(receiver) => { if receiver.reference.is_some() { @@ -228,11 +233,27 @@ fn check_inputs_impls_from_request( } }; + let check_fn_generics = if must_impl_from_request_parts { + quote! {} + } else { + quote! { } + }; + + let from_request_bound = if must_impl_from_request_parts { + quote! { + #ty: ::axum::extract::FromRequestParts<#state_ty> + Send + } + } else { + quote! { + #ty: ::axum::extract::FromRequest<#state_ty, #body_ty, M> + Send + } + }; + quote_spanned! {span=> #[allow(warnings)] - fn #check_fn() + fn #check_fn #check_fn_generics() where - #ty: ::axum::extract::FromRequest<#state_ty, #body_ty, M> + Send, + #from_request_bound, {} // we have to call the function to actually trigger a compile error @@ -472,15 +493,5 @@ fn state_type_from_args(item_fn: &ItemFn) -> Option { #[test] fn ui() { - #[rustversion::stable] - fn go() { - let t = trybuild::TestCases::new(); - t.compile_fail("tests/debug_handler/fail/*.rs"); - t.pass("tests/debug_handler/pass/*.rs"); - } - - #[rustversion::not(stable)] - fn go() {} - - go(); + crate::run_ui_tests("debug_handler"); } diff --git a/axum-macros/src/from_request.rs b/axum-macros/src/from_request.rs index 596d23d3..f92eaeb4 100644 --- a/axum-macros/src/from_request.rs +++ b/axum-macros/src/from_request.rs @@ -564,17 +564,7 @@ fn impl_enum_by_extracting_all_at_once( #[test] fn ui() { - #[rustversion::stable] - fn go() { - let t = trybuild::TestCases::new(); - t.compile_fail("tests/from_request/fail/*.rs"); - t.pass("tests/from_request/pass/*.rs"); - } - - #[rustversion::not(stable)] - fn go() {} - - go(); + crate::run_ui_tests("from_request"); } /// For some reason the compiler error for this is different locally and on CI. No idea why... So diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index 47442724..a3ebf2a5 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -50,6 +50,7 @@ use syn::parse::Parse; mod debug_handler; mod from_request; mod typed_path; +mod with_position; /// Derive an implementation of [`FromRequest`]. /// @@ -561,3 +562,37 @@ where Err(err) => err.into_compile_error().into(), } } + +#[cfg(test)] +fn run_ui_tests(directory: &str) { + #[rustversion::stable] + fn go(directory: &str) { + let t = trybuild::TestCases::new(); + + if let Ok(mut path) = std::env::var("AXUM_TEST_ONLY") { + if let Some(path_without_prefix) = path.strip_prefix("axum-macros/") { + path = path_without_prefix.to_owned(); + } + + if !path.contains(&format!("/{}/", directory)) { + return; + } + + if path.contains("/fail/") { + t.compile_fail(path); + } else if path.contains("/pass/") { + t.pass(path); + } else { + panic!() + } + } else { + t.compile_fail(format!("tests/{}/fail/*.rs", directory)); + t.pass(format!("tests/{}/pass/*.rs", directory)); + } + } + + #[rustversion::not(stable)] + fn go(directory: &str) {} + + go(directory); +} diff --git a/axum-macros/src/typed_path.rs b/axum-macros/src/typed_path.rs index efbf733c..ede7a581 100644 --- a/axum-macros/src/typed_path.rs +++ b/axum-macros/src/typed_path.rs @@ -423,15 +423,5 @@ fn map_err_rejection(rejection: &Option) -> TokenStream { #[test] fn ui() { - #[rustversion::stable] - fn go() { - let t = trybuild::TestCases::new(); - t.compile_fail("tests/typed_path/fail/*.rs"); - t.pass("tests/typed_path/pass/*.rs"); - } - - #[rustversion::not(stable)] - fn go() {} - - go(); + crate::run_ui_tests("typed_path"); } diff --git a/axum-macros/src/with_position.rs b/axum-macros/src/with_position.rs new file mode 100644 index 00000000..2e0caa50 --- /dev/null +++ b/axum-macros/src/with_position.rs @@ -0,0 +1,116 @@ +// this is copied from itertools under the following license +// +// Copyright (c) 2015 +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use std::iter::{Fuse, FusedIterator, Peekable}; + +pub(crate) struct WithPosition +where + I: Iterator, +{ + handled_first: bool, + peekable: Peekable>, +} + +impl WithPosition +where + I: Iterator, +{ + pub(crate) fn new(iter: I) -> WithPosition { + WithPosition { + handled_first: false, + peekable: iter.fuse().peekable(), + } + } +} + +impl Clone for WithPosition +where + I: Clone + Iterator, + I::Item: Clone, +{ + fn clone(&self) -> Self { + Self { + handled_first: self.handled_first, + peekable: self.peekable.clone(), + } + } +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub(crate) enum Position { + First(T), + Middle(T), + Last(T), + Only(T), +} + +impl Position { + pub(crate) fn into_inner(self) -> T { + match self { + Position::First(x) | Position::Middle(x) | Position::Last(x) | Position::Only(x) => x, + } + } +} + +impl Iterator for WithPosition { + type Item = Position; + + fn next(&mut self) -> Option { + match self.peekable.next() { + Some(item) => { + if !self.handled_first { + // Haven't seen the first item yet, and there is one to give. + self.handled_first = true; + // Peek to see if this is also the last item, + // in which case tag it as `Only`. + match self.peekable.peek() { + Some(_) => Some(Position::First(item)), + None => Some(Position::Only(item)), + } + } else { + // Have seen the first item, and there's something left. + // Peek to see if this is the last item. + match self.peekable.peek() { + Some(_) => Some(Position::Middle(item)), + None => Some(Position::Last(item)), + } + } + } + // Iterator is finished. + None => None, + } + } + + fn size_hint(&self) -> (usize, Option) { + self.peekable.size_hint() + } +} + +impl ExactSizeIterator for WithPosition where I: ExactSizeIterator {} + +impl FusedIterator for WithPosition {} diff --git a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr index 8d46455c..c669438a 100644 --- a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr +++ b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr @@ -16,7 +16,10 @@ error[E0277]: the trait bound `bool: FromRequestParts<()>` is not satisfied and 26 others = note: required because of the requirements on the impl of `FromRequest<(), Body, axum_core::extract::private::ViaParts>` for `bool` note: required by a bound in `__axum_macros_check_handler_0_from_request_check` - --> tests/debug_handler/fail/argument_not_extractor.rs:4:23 + --> tests/debug_handler/fail/argument_not_extractor.rs:3:1 | +3 | #[debug_handler] + | ^^^^^^^^^^^^^^^^ required by this bound in `__axum_macros_check_handler_0_from_request_check` 4 | async fn handler(foo: bool) {} - | ^^^^ required by this bound in `__axum_macros_check_handler_0_from_request_check` + | ---- required by a bound in this + = note: this error originates in the attribute macro `debug_handler` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/debug_handler/fail/doesnt_implement_from_request_parts.rs b/axum-macros/tests/debug_handler/fail/doesnt_implement_from_request_parts.rs new file mode 100644 index 00000000..e4700c2a --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/doesnt_implement_from_request_parts.rs @@ -0,0 +1,7 @@ +use axum_macros::debug_handler; +use axum::http::Method; + +#[debug_handler] +async fn handler(_: String, _: Method) {} + +fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/doesnt_implement_from_request_parts.stderr b/axum-macros/tests/debug_handler/fail/doesnt_implement_from_request_parts.stderr new file mode 100644 index 00000000..87f3ab36 --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/doesnt_implement_from_request_parts.stderr @@ -0,0 +1,18 @@ +error[E0277]: the trait bound `String: FromRequestParts<()>` is not satisfied + --> tests/debug_handler/fail/doesnt_implement_from_request_parts.rs:4:1 + | +4 | #[debug_handler] + | ^^^^^^^^^^^^^^^^ the trait `FromRequestParts<()>` is not implemented for `String` + | + = help: the following other types implement trait `FromRequestParts`: + <() as FromRequestParts> + <(T1, T2) as FromRequestParts> + <(T1, T2, T3) as FromRequestParts> + <(T1, T2, T3, T4) as FromRequestParts> + <(T1, T2, T3, T4, T5) as FromRequestParts> + <(T1, T2, T3, T4, T5, T6) as FromRequestParts> + <(T1, T2, T3, T4, T5, T6, T7) as FromRequestParts> + <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequestParts> + and 26 others + = help: see issue #48214 + = note: this error originates in the attribute macro `debug_handler` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs index c4be8b52..782fc930 100644 --- a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs +++ b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs @@ -1,4 +1,4 @@ -use axum::{async_trait, extract::FromRequest, http::Request, response::IntoResponse}; +use axum::{async_trait, extract::FromRequestParts, http::request::Parts, response::IntoResponse}; use axum_macros::debug_handler; fn main() {} @@ -116,14 +116,13 @@ impl A { } #[async_trait] -impl FromRequest for A +impl FromRequestParts for A where - B: Send + 'static, S: Send + Sync, { type Rejection = (); - async fn from_request(_req: Request, _state: &S) -> Result { + async fn from_request_parts(_parts: &mut Parts, _state: &S) -> Result { unimplemented!() } }