Improve debug_handler on tuple response types (#2201)

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
Sleep_AllDay 2023-12-31 05:48:35 +08:00 committed by GitHub
parent 85573e0573
commit d2cea5cdbd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 357 additions and 17 deletions

View file

@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- None.
- **fixed:** Improve `debug_handler` on tuple response types ([#2201])
[#2201]: https://github.com/tokio-rs/axum/pull/2201
# 0.4.0 (27. November, 2023)

View file

@ -4,9 +4,9 @@ use crate::{
attr_parsing::{parse_assignment_attribute, second},
with_position::{Position, WithPosition},
};
use proc_macro2::{Span, TokenStream};
use proc_macro2::{Ident, Span, TokenStream};
use quote::{format_ident, quote, quote_spanned};
use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Token, Type};
use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, ReturnType, Token, Type};
pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
let Attrs { state_ty } = attr;
@ -15,7 +15,12 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
let check_extractor_count = check_extractor_count(&item_fn);
let check_path_extractor = check_path_extractor(&item_fn);
let check_output_impls_into_response = check_output_impls_into_response(&item_fn);
let check_output_tuples = check_output_tuples(&item_fn);
let check_output_impls_into_response = if check_output_tuples.is_empty() {
check_output_impls_into_response(&item_fn)
} else {
check_output_tuples
};
// If the function is generic, we can't reliably check its inputs or whether the future it
// returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors.
@ -180,7 +185,7 @@ fn check_inputs_impls_from_request(item_fn: &ItemFn, state_ty: Type) -> TokenStr
FnArg::Typed(typed) => is_self_pat_type(typed),
});
WithPosition::new(item_fn.sig.inputs.iter())
WithPosition::new(&item_fn.sig.inputs)
.enumerate()
.map(|(idx, arg)| {
let must_impl_from_request_parts = match &arg {
@ -275,8 +280,7 @@ fn check_inputs_impls_from_request(item_fn: &ItemFn, state_ty: Type) -> TokenStr
#[allow(warnings)]
#[allow(unreachable_code)]
#[doc(hidden)]
fn #call_check_fn()
{
fn #call_check_fn() {
#call_check_fn_body
}
}
@ -284,6 +288,163 @@ fn check_inputs_impls_from_request(item_fn: &ItemFn, state_ty: Type) -> TokenStr
.collect::<TokenStream>()
}
fn check_output_tuples(item_fn: &ItemFn) -> TokenStream {
let elems = match &item_fn.sig.output {
ReturnType::Type(_, ty) => match &**ty {
Type::Tuple(tuple) => &tuple.elems,
_ => return quote! {},
},
ReturnType::Default => return quote! {},
};
let handler_ident = &item_fn.sig.ident;
match elems.len() {
0 => quote! {},
n if n > 17 => syn::Error::new_spanned(
&item_fn.sig.output,
"Cannot return tuples with more than 17 elements",
)
.to_compile_error(),
_ => WithPosition::new(elems)
.enumerate()
.map(|(idx, arg)| match arg {
Position::First(ty) => match extract_clean_typename(ty).as_deref() {
Some("StatusCode" | "Response") => quote! {},
Some("Parts") => check_is_response_parts(ty, handler_ident, idx),
Some(_) | None => {
if let Some(tn) = well_known_last_response_type(ty) {
syn::Error::new_spanned(
ty,
format!(
"`{tn}` must be the last element \
in a response tuple"
),
)
.to_compile_error()
} else {
check_into_response_parts(ty, handler_ident, idx)
}
}
},
Position::Middle(ty) => {
if let Some(tn) = well_known_last_response_type(ty) {
syn::Error::new_spanned(
ty,
format!("`{tn}` must be the last element in a response tuple"),
)
.to_compile_error()
} else {
check_into_response_parts(ty, handler_ident, idx)
}
}
Position::Last(ty) | Position::Only(ty) => check_into_response(handler_ident, ty),
})
.collect::<TokenStream>(),
}
}
fn check_into_response(handler: &Ident, ty: &Type) -> TokenStream {
let (span, ty) = (ty.span(), ty.clone());
let check_fn = format_ident!(
"__axum_macros_check_{handler}_into_response_check",
span = span,
);
let call_check_fn = format_ident!(
"__axum_macros_check_{handler}_into_response_call_check",
span = span,
);
let call_check_fn_body = quote_spanned! {span=>
#check_fn();
};
let from_request_bound = quote_spanned! {span=>
#ty: ::axum::response::IntoResponse
};
quote_spanned! {span=>
#[allow(warnings)]
#[allow(unreachable_code)]
#[doc(hidden)]
fn #check_fn()
where
#from_request_bound,
{}
// we have to call the function to actually trigger a compile error
// since the function is generic, just defining it is not enough
#[allow(warnings)]
#[allow(unreachable_code)]
#[doc(hidden)]
fn #call_check_fn() {
#call_check_fn_body
}
}
}
fn check_is_response_parts(ty: &Type, ident: &Ident, index: usize) -> TokenStream {
let (span, ty) = (ty.span(), ty.clone());
let check_fn = format_ident!(
"__axum_macros_check_{}_is_response_parts_{index}_check",
ident,
span = span,
);
quote_spanned! {span=>
#[allow(warnings)]
#[allow(unreachable_code)]
#[doc(hidden)]
fn #check_fn(parts: #ty) -> ::axum::http::response::Parts {
parts
}
}
}
fn check_into_response_parts(ty: &Type, ident: &Ident, index: usize) -> TokenStream {
let (span, ty) = (ty.span(), ty.clone());
let check_fn = format_ident!(
"__axum_macros_check_{}_into_response_parts_{index}_check",
ident,
span = span,
);
let call_check_fn = format_ident!(
"__axum_macros_check_{}_into_response_parts_{index}_call_check",
ident,
span = span,
);
let call_check_fn_body = quote_spanned! {span=>
#check_fn();
};
let from_request_bound = quote_spanned! {span=>
#ty: ::axum::response::IntoResponseParts
};
quote_spanned! {span=>
#[allow(warnings)]
#[allow(unreachable_code)]
#[doc(hidden)]
fn #check_fn()
where
#from_request_bound,
{}
// we have to call the function to actually trigger a compile error
// since the function is generic, just defining it is not enough
#[allow(warnings)]
#[allow(unreachable_code)]
#[doc(hidden)]
fn #call_check_fn() {
#call_check_fn_body
}
}
}
fn check_input_order(item_fn: &ItemFn) -> Option<TokenStream> {
let types_that_consume_the_request = item_fn
.sig
@ -334,7 +495,7 @@ fn check_input_order(item_fn: &ItemFn) -> Option<TokenStream> {
compile_error!(#error);
})
} else {
let types = WithPosition::new(types_that_consume_the_request.into_iter())
let types = WithPosition::new(types_that_consume_the_request)
.map(|pos| match pos {
Position::First((_, type_name, _)) | Position::Middle((_, type_name, _)) => {
format!("`{type_name}`, ")
@ -355,18 +516,18 @@ fn check_input_order(item_fn: &ItemFn) -> Option<TokenStream> {
}
}
fn request_consuming_type_name(ty: &Type) -> Option<&'static str> {
fn extract_clean_typename(ty: &Type) -> Option<String> {
let path = match ty {
Type::Path(type_path) => &type_path.path,
_ => return None,
};
path.segments.last().map(|p| p.ident.to_string())
}
let ident = match path.segments.last() {
Some(path_segment) => &path_segment.ident,
None => return None,
};
fn request_consuming_type_name(ty: &Type) -> Option<&'static str> {
let typename = extract_clean_typename(ty)?;
let type_name = match &*ident.to_string() {
let type_name = match &*typename {
"Json" => "Json<_>",
"RawBody" => "RawBody<_>",
"RawForm" => "RawForm",
@ -384,6 +545,25 @@ fn request_consuming_type_name(ty: &Type) -> Option<&'static str> {
Some(type_name)
}
fn well_known_last_response_type(ty: &Type) -> Option<&'static str> {
let typename = match extract_clean_typename(ty) {
Some(tn) => tn,
None => return None,
};
let type_name = match &*typename {
"Json" => "Json<_>",
"Protobuf" => "Protobuf",
"JsonLines" => "JsonLines<_>",
"Form" => "Form<_>",
"Bytes" => "Bytes",
"String" => "String",
_ => return None,
};
Some(type_name)
}
fn check_output_impls_into_response(item_fn: &ItemFn) -> TokenStream {
let ty = match &item_fn.sig.output {
syn::ReturnType::Default => return quote! {},

View file

@ -40,10 +40,10 @@ impl<I> WithPosition<I>
where
I: Iterator,
{
pub(crate) fn new(iter: I) -> WithPosition<I> {
pub(crate) fn new(iter: impl IntoIterator<IntoIter = I>) -> WithPosition<I> {
WithPosition {
handled_first: false,
peekable: iter.fuse().peekable(),
peekable: iter.into_iter().fuse().peekable(),
}
}
}

View file

@ -0,0 +1,29 @@
use axum::response::AppendHeaders;
#[axum::debug_handler]
async fn handler(
) -> (
axum::http::StatusCode,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
axum::http::StatusCode,
) {
panic!()
}
fn main() {}

View file

@ -0,0 +1,12 @@
error: Cannot return tuples with more than 17 elements
--> tests/debug_handler/fail/output_tuple_too_many.rs:5:3
|
5 | ) -> (
| ___^
6 | | axum::http::StatusCode,
7 | | AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
8 | | AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>,
... |
24 | | axum::http::StatusCode,
25 | | ) {
| |_^

View file

@ -0,0 +1,10 @@
#[axum::debug_handler]
async fn handler(
) -> (
axum::http::request::Parts, // this should be response parts, not request parts
axum::http::StatusCode,
) {
panic!()
}
fn main(){}

View file

@ -0,0 +1,8 @@
error[E0308]: mismatched types
--> tests/debug_handler/fail/returning_request_parts.rs:4:5
|
4 | axum::http::request::Parts, // this should be response parts, not request parts
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
| |
| expected `axum::http::response::Parts`, found `axum::http::request::Parts`
| expected `axum::http::response::Parts` because of return type

View file

@ -0,0 +1,10 @@
#![allow(unused_parens)]
struct NotIntoResponse;
#[axum::debug_handler]
async fn handler() -> (NotIntoResponse) {
panic!()
}
fn main() {}

View file

@ -0,0 +1,21 @@
error[E0277]: the trait bound `NotIntoResponse: IntoResponse` is not satisfied
--> tests/debug_handler/fail/single_wrong_return_tuple.rs:6:23
|
6 | async fn handler() -> (NotIntoResponse) {
| ^^^^^^^^^^^^^^^^^ the trait `IntoResponse` is not implemented for `NotIntoResponse`
|
= help: the following other types implement trait `IntoResponse`:
Box<str>
Box<[u8]>
axum::body::Bytes
Body
axum::extract::rejection::FailedToBufferBody
axum::extract::rejection::LengthLimitError
axum::extract::rejection::UnknownBodyError
bytes::bytes_mut::BytesMut
and $N others
note: required by a bound in `__axum_macros_check_handler_into_response::{closure#0}::check`
--> tests/debug_handler/fail/single_wrong_return_tuple.rs:6:23
|
6 | async fn handler() -> (NotIntoResponse) {
| ^^^^^^^^^^^^^^^^^ required by this bound in `check`

View file

@ -0,0 +1,30 @@
#![allow(unused_parens)]
#[axum::debug_handler]
async fn named_type() -> (
axum::http::StatusCode,
axum::Json<&'static str>,
axum::response::AppendHeaders<[( axum::http::HeaderName,&'static str); 1]>,
) {
panic!()
}
struct CustomIntoResponse{
}
impl axum::response::IntoResponse for CustomIntoResponse{
fn into_response(self) -> axum::response::Response {
todo!()
}
}
#[axum::debug_handler]
async fn custom_type() -> (
axum::http::StatusCode,
CustomIntoResponse,
axum::response::AppendHeaders<[( axum::http::HeaderName,&'static str); 1]>,
) {
panic!()
}
fn main() {}

View file

@ -0,0 +1,24 @@
error: `Json<_>` must be the last element in a response tuple
--> tests/debug_handler/fail/wrong_return_tuple.rs:6:5
|
6 | axum::Json<&'static str>,
| ^^^^^^^^^^^^^^^^^^^^^^^^
error[E0277]: the trait bound `CustomIntoResponse: IntoResponseParts` is not satisfied
--> tests/debug_handler/fail/wrong_return_tuple.rs:24:5
|
24 | CustomIntoResponse,
| ^^^^^^^^^^^^^^^^^^ the trait `IntoResponseParts` is not implemented for `CustomIntoResponse`
|
= help: the following other types implement trait `IntoResponseParts`:
AppendHeaders<I>
HeaderMap
Extension<T>
Extensions
Option<T>
[(K, V); N]
()
(T1,)
and $N others
= help: see issue #48214
= help: add `#![feature(trivial_bounds)]` to the crate attributes to enable

View file

@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- None.
- **fixed:** Improve `debug_handler` on tuple response types ([#2201])
[#2201]: https://github.com/tokio-rs/axum/pull/2201
# 0.7.3 (29. December, 2023)

View file

@ -31,6 +31,18 @@ async fn nest() {
assert_eq!(res.text().await, "fallback");
}
#[crate::test]
async fn two() {
let app = Router::new()
.route("/first", get(|| async {}))
.route("/second", get(|| async {}))
.fallback(get(|| async { "fallback" }));
let client = TestClient::new(app);
let res = client.get("/does-not-exist").await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "fallback");
}
#[crate::test]
async fn or() {
let one = Router::new().route("/one", get(|| async {}));