diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index 11a128c3..72bf2f47 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -1,11 +1,8 @@ -use axum::{ - extract::FromRequestParts, - response::{IntoResponse, Response}, - Error, -}; -use http::{request::Parts, StatusCode}; +use axum::extract::FromRequestParts; +use axum_core::__composite_rejection as composite_rejection; +use axum_core::__define_rejection as define_rejection; +use http::request::Parts; use serde::de::DeserializeOwned; -use std::fmt; /// Extractor that deserializes query strings into some type. /// @@ -93,63 +90,27 @@ where let deserializer = serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes())); let value = serde_path_to_error::deserialize(deserializer) - .map_err(|err| QueryRejection::FailedToDeserializeQueryString(Error::new(err)))?; + .map_err(FailedToDeserializeQueryString::from_err)?; Ok(Query(value)) } } axum_core::__impl_deref!(Query); -/// Rejection used for [`Query`]. -/// -/// Contains one variant for each way the [`Query`] extractor can fail. -#[derive(Debug)] -#[non_exhaustive] -#[cfg(feature = "query")] -pub enum QueryRejection { - #[allow(missing_docs)] - FailedToDeserializeQueryString(Error), +define_rejection! { + #[status = BAD_REQUEST] + #[body = "Failed to deserialize query string"] + /// Rejection type used if the [`Query`] extractor is unable to + /// deserialize the query string into the target type. + pub struct FailedToDeserializeQueryString(Error); } -impl QueryRejection { - /// Get the status code used for this rejection. - pub fn status(&self) -> StatusCode { - match self { - Self::FailedToDeserializeQueryString(_) => StatusCode::BAD_REQUEST, - } - } -} - -impl IntoResponse for QueryRejection { - fn into_response(self) -> Response { - let status = self.status(); - match self { - Self::FailedToDeserializeQueryString(inner) => { - let body = format!("Failed to deserialize query string: {inner}"); - axum_core::__log_rejection!( - rejection_type = Self, - body_text = body, - status = status, - ); - (status, body).into_response() - } - } - } -} - -impl fmt::Display for QueryRejection { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::FailedToDeserializeQueryString(inner) => inner.fmt(f), - } - } -} - -impl std::error::Error for QueryRejection { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Self::FailedToDeserializeQueryString(inner) => Some(inner), - } +composite_rejection! { + /// Rejection used for [`Query`]. + /// + /// Contains one variant for each way the [`Query`] extractor can fail. + pub enum QueryRejection { + FailedToDeserializeQueryString, } } @@ -207,9 +168,8 @@ where if let Some(query) = parts.uri.query() { let deserializer = serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes())); - let value = serde_path_to_error::deserialize(deserializer).map_err(|err| { - OptionalQueryRejection::FailedToDeserializeQueryString(Error::new(err)) - })?; + let value = serde_path_to_error::deserialize(deserializer) + .map_err(FailedToDeserializeQueryString::from_err)?; Ok(OptionalQuery(Some(value))) } else { Ok(OptionalQuery(None)) @@ -233,42 +193,12 @@ impl std::ops::DerefMut for OptionalQuery { } } -/// Rejection used for [`OptionalQuery`]. -/// -/// Contains one variant for each way the [`OptionalQuery`] extractor can fail. -#[derive(Debug)] -#[non_exhaustive] -#[cfg(feature = "query")] -pub enum OptionalQueryRejection { - #[allow(missing_docs)] - FailedToDeserializeQueryString(Error), -} - -impl IntoResponse for OptionalQueryRejection { - fn into_response(self) -> Response { - match self { - Self::FailedToDeserializeQueryString(inner) => ( - StatusCode::BAD_REQUEST, - format!("Failed to deserialize query string: {inner}"), - ) - .into_response(), - } - } -} - -impl fmt::Display for OptionalQueryRejection { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::FailedToDeserializeQueryString(inner) => inner.fmt(f), - } - } -} - -impl std::error::Error for OptionalQueryRejection { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Self::FailedToDeserializeQueryString(inner) => Some(inner), - } +composite_rejection! { + /// Rejection used for [`OptionalQuery`]. + /// + /// Contains one variant for each way the [`OptionalQuery`] extractor can fail. + pub enum OptionalQueryRejection { + FailedToDeserializeQueryString, } } @@ -279,6 +209,7 @@ mod tests { use axum::routing::{get, post}; use axum::Router; use http::header::CONTENT_TYPE; + use http::StatusCode; use serde::Deserialize; #[tokio::test]