Merge branch 'tokio-rs:main' into main

This commit is contained in:
Justus Fluegel 2024-06-22 01:15:47 +02:00 committed by GitHub
commit b589cb3ab4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
40 changed files with 598 additions and 236 deletions

View file

@ -56,13 +56,19 @@ jobs:
crate: [axum, axum-core, axum-extra, axum-macros]
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@nightly
# Pinned version due to failing `cargo-public-api-crates`.
- uses: dtolnay/rust-toolchain@master
with:
toolchain: nightly-2024-06-06
- uses: Swatinem/rust-cache@v2
- name: Install cargo-public-api-crates
run: |
cargo install --git https://github.com/davidpdrsn/cargo-public-api-crates
- name: Build rustdoc
run: |
cargo rustdoc --all-features --manifest-path ${{ matrix.crate }}/Cargo.toml -- -Z unstable-options --output-format json
- name: cargo public-api-crates check
run: cargo public-api-crates --manifest-path ${{ matrix.crate }}/Cargo.toml check
run: cargo public-api-crates --manifest-path ${{ matrix.crate }}/Cargo.toml --skip-build check
test-versions:
needs: check

View file

@ -26,6 +26,7 @@ http-body = "1.0.0"
http-body-util = "0.1.0"
mime = "0.3.16"
pin-project-lite = "0.2.7"
rustversion = "1.0.9"
sync_wrapper = "1.0.0"
tower-layer = "0.3"
tower-service = "0.3"
@ -34,9 +35,6 @@ tower-service = "0.3"
tower-http = { version = "0.5.0", optional = true, features = ["limit"] }
tracing = { version = "0.1.37", default-features = false, optional = true }
[build-dependencies]
rustversion = "1.0.9"
[dev-dependencies]
axum = { path = "../axum", version = "0.7.2" }
axum-extra = { path = "../axum-extra", features = ["typed-header"] }

View file

@ -1,7 +0,0 @@
#[rustversion::nightly]
fn main() {
println!("cargo:rustc-cfg=nightly_error_messages");
}
#[rustversion::not(nightly)]
fn main() {}

View file

@ -43,8 +43,8 @@ mod private {
///
/// [`axum::extract`]: https://docs.rs/axum/0.7/axum/extract/index.html
#[async_trait]
#[cfg_attr(
nightly_error_messages,
#[rustversion::attr(
since(1.78),
diagnostic::on_unimplemented(
note = "Function argument is not a valid axum extractor. \nSee `https://docs.rs/axum/0.7/axum/extract/index.html` for details",
)
@ -70,8 +70,8 @@ pub trait FromRequestParts<S>: Sized {
///
/// [`axum::extract`]: https://docs.rs/axum/0.7/axum/extract/index.html
#[async_trait]
#[cfg_attr(
nightly_error_messages,
#[rustversion::attr(
since(1.78),
diagnostic::on_unimplemented(
note = "Function argument is not a valid axum extractor. \nSee `https://docs.rs/axum/0.7/axum/extract/index.html` for details",
)

View file

@ -1,4 +1,5 @@
/// Private API.
#[cfg(feature = "tracing")]
#[doc(hidden)]
#[macro_export]
macro_rules! __log_rejection {
@ -7,7 +8,6 @@ macro_rules! __log_rejection {
body_text = $body_text:expr,
status = $status:expr,
) => {
#[cfg(feature = "tracing")]
{
tracing::event!(
target: "axum::rejection",
@ -21,6 +21,17 @@ macro_rules! __log_rejection {
};
}
#[cfg(not(feature = "tracing"))]
#[doc(hidden)]
#[macro_export]
macro_rules! __log_rejection {
(
rejection_type = $ty:ident,
body_text = $body_text:expr,
status = $status:expr,
) => {};
}
/// Private API.
#[doc(hidden)]
#[macro_export]

View file

@ -29,7 +29,7 @@ use std::fmt;
/// )
/// }
/// ```
#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
#[must_use]
pub struct AppendHeaders<I>(pub I);

View file

@ -15,6 +15,7 @@ version = "0.9.3"
default = ["tracing"]
async-read-body = ["dep:tokio-util", "tokio-util?/io", "dep:tokio"]
attachment = ["dep:tracing"]
cookie = ["dep:cookie"]
cookie-private = ["cookie", "cookie?/private"]
cookie-signed = ["cookie", "cookie?/signed"]
@ -33,7 +34,7 @@ json-lines = [
multipart = ["dep:multer"]
protobuf = ["dep:prost"]
query = ["dep:serde_html_form"]
tracing = ["dep:tracing", "axum-core/tracing"]
tracing = ["dep:tracing", "axum-core/tracing", "axum/tracing"]
typed-header = ["dep:headers"]
typed-routing = ["dep:axum-macros", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"]

View file

@ -23,8 +23,7 @@ use std::marker::PhantomData;
/// Additionally, a `JsonRejection` error will be returned, when calling `deserialize` if:
///
/// - The body doesn't contain syntactically valid JSON.
/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target
/// type.
/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target type.
/// - Attempting to deserialize escaped JSON into a type that must be borrowed (e.g. `&'a str`).
///
/// ⚠️ `serde` will implicitly try to borrow for `&str` and `&[u8]` types, but will error if the

View file

@ -0,0 +1,103 @@
use axum::response::IntoResponse;
use http::{header, HeaderMap, HeaderValue};
use tracing::error;
/// A file attachment response.
///
/// This type will set the `Content-Disposition` header to `attachment`. In response a webbrowser
/// will offer to download the file instead of displaying it directly.
///
/// Use the `filename` and `content_type` methods to set the filename or content-type of the
/// attachment. If these values are not set they will not be sent.
///
///
/// # Example
///
/// ```rust
/// use axum::{http::StatusCode, routing::get, Router};
/// use axum_extra::response::Attachment;
///
/// async fn cargo_toml() -> Result<Attachment<String>, (StatusCode, String)> {
/// let file_contents = tokio::fs::read_to_string("Cargo.toml")
/// .await
/// .map_err(|err| (StatusCode::NOT_FOUND, format!("File not found: {err}")))?;
/// Ok(Attachment::new(file_contents)
/// .filename("Cargo.toml")
/// .content_type("text/x-toml"))
/// }
///
/// let app = Router::new().route("/Cargo.toml", get(cargo_toml));
/// let _: Router = app;
/// ```
///
/// # Note
///
/// If you use axum with hyper, hyper will set the `Content-Length` if it is known.
///
#[derive(Debug)]
pub struct Attachment<T> {
inner: T,
filename: Option<HeaderValue>,
content_type: Option<HeaderValue>,
}
impl<T: IntoResponse> Attachment<T> {
/// Creates a new [`Attachment`].
pub fn new(inner: T) -> Self {
Self {
inner,
filename: None,
content_type: None,
}
}
/// Sets the filename of the [`Attachment`].
///
/// This updates the `Content-Disposition` header to add a filename.
pub fn filename<H: TryInto<HeaderValue>>(mut self, value: H) -> Self {
self.filename = if let Ok(filename) = value.try_into() {
Some(filename)
} else {
error!("Attachment filename contains invalid characters");
None
};
self
}
/// Sets the content-type of the [`Attachment`]
pub fn content_type<H: TryInto<HeaderValue>>(mut self, value: H) -> Self {
if let Ok(content_type) = value.try_into() {
self.content_type = Some(content_type);
} else {
error!("Attachment content-type contains invalid characters");
}
self
}
}
impl<T> IntoResponse for Attachment<T>
where
T: IntoResponse,
{
fn into_response(self) -> axum::response::Response {
let mut headers = HeaderMap::new();
if let Some(content_type) = self.content_type {
headers.append(header::CONTENT_TYPE, content_type);
}
let content_disposition = if let Some(filename) = self.filename {
let mut bytes = b"attachment; filename=\"".to_vec();
bytes.extend_from_slice(filename.as_bytes());
bytes.push(b'\"');
HeaderValue::from_bytes(&bytes).expect("This was a HeaderValue so this can not fail")
} else {
HeaderValue::from_static("attachment")
};
headers.append(header::CONTENT_DISPOSITION, content_disposition);
(headers, self.inner).into_response()
}
}

View file

@ -3,6 +3,9 @@
#[cfg(feature = "erased-json")]
mod erased_json;
#[cfg(feature = "attachment")]
mod attachment;
#[cfg(feature = "erased-json")]
pub use erased_json::ErasedJson;
@ -10,6 +13,9 @@ pub use erased_json::ErasedJson;
#[doc(no_inline)]
pub use crate::json_lines::JsonLines;
#[cfg(feature = "attachment")]
pub use attachment::Attachment;
macro_rules! mime_response {
(
$(#[$m:meta])*

View file

@ -85,12 +85,12 @@ use serde::Serialize;
///
/// - A `TypedPath` implementation.
/// - A [`FromRequest`] implementation compatible with [`RouterExt::typed_get`],
/// [`RouterExt::typed_post`], etc. This implementation uses [`Path`] and thus your struct must
/// also implement [`serde::Deserialize`], unless it's a unit struct.
/// [`RouterExt::typed_post`], etc. This implementation uses [`Path`] and thus your struct must
/// also implement [`serde::Deserialize`], unless it's a unit struct.
/// - A [`Display`] implementation that interpolates the captures. This can be used to, among other
/// things, create links to known paths and have them verified statically. Note that the
/// [`Display`] implementation for each field must return something that's compatible with its
/// [`Deserialize`] implementation.
/// things, create links to known paths and have them verified statically. Note that the
/// [`Display`] implementation for each field must return something that's compatible with its
/// [`Deserialize`] implementation.
///
/// Additionally the macro will verify the captures in the path matches the fields of the struct.
/// For example this fails to compile since the struct doesn't have a `team_id` field:

View file

@ -6,7 +6,7 @@ use axum::{
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use headers::{Header, HeaderMapExt};
use http::request::Parts;
use http::{request::Parts, StatusCode};
use std::convert::Infallible;
/// Extractor and response that works with typed header values from [`headers`].
@ -156,7 +156,10 @@ impl TypedHeaderRejectionReason {
impl IntoResponse for TypedHeaderRejection {
fn into_response(self) -> Response {
(http::StatusCode::BAD_REQUEST, self.to_string()).into_response()
let status = StatusCode::BAD_REQUEST;
let body = self.to_string();
axum_core::__log_rejection!(rejection_type = Self, body_text = body, status = status,);
(status, body).into_response()
}
}

View file

@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- None.
- **added:** Add `#[debug_middleware]` ([#1993], [#2725])
[#1993]: https://github.com/tokio-rs/axum/pull/1993
[#2725]: https://github.com/tokio-rs/axum/pull/2725
# 0.4.1 (13. January, 2024)

View file

@ -19,7 +19,6 @@ __private = ["syn/visit-mut"]
proc-macro = true
[dependencies]
heck = "0.4"
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "2.0", features = [

View file

@ -1,4 +1,4 @@
use std::collections::HashSet;
use std::{collections::HashSet, fmt};
use crate::{
attr_parsing::{parse_assignment_attribute, second},
@ -8,13 +8,13 @@ use proc_macro2::{Ident, Span, TokenStream};
use quote::{format_ident, quote, quote_spanned};
use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, ReturnType, Token, Type};
pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
pub(crate) fn expand(attr: Attrs, item_fn: ItemFn, kind: FunctionKind) -> TokenStream {
let Attrs { state_ty } = attr;
let mut state_ty = state_ty.map(second);
let check_extractor_count = check_extractor_count(&item_fn);
let check_path_extractor = check_path_extractor(&item_fn);
let check_extractor_count = check_extractor_count(&item_fn, kind);
let check_path_extractor = check_path_extractor(&item_fn, kind);
let check_output_tuples = check_output_tuples(&item_fn);
let check_output_impls_into_response = if check_output_tuples.is_empty() {
check_output_impls_into_response(&item_fn)
@ -37,8 +37,10 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
err = Some(
syn::Error::new(
Span::call_site(),
"can't infer state type, please add set it explicitly, as in \
`#[debug_handler(state = MyStateType)]`",
format!(
"can't infer state type, please add set it explicitly, as in \
`#[axum_macros::debug_{kind}(state = MyStateType)]`"
),
)
.into_compile_error(),
);
@ -48,16 +50,16 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
err.unwrap_or_else(|| {
let state_ty = state_ty.unwrap_or_else(|| syn::parse_quote!(()));
let check_future_send = check_future_send(&item_fn);
let check_future_send = check_future_send(&item_fn, kind);
if let Some(check_input_order) = check_input_order(&item_fn) {
if let Some(check_input_order) = check_input_order(&item_fn, kind) {
quote! {
#check_input_order
#check_future_send
}
} else {
let check_inputs_impls_from_request =
check_inputs_impls_from_request(&item_fn, state_ty);
check_inputs_impls_from_request(&item_fn, state_ty, kind);
quote! {
#check_inputs_impls_from_request
@ -68,17 +70,45 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
} else {
syn::Error::new_spanned(
&item_fn.sig.generics,
"`#[axum_macros::debug_handler]` doesn't support generic functions",
format!("`#[axum_macros::debug_{kind}]` doesn't support generic functions"),
)
.into_compile_error()
};
let middleware_takes_next_as_last_arg =
matches!(kind, FunctionKind::Middleware).then(|| next_is_last_input(&item_fn));
quote! {
#item_fn
#check_extractor_count
#check_path_extractor
#check_output_impls_into_response
#check_inputs_and_future_send
#middleware_takes_next_as_last_arg
}
}
#[derive(Clone, Copy)]
pub(crate) enum FunctionKind {
Handler,
Middleware,
}
impl fmt::Display for FunctionKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FunctionKind::Handler => f.write_str("handler"),
FunctionKind::Middleware => f.write_str("middleware"),
}
}
}
impl FunctionKind {
fn name_uppercase_plural(&self) -> &'static str {
match self {
FunctionKind::Handler => "Handlers",
FunctionKind::Middleware => "Middleware",
}
}
}
@ -110,25 +140,36 @@ impl Parse for Attrs {
}
}
fn check_extractor_count(item_fn: &ItemFn) -> Option<TokenStream> {
fn check_extractor_count(item_fn: &ItemFn, kind: FunctionKind) -> Option<TokenStream> {
let max_extractors = 16;
if item_fn.sig.inputs.len() <= max_extractors {
let inputs = item_fn
.sig
.inputs
.iter()
.filter(|arg| skip_next_arg(arg, kind))
.count();
if inputs <= max_extractors {
None
} else {
let error_message = format!(
"Handlers cannot take more than {max_extractors} arguments. \
"{} cannot take more than {max_extractors} arguments. \
Use `(a, b): (ExtractorA, ExtractorA)` to further nest extractors",
kind.name_uppercase_plural(),
);
let error = syn::Error::new_spanned(&item_fn.sig.inputs, error_message).to_compile_error();
Some(error)
}
}
fn extractor_idents(item_fn: &ItemFn) -> impl Iterator<Item = (usize, &syn::FnArg, &syn::Ident)> {
fn extractor_idents(
item_fn: &ItemFn,
kind: FunctionKind,
) -> impl Iterator<Item = (usize, &syn::FnArg, &syn::Ident)> {
item_fn
.sig
.inputs
.iter()
.filter(move |arg| skip_next_arg(arg, kind))
.enumerate()
.filter_map(|(idx, fn_arg)| match fn_arg {
FnArg::Receiver(_) => None,
@ -146,8 +187,8 @@ fn extractor_idents(item_fn: &ItemFn) -> impl Iterator<Item = (usize, &syn::FnAr
})
}
fn check_path_extractor(item_fn: &ItemFn) -> TokenStream {
let path_extractors = extractor_idents(item_fn)
fn check_path_extractor(item_fn: &ItemFn, kind: FunctionKind) -> TokenStream {
let path_extractors = extractor_idents(item_fn, kind)
.filter(|(_, _, ident)| *ident == "Path")
.collect::<Vec<_>>();
@ -179,113 +220,122 @@ fn is_self_pat_type(typed: &syn::PatType) -> bool {
ident == "self"
}
fn check_inputs_impls_from_request(item_fn: &ItemFn, state_ty: Type) -> TokenStream {
fn check_inputs_impls_from_request(
item_fn: &ItemFn,
state_ty: Type,
kind: FunctionKind,
) -> TokenStream {
let takes_self = item_fn.sig.inputs.first().map_or(false, |arg| match arg {
FnArg::Receiver(_) => true,
FnArg::Typed(typed) => is_self_pat_type(typed),
});
WithPosition::new(&item_fn.sig.inputs)
.enumerate()
.map(|(idx, arg)| {
let must_impl_from_request_parts = match &arg {
Position::First(_) | Position::Middle(_) => true,
Position::Last(_) | Position::Only(_) => false,
};
WithPosition::new(
item_fn
.sig
.inputs
.iter()
.filter(|arg| skip_next_arg(arg, kind)),
)
.enumerate()
.map(|(idx, arg)| {
let must_impl_from_request_parts = match &arg {
Position::First(_) | Position::Middle(_) => true,
Position::Last(_) | Position::Only(_) => false,
};
let arg = arg.into_inner();
let arg = arg.into_inner();
let (span, ty) = match arg {
FnArg::Receiver(receiver) => {
if receiver.reference.is_some() {
return syn::Error::new_spanned(
receiver,
"Handlers must only take owned values",
)
.into_compile_error();
}
let (span, ty) = match arg {
FnArg::Receiver(receiver) => {
if receiver.reference.is_some() {
return syn::Error::new_spanned(
receiver,
"Handlers must only take owned values",
)
.into_compile_error();
}
let span = receiver.span();
let span = receiver.span();
(span, syn::parse_quote!(Self))
}
FnArg::Typed(typed) => {
let ty = &typed.ty;
let span = ty.span();
if is_self_pat_type(typed) {
(span, syn::parse_quote!(Self))
}
FnArg::Typed(typed) => {
let ty = &typed.ty;
let span = ty.span();
if is_self_pat_type(typed) {
(span, syn::parse_quote!(Self))
} else {
(span, ty.clone())
}
}
};
let consumes_request = request_consuming_type_name(&ty).is_some();
let check_fn = format_ident!(
"__axum_macros_check_{}_{}_from_request_check",
item_fn.sig.ident,
idx,
span = span,
);
let call_check_fn = format_ident!(
"__axum_macros_check_{}_{}_from_request_call_check",
item_fn.sig.ident,
idx,
span = span,
);
let call_check_fn_body = if takes_self {
quote_spanned! {span=>
Self::#check_fn();
}
} else {
quote_spanned! {span=>
#check_fn();
}
};
let check_fn_generics = if must_impl_from_request_parts || consumes_request {
quote! {}
} else {
quote! { <M> }
};
let from_request_bound = if must_impl_from_request_parts {
quote_spanned! {span=>
#ty: ::axum::extract::FromRequestParts<#state_ty> + Send
}
} else if consumes_request {
quote_spanned! {span=>
#ty: ::axum::extract::FromRequest<#state_ty> + Send
}
} else {
quote_spanned! {span=>
#ty: ::axum::extract::FromRequest<#state_ty, M> + Send
}
};
quote_spanned! {span=>
#[allow(warnings)]
#[allow(unreachable_code)]
#[doc(hidden)]
fn #check_fn #check_fn_generics()
where
#from_request_bound,
{}
// we have to call the function to actually trigger a compile error
// since the function is generic, just defining it is not enough
#[allow(warnings)]
#[allow(unreachable_code)]
#[doc(hidden)]
fn #call_check_fn() {
#call_check_fn_body
} else {
(span, ty.clone())
}
}
})
.collect::<TokenStream>()
};
let consumes_request = request_consuming_type_name(&ty).is_some();
let check_fn = format_ident!(
"__axum_macros_check_{}_{}_from_request_check",
item_fn.sig.ident,
idx,
span = span,
);
let call_check_fn = format_ident!(
"__axum_macros_check_{}_{}_from_request_call_check",
item_fn.sig.ident,
idx,
span = span,
);
let call_check_fn_body = if takes_self {
quote_spanned! {span=>
Self::#check_fn();
}
} else {
quote_spanned! {span=>
#check_fn();
}
};
let check_fn_generics = if must_impl_from_request_parts || consumes_request {
quote! {}
} else {
quote! { <M> }
};
let from_request_bound = if must_impl_from_request_parts {
quote_spanned! {span=>
#ty: ::axum::extract::FromRequestParts<#state_ty> + Send
}
} else if consumes_request {
quote_spanned! {span=>
#ty: ::axum::extract::FromRequest<#state_ty> + Send
}
} else {
quote_spanned! {span=>
#ty: ::axum::extract::FromRequest<#state_ty, M> + Send
}
};
quote_spanned! {span=>
#[allow(warnings)]
#[doc(hidden)]
fn #check_fn #check_fn_generics()
where
#from_request_bound,
{}
// we have to call the function to actually trigger a compile error
// since the function is generic, just defining it is not enough
#[allow(warnings)]
#[doc(hidden)]
fn #call_check_fn()
{
#call_check_fn_body
}
}
})
.collect::<TokenStream>()
}
fn check_output_tuples(item_fn: &ItemFn) -> TokenStream {
@ -445,11 +495,19 @@ fn check_into_response_parts(ty: &Type, ident: &Ident, index: usize) -> TokenStr
}
}
fn check_input_order(item_fn: &ItemFn) -> Option<TokenStream> {
fn check_input_order(item_fn: &ItemFn, kind: FunctionKind) -> Option<TokenStream> {
let number_of_inputs = item_fn
.sig
.inputs
.iter()
.filter(|arg| skip_next_arg(arg, kind))
.count();
let types_that_consume_the_request = item_fn
.sig
.inputs
.iter()
.filter(|arg| skip_next_arg(arg, kind))
.enumerate()
.filter_map(|(idx, arg)| {
let ty = match arg {
@ -469,7 +527,7 @@ fn check_input_order(item_fn: &ItemFn) -> Option<TokenStream> {
// exactly one type that consumes the request
if types_that_consume_the_request.len() == 1 {
// and that is not the last
if types_that_consume_the_request[0].0 != item_fn.sig.inputs.len() - 1 {
if types_that_consume_the_request[0].0 != number_of_inputs - 1 {
let (_idx, type_name, span) = &types_that_consume_the_request[0];
let error = format!(
"`{type_name}` consumes the request body and thus must be \
@ -653,13 +711,13 @@ fn check_output_impls_into_response(item_fn: &ItemFn) -> TokenStream {
}
}
fn check_future_send(item_fn: &ItemFn) -> TokenStream {
fn check_future_send(item_fn: &ItemFn, kind: FunctionKind) -> TokenStream {
if item_fn.sig.asyncness.is_none() {
match &item_fn.sig.output {
syn::ReturnType::Default => {
return syn::Error::new_spanned(
item_fn.sig.fn_token,
"Handlers must be `async fn`s",
format!("{} must be `async fn`s", kind.name_uppercase_plural()),
)
.into_compile_error();
}
@ -763,7 +821,69 @@ fn state_types_from_args(item_fn: &ItemFn) -> HashSet<Type> {
crate::infer_state_types(types).collect()
}
fn next_is_last_input(item_fn: &ItemFn) -> TokenStream {
let next_args = item_fn
.sig
.inputs
.iter()
.enumerate()
.filter(|(_, arg)| !skip_next_arg(arg, FunctionKind::Middleware))
.collect::<Vec<_>>();
if next_args.is_empty() {
return quote! {
compile_error!(
"Middleware functions must take `axum::middleware::Next` as the last argument",
);
};
}
if next_args.len() == 1 {
let (idx, arg) = &next_args[0];
if *idx != item_fn.sig.inputs.len() - 1 {
return quote_spanned! {arg.span()=>
compile_error!("`axum::middleware::Next` must the last argument");
};
}
}
if next_args.len() >= 2 {
return quote! {
compile_error!(
"Middleware functions can only take one argument of type `axum::middleware::Next`",
);
};
}
quote! {}
}
fn skip_next_arg(arg: &FnArg, kind: FunctionKind) -> bool {
match kind {
FunctionKind::Handler => true,
FunctionKind::Middleware => match arg {
FnArg::Receiver(_) => true,
FnArg::Typed(pat_type) => {
if let Type::Path(type_path) = &*pat_type.ty {
type_path
.path
.segments
.last()
.map_or(true, |path_segment| path_segment.ident != "Next")
} else {
true
}
}
},
}
}
#[test]
fn ui() {
fn ui_debug_handler() {
crate::run_ui_tests("debug_handler");
}
#[test]
fn ui_debug_middleware() {
crate::run_ui_tests("debug_middleware");
}

View file

@ -44,6 +44,7 @@
#![cfg_attr(test, allow(clippy::float_cmp))]
#![cfg_attr(not(test), warn(clippy::print_stdout, clippy::dbg_macro))]
use debug_handler::FunctionKind;
use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::{parse::Parse, Type};
@ -464,7 +465,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
expand_with(item, |item| from_request::expand(item, FromRequestParts))
}
/// Generates better error messages when applied handler functions.
/// Generates better error messages when applied to handler functions.
///
/// While using [`axum`], you can get long error messages for simple mistakes. For example:
///
@ -515,17 +516,15 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream {
///
/// As the error message says, handler function needs to be async.
///
/// ```
/// ```no_run
/// use axum::{routing::get, Router, debug_handler};
///
/// #[tokio::main]
/// async fn main() {
/// # async {
/// let app = Router::new().route("/", get(handler));
///
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
/// axum::serve(listener, app).await.unwrap();
/// # };
/// }
///
/// #[debug_handler]
@ -618,7 +617,65 @@ pub fn debug_handler(_attr: TokenStream, input: TokenStream) -> TokenStream {
return input;
#[cfg(debug_assertions)]
return expand_attr_with(_attr, input, debug_handler::expand);
return expand_attr_with(_attr, input, |attrs, item_fn| {
debug_handler::expand(attrs, item_fn, FunctionKind::Handler)
});
}
/// Generates better error messages when applied to middleware functions.
///
/// This works similarly to [`#[debug_handler]`](macro@debug_handler) except for middleware using
/// [`axum::middleware::from_fn`].
///
/// # Example
///
/// ```no_run
/// use axum::{
/// routing::get,
/// extract::Request,
/// response::Response,
/// Router,
/// middleware::{self, Next},
/// debug_middleware,
/// };
///
/// #[tokio::main]
/// async fn main() {
/// let app = Router::new()
/// .route("/", get(|| async {}))
/// .layer(middleware::from_fn(my_middleware));
///
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
/// axum::serve(listener, app).await.unwrap();
/// }
///
/// // if this wasn't a valid middleware function #[debug_middleware] would
/// // improve compile error
/// #[debug_middleware]
/// async fn my_middleware(
/// request: Request,
/// next: Next,
/// ) -> Response {
/// next.run(request).await
/// }
/// ```
///
/// # Performance
///
/// This macro has no effect when compiled with the release profile. (eg. `cargo build --release`)
///
/// [`axum`]: https://docs.rs/axum/latest
/// [`axum::middleware::from_fn`]: https://docs.rs/axum/0.7/axum/middleware/fn.from_fn.html
/// [`debug_middleware`]: macro@debug_middleware
#[proc_macro_attribute]
pub fn debug_middleware(_attr: TokenStream, input: TokenStream) -> TokenStream {
#[cfg(not(debug_assertions))]
return input;
#[cfg(debug_assertions)]
return expand_attr_with(_attr, input, |attrs, item_fn| {
debug_handler::expand(attrs, item_fn, FunctionKind::Middleware)
});
}
/// Private API: Do no use this!

View file

@ -0,0 +1,9 @@
use axum::extract::Extension;
use axum_macros::debug_handler;
struct NonCloneType;
#[debug_handler]
async fn test_extension_non_clone(_: Extension<NonCloneType>) {}
fn main() {}

View file

@ -0,0 +1,28 @@
error[E0277]: the trait bound `NonCloneType: Clone` is not satisfied
--> tests/debug_handler/fail/extension_not_clone.rs:7:38
|
7 | async fn test_extension_non_clone(_: Extension<NonCloneType>) {}
| ^^^^^^^^^^^^^^^^^^^^^^^ the trait `Clone` is not implemented for `NonCloneType`, which is required by `Extension<NonCloneType>: FromRequest<(), _>`
|
= help: the following other types implement trait `FromRequest<S, M>`:
axum::body::Bytes
Body
Form<T>
Json<T>
axum::http::Request<Body>
RawForm
String
Option<T>
and $N others
= note: required for `Extension<NonCloneType>` to implement `FromRequestParts<()>`
= note: required for `Extension<NonCloneType>` to implement `FromRequest<(), axum_core::extract::private::ViaParts>`
note: required by a bound in `__axum_macros_check_test_extension_non_clone_0_from_request_check`
--> tests/debug_handler/fail/extension_not_clone.rs:7:38
|
7 | async fn test_extension_non_clone(_: Extension<NonCloneType>) {}
| ^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `__axum_macros_check_test_extension_non_clone_0_from_request_check`
help: consider annotating `NonCloneType` with `#[derive(Clone)]`
|
4 + #[derive(Clone)]
5 | struct NonCloneType;
|

View file

@ -0,0 +1,13 @@
use axum::{
debug_middleware,
extract::Request,
response::{IntoResponse, Response},
};
#[debug_middleware]
async fn my_middleware(request: Request) -> Response {
let _ = request;
().into_response()
}
fn main() {}

View file

@ -0,0 +1,7 @@
error: Middleware functions must take `axum::middleware::Next` as the last argument
--> tests/debug_middleware/fail/doesnt_take_next.rs:7:1
|
7 | #[debug_middleware]
| ^^^^^^^^^^^^^^^^^^^
|
= note: this error originates in the attribute macro `debug_middleware` (in Nightly builds, run with -Z macro-backtrace for more info)

View file

@ -0,0 +1,13 @@
use axum::{
extract::Request,
response::Response,
middleware::Next,
debug_middleware,
};
#[debug_middleware]
async fn my_middleware(next: Next, request: Request) -> Response {
next.run(request).await
}
fn main() {}

View file

@ -0,0 +1,5 @@
error: `axum::middleware::Next` must the last argument
--> tests/debug_middleware/fail/next_not_last.rs:9:24
|
9 | async fn my_middleware(next: Next, request: Request) -> Response {
| ^^^^^^^^^^

View file

@ -0,0 +1,9 @@
use axum::{debug_middleware, extract::Request, middleware::Next, response::Response};
#[debug_middleware]
async fn my_middleware(request: Request, next: Next, next2: Next) -> Response {
let _ = next2;
next.run(request).await
}
fn main() {}

View file

@ -0,0 +1,7 @@
error: Middleware functions can only take one argument of type `axum::middleware::Next`
--> tests/debug_middleware/fail/takes_next_twice.rs:3:1
|
3 | #[debug_middleware]
| ^^^^^^^^^^^^^^^^^^^
|
= note: this error originates in the attribute macro `debug_middleware` (in Nightly builds, run with -Z macro-backtrace for more info)

View file

@ -0,0 +1,13 @@
use axum::{
extract::Request,
response::Response,
middleware::Next,
debug_middleware,
};
#[debug_middleware]
async fn my_middleware(request: Request, next: Next) -> Response {
next.run(request).await
}
fn main() {}

View file

@ -18,7 +18,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
a `Router` or `MethodRouter` ([#2586])
- **fixed:** `h2` is no longer pulled as a dependency unless the `http2` feature
is enabled ([#2605])
- **added:** Add `#[debug_middleware]` ([#1993], [#2725])
[#1993]: https://github.com/tokio-rs/axum/pull/1993
[#2725]: https://github.com/tokio-rs/axum/pull/2725
[#2586]: https://github.com/tokio-rs/axum/pull/2586
[#2605]: https://github.com/tokio-rs/axum/pull/2605

View file

@ -54,6 +54,7 @@ memchr = "2.4.1"
mime = "0.3.16"
percent-encoding = "2.1"
pin-project-lite = "0.2.7"
rustversion = "1.0.9"
serde = "1.0"
sync_wrapper = "1.0.0"
tower = { version = "0.4.13", default-features = false, features = ["util"] }
@ -64,7 +65,7 @@ tower-service = "0.3"
axum-macros = { path = "../axum-macros", version = "0.4.1", optional = true }
base64 = { version = "0.21.0", optional = true }
hyper = { version = "1.1.0", optional = true }
hyper-util = { version = "0.1.3", features = ["tokio", "server"], optional = true }
hyper-util = { version = "0.1.3", features = ["tokio", "server", "service"], optional = true }
multer = { version = "3.0.0", optional = true }
serde_json = { version = "1.0", features = ["raw_value"], optional = true }
serde_path_to_error = { version = "0.1.8", optional = true }
@ -109,16 +110,12 @@ features = [
"validate-request",
]
[build-dependencies]
rustversion = "1.0.9"
[dev-dependencies]
anyhow = "1.0"
axum-macros = { path = "../axum-macros", version = "0.4.1", features = ["__private"] }
quickcheck = "1.0"
quickcheck_macros = "1.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] }
rustversion = "1.0.9"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
time = { version = "0.3", features = ["serde-human-readable"] }

View file

@ -1,7 +0,0 @@
#[rustversion::nightly]
fn main() {
println!("cargo:rustc-cfg=nightly_error_messages");
}
#[rustversion::not(nightly)]
fn main() {}

View file

@ -103,6 +103,7 @@ where
}
}
#[allow(dead_code)]
pub(crate) struct MakeErasedRouter<S> {
pub(crate) router: Router<S>,
pub(crate) into_route: fn(Router<S>, S) -> Route,

View file

@ -19,8 +19,7 @@ Types and traits for extracting data from requests.
A handler function is an async function that takes any number of
"extractors" as arguments. An extractor is a type that implements
[`FromRequest`](crate::extract::FromRequest)
or [`FromRequestParts`](crate::extract::FromRequestParts).
[`FromRequest`] or [`FromRequestParts`].
For example, [`Json`] is an extractor that consumes the request body and
deserializes it as JSON into some target type:
@ -284,7 +283,7 @@ let app = Router::new().route("/users", post(create_user));
# Customizing extractor responses
If an extractor fails it will return a response with the error and your
handler will not be called. To customize the error response you have a two
handler will not be called. To customize the error response you have two
options:
1. Use `Result<T, T::Rejection>` as your extractor like shown in ["Optional

View file

@ -166,7 +166,7 @@ In general you can return tuples like:
This means you cannot accidentally override the status or body as [`IntoResponseParts`] only allows
setting headers and extensions.
Use [`Response`](crate::response::Response) for more low level control:
Use [`Response`] for more low level control:
```rust,no_run
use axum::{

View file

@ -181,7 +181,7 @@ router.
# Panics
- If the route overlaps with another route. See [`Router::route`]
for more details.
for more details.
- If the route contains a wildcard (`*`).
- If `path` is empty.

View file

@ -5,8 +5,7 @@ can be either static, a capture, or a wildcard.
`method_router` is the [`MethodRouter`] that should receive the request if the
path matches `path`. `method_router` will commonly be a handler wrapped in a method
router like [`get`](crate::routing::get). See [`handler`](crate::handler) for
more details on handlers.
router like [`get`]. See [`handler`](crate::handler) for more details on handlers.
# Static paths
@ -56,7 +55,22 @@ Note that `/*key` doesn't match empty segments. Thus:
- `/*key` doesn't match `/` but does match `/a`, `/a/`, etc.
- `/x/*key` doesn't match `/x` or `/x/` but does match `/x/a`, `/x/a/`, etc.
Wildcard captures can also be extracted using [`Path`](crate::extract::Path).
Wildcard captures can also be extracted using [`Path`](crate::extract::Path):
```rust
use axum::{
Router,
routing::get,
extract::Path,
};
let app: Router = Router::new().route("/*key", get(handler));
async fn handler(Path(path): Path<String>) -> String {
path
}
```
Note that the leading slash is not included, i.e. for the route `/foo/*rest` and
the path `/foo/bar/baz` the value of `rest` will be `bar/baz`.

View file

@ -125,8 +125,8 @@ pub use self::service::HandlerService;
/// )));
/// # let _: Router = app;
/// ```
#[cfg_attr(
nightly_error_messages,
#[rustversion::attr(
since(1.78),
diagnostic::on_unimplemented(
note = "Consider using `#[axum::debug_handler]` to improve the error message"
)

View file

@ -17,8 +17,7 @@ use serde::{de::DeserializeOwned, Serialize};
///
/// - The request doesn't have a `Content-Type: application/json` (or similar) header.
/// - The body doesn't contain syntactically valid JSON.
/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target
/// type.
/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target type.
/// - Buffering the request body fails.
///
/// ⚠️ Since parsing JSON requires consuming the request body, the `Json` extractor must be

View file

@ -463,7 +463,7 @@ pub use self::form::Form;
pub use axum_core::{BoxError, Error, RequestExt, RequestPartsExt};
#[cfg(feature = "macros")]
pub use axum_macros::debug_handler;
pub use axum_macros::{debug_handler, debug_middleware};
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[doc(inline)]

View file

@ -185,6 +185,7 @@ where
}
#[doc = include_str!("../docs/routing/nest.md")]
#[doc(alias = "scope")] // Some web frameworks like actix-web use this term
#[track_caller]
pub fn nest(self, path: &str, router: Router<S>) -> Self {
let RouterInner {

View file

@ -7,9 +7,7 @@ use std::{
io,
marker::PhantomData,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
@ -18,13 +16,12 @@ use futures_util::{pin_mut, FutureExt};
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
#[cfg(any(feature = "http1", feature = "http2"))]
use hyper_util::server::conn::auto::Builder;
use pin_project_lite::pin_project;
use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
use tokio::{
net::{TcpListener, TcpStream},
sync::watch,
};
use tower::util::{Oneshot, ServiceExt};
use tower::ServiceExt as _;
use tower_service::Service;
/// Serve the service with the supplied listener.
@ -243,11 +240,10 @@ where
remote_addr,
})
.await
.unwrap_or_else(|err| match err {});
.unwrap_or_else(|err| match err {})
.map_request(|req: Request<Incoming>| req.map(Body::new));
let hyper_service = TowerToHyperService {
service: tower_service,
};
let hyper_service = TowerToHyperService::new(tower_service);
tokio::spawn(async move {
match Builder::new(TokioExecutor::new())
@ -404,11 +400,10 @@ where
remote_addr,
})
.await
.unwrap_or_else(|err| match err {});
.unwrap_or_else(|err| match err {})
.map_request(|req: Request<Incoming>| req.map(Body::new));
let hyper_service = TowerToHyperService {
service: tower_service,
};
let hyper_service = TowerToHyperService::new(tower_service);
let signal_tx = Arc::clone(&signal_tx);
@ -518,49 +513,6 @@ mod private {
}
}
#[derive(Debug, Copy, Clone)]
struct TowerToHyperService<S> {
service: S,
}
impl<S> hyper::service::Service<Request<Incoming>> for TowerToHyperService<S>
where
S: tower_service::Service<Request> + Clone,
{
type Response = S::Response;
type Error = S::Error;
type Future = TowerToHyperServiceFuture<S, Request>;
fn call(&self, req: Request<Incoming>) -> Self::Future {
let req = req.map(Body::new);
TowerToHyperServiceFuture {
future: self.service.clone().oneshot(req),
}
}
}
pin_project! {
struct TowerToHyperServiceFuture<S, R>
where
S: tower_service::Service<R>,
{
#[pin]
future: Oneshot<S, R>,
}
}
impl<S, R> Future for TowerToHyperServiceFuture<S, R>
where
S: tower_service::Service<R>,
{
type Output = Result<S::Response, S::Error>;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().future.poll(cx)
}
}
/// An incoming stream.
///
/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].

View file

@ -130,8 +130,8 @@ async fn websocket(stream: WebSocket, state: Arc<AppState>) {
// If any one of the tasks run to completion, we abort the other.
tokio::select! {
_ = (&mut send_task) => recv_task.abort(),
_ = (&mut recv_task) => send_task.abort(),
_ = &mut send_task => recv_task.abort(),
_ = &mut recv_task => send_task.abort(),
};
// Send "user left" message (similar to "joined" above).

View file

@ -43,7 +43,7 @@ async fn serve_plain() {
// We don't need to call `poll_ready` because `Router` is always ready.
let tower_service = app.clone();
// Spawn a task to handle the connection. That way we can multiple connections
// Spawn a task to handle the connection. That way we can handle multiple connections
// concurrently.
tokio::spawn(async move {
// Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.