diff --git a/axum-extra/src/extract/form.rs b/axum-extra/src/extract/form.rs index adad8ac6..f31c5499 100644 --- a/axum-extra/src/extract/form.rs +++ b/axum-extra/src/extract/form.rs @@ -1,11 +1,11 @@ +use axum::extract::rejection::RawFormRejection; use axum::{ - extract::{rejection::RawFormRejection, FromRequest, RawForm, Request}, - response::{IntoResponse, Response}, - Error, RequestExt, + extract::{FromRequest, RawForm, Request}, + RequestExt, }; -use http::StatusCode; +use axum_core::__composite_rejection as composite_rejection; +use axum_core::__define_rejection as define_rejection; use serde::de::DeserializeOwned; -use std::fmt; /// Extractor that deserializes `application/x-www-form-urlencoded` requests /// into some type. @@ -51,76 +51,32 @@ where type Rejection = FormRejection; async fn from_request(req: Request, _state: &S) -> Result { - let RawForm(bytes) = req - .extract() - .await - .map_err(FormRejection::RawFormRejection)?; + let RawForm(bytes) = req.extract().await?; let deserializer = serde_html_form::Deserializer::new(form_urlencoded::parse(&bytes)); - serde_path_to_error::deserialize::<_, T>(deserializer) - .map(Self) - .map_err(|err| FormRejection::FailedToDeserializeForm(Error::new(err))) + let value = serde_path_to_error::deserialize::<_, T>(deserializer) + .map_err(FailedToDeserializeForm::from_err)?; + + Ok(Self(value)) } } -/// Rejection used for [`Form`]. -/// -/// Contains one variant for each way the [`Form`] extractor can fail. -#[derive(Debug)] -#[non_exhaustive] -#[cfg(feature = "form")] -pub enum FormRejection { - #[allow(missing_docs)] - RawFormRejection(RawFormRejection), - #[allow(missing_docs)] - FailedToDeserializeForm(Error), +define_rejection! { + #[status = BAD_REQUEST] + #[body = "Failed to deserialize form"] + /// Rejection type used if the [`Form`](Form) extractor is unable to + /// deserialize the form into the target type. + pub struct FailedToDeserializeForm(Error); } -impl FormRejection { - /// Get the status code used for this rejection. - pub fn status(&self) -> StatusCode { - // Make sure to keep this in sync with `IntoResponse` impl. - match self { - Self::RawFormRejection(inner) => inner.status(), - Self::FailedToDeserializeForm(_) => StatusCode::BAD_REQUEST, - } - } -} - -impl IntoResponse for FormRejection { - fn into_response(self) -> Response { - let status = self.status(); - match self { - Self::RawFormRejection(inner) => inner.into_response(), - Self::FailedToDeserializeForm(inner) => { - let body = format!("Failed to deserialize form: {inner}"); - axum_core::__log_rejection!( - rejection_type = Self, - body_text = body, - status = status, - ); - (status, body).into_response() - } - } - } -} - -impl fmt::Display for FormRejection { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::RawFormRejection(inner) => inner.fmt(f), - Self::FailedToDeserializeForm(inner) => inner.fmt(f), - } - } -} - -impl std::error::Error for FormRejection { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Self::RawFormRejection(inner) => Some(inner), - Self::FailedToDeserializeForm(inner) => Some(inner), - } +composite_rejection! { + /// Rejection used for [`Form`]. + /// + /// Contains one variant for each way the [`Form`] extractor can fail. + pub enum FormRejection { + RawFormRejection, + FailedToDeserializeForm, } } @@ -131,6 +87,7 @@ mod tests { use axum::routing::{on, post, MethodFilter}; use axum::Router; use http::header::CONTENT_TYPE; + use http::StatusCode; use mime::APPLICATION_WWW_FORM_URLENCODED; use serde::Deserialize;