From 84c725a1ae40ee3ae4dce83da4250447a3242215 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 2 Mar 2022 12:41:14 +0100 Subject: [PATCH] Make `IntoResponseParts` more flexible (#813) * Make `IntoResponseParts` more flexible * fix `impl IntoResponseParts for TypedHeader` * fix --- axum-core/src/response/into_response.rs | 54 ++--- axum-core/src/response/into_response_parts.rs | 194 ++++++++++-------- axum-core/src/response/mod.rs | 2 +- axum/src/extension.rs | 8 +- axum/src/typed_header.rs | 29 +-- 5 files changed, 149 insertions(+), 138 deletions(-) diff --git a/axum-core/src/response/into_response.rs b/axum-core/src/response/into_response.rs index f11158d7..8f49d317 100644 --- a/axum-core/src/response/into_response.rs +++ b/axum-core/src/response/into_response.rs @@ -460,16 +460,18 @@ macro_rules! impl_into_response { let ($($ty),*, res) = self; let res = res.into_response(); - let mut parts = ResponseParts { res: Ok(res) }; + let parts = ResponseParts { res }; $( - $ty.into_response_parts(&mut parts); + let parts = match $ty.into_response_parts(parts) { + Ok(parts) => parts, + Err(err) => { + return err.into_response(); + } + }; )* - match parts.res { - Ok(res) => res, - Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err).into_response(), - } + parts.res } } @@ -483,19 +485,20 @@ macro_rules! impl_into_response { let (status, $($ty),*, res) = self; let res = res.into_response(); - let mut parts = ResponseParts { res: Ok(res) }; + let parts = ResponseParts { res }; $( - $ty.into_response_parts(&mut parts); + let parts = match $ty.into_response_parts(parts) { + Ok(parts) => parts, + Err(err) => { + return err.into_response(); + } + }; )* - match parts.res { - Ok(mut res) => { - *res.status_mut() = status; - res - } - Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err).into_response(), - } + let mut res = parts.res; + *res.status_mut() = status; + res } } @@ -509,20 +512,21 @@ macro_rules! impl_into_response { let (version, status, $($ty),*, res) = self; let res = res.into_response(); - let mut parts = ResponseParts { res: Ok(res) }; + let parts = ResponseParts { res }; $( - $ty.into_response_parts(&mut parts); + let parts = match $ty.into_response_parts(parts) { + Ok(parts) => parts, + Err(err) => { + return err.into_response(); + } + }; )* - match parts.res { - Ok(mut res) => { - *res.version_mut() = version; - *res.status_mut() = status; - res - } - Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err).into_response(), - } + let mut res = parts.res; + *res.version_mut() = version; + *res.status_mut() = status; + res } } } diff --git a/axum-core/src/response/into_response_parts.rs b/axum-core/src/response/into_response_parts.rs index 4a403e6d..bbe603ac 100644 --- a/axum-core/src/response/into_response_parts.rs +++ b/axum-core/src/response/into_response_parts.rs @@ -1,14 +1,22 @@ -use super::Response; -use http::header::{HeaderMap, HeaderName, HeaderValue}; -use std::{convert::TryInto, fmt}; +use super::{IntoResponse, Response}; +use http::{ + header::{HeaderMap, HeaderName, HeaderValue}, + Extensions, StatusCode, +}; +use std::{ + convert::{Infallible, TryInto}, + fmt, +}; /// Trait for adding headers and extensions to a response. -/// -/// You generally don't need to implement this trait manually. It's recommended instead to rely -/// on the implementations in axum. pub trait IntoResponseParts { + /// The type returned in the event of an error. + /// + /// This can be used to fallibly convert types into headers or extensions. + type Error: IntoResponse; + /// Set parts of the response - fn into_response_parts(self, res: &mut ResponseParts); + fn into_response_parts(self, res: ResponseParts) -> Result; } /// Parts of a response. @@ -16,96 +24,37 @@ pub trait IntoResponseParts { /// Used with [`IntoResponseParts`]. #[derive(Debug)] pub struct ResponseParts { - pub(crate) res: Result, + pub(crate) res: Response, } impl ResponseParts { - /// Insert a header into the response. - /// - /// If the header already exists, it will be overwritten. - pub fn insert_header(&mut self, key: K, value: V) - where - K: TryInto, - K::Error: fmt::Display, - V: TryInto, - V::Error: fmt::Display, - { - self.update_headers(key, value, |headers, key, value| { - headers.insert(key, value); - }); + /// Gets a reference to the response headers. + pub fn headers(&self) -> &HeaderMap { + self.res.headers() } - /// Append a header to the response. - /// - /// If the header already exists it will be appended to. - pub fn append_header(&mut self, key: K, value: V) - where - K: TryInto, - K::Error: fmt::Display, - V: TryInto, - V::Error: fmt::Display, - { - self.update_headers(key, value, |headers, key, value| { - headers.append(key, value); - }); + /// Gets a mutable reference to the response headers. + pub fn headers_mut(&mut self) -> &mut HeaderMap { + self.res.headers_mut() } - fn update_headers(&mut self, key: K, value: V, f: F) - where - K: TryInto, - K::Error: fmt::Display, - V: TryInto, - V::Error: fmt::Display, - F: FnOnce(&mut HeaderMap, HeaderName, HeaderValue), - { - if let Ok(response) = &mut self.res { - let key = match key.try_into() { - Ok(key) => key, - Err(err) => { - self.res = Err(err.to_string()); - return; - } - }; - - let value = match value.try_into() { - Ok(value) => value, - Err(err) => { - self.res = Err(err.to_string()); - return; - } - }; - - f(response.headers_mut(), key, value); - } + /// Gets a reference to the response extensions. + pub fn extensions(&self) -> &Extensions { + self.res.extensions() } - /// Insert an extension into the response. - /// - /// If the extension already exists it will be overwritten. - pub fn insert_extension(&mut self, extension: T) - where - T: Send + Sync + 'static, - { - if let Ok(res) = &mut self.res { - res.extensions_mut().insert(extension); - } - } -} - -impl Extend<(Option, HeaderValue)> for ResponseParts { - fn extend(&mut self, iter: T) - where - T: IntoIterator, HeaderValue)>, - { - if let Ok(res) = &mut self.res { - res.headers_mut().extend(iter); - } + /// Gets a mutable reference to the response extensions. + pub fn extensions_mut(&mut self) -> &mut Extensions { + self.res.extensions_mut() } } impl IntoResponseParts for HeaderMap { - fn into_response_parts(self, res: &mut ResponseParts) { - res.extend(self); + type Error = Infallible; + + fn into_response_parts(self, mut res: ResponseParts) -> Result { + res.headers_mut().extend(self); + Ok(res) } } @@ -116,9 +65,82 @@ where V: TryInto, V::Error: fmt::Display, { - fn into_response_parts(self, res: &mut ResponseParts) { + type Error = TryIntoHeaderError; + + fn into_response_parts(self, mut res: ResponseParts) -> Result { for (key, value) in self { - res.insert_header(key, value); + let key = key.try_into().map_err(TryIntoHeaderError::key)?; + let value = value.try_into().map_err(TryIntoHeaderError::value)?; + res.headers_mut().insert(key, value); + } + + Ok(res) + } +} + +/// Error returned if converting a value to a header fails. +#[derive(Debug)] +pub struct TryIntoHeaderError { + kind: TryIntoHeaderErrorKind, +} + +impl TryIntoHeaderError { + fn key(err: K) -> Self { + Self { + kind: TryIntoHeaderErrorKind::Key(err), + } + } + + fn value(err: V) -> Self { + Self { + kind: TryIntoHeaderErrorKind::Value(err), + } + } +} + +#[derive(Debug)] +enum TryIntoHeaderErrorKind { + Key(K), + Value(V), +} + +impl IntoResponse for TryIntoHeaderError +where + K: fmt::Display, + V: fmt::Display, +{ + fn into_response(self) -> Response { + match self.kind { + TryIntoHeaderErrorKind::Key(inner) => { + (StatusCode::INTERNAL_SERVER_ERROR, inner.to_string()).into_response() + } + TryIntoHeaderErrorKind::Value(inner) => { + (StatusCode::INTERNAL_SERVER_ERROR, inner.to_string()).into_response() + } + } + } +} + +impl fmt::Display for TryIntoHeaderError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.kind { + TryIntoHeaderErrorKind::Key(_) => write!(f, "failed to convert key to a header name"), + TryIntoHeaderErrorKind::Value(_) => { + write!(f, "failed to convert value to a header value") + } + } + } +} + +impl std::error::Error for TryIntoHeaderError +where + K: std::error::Error + 'static, + V: std::error::Error + 'static, +{ + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.kind { + TryIntoHeaderErrorKind::Key(inner) => Some(inner), + TryIntoHeaderErrorKind::Value(inner) => Some(inner), } } } diff --git a/axum-core/src/response/mod.rs b/axum-core/src/response/mod.rs index d5bbf9a8..687743fe 100644 --- a/axum-core/src/response/mod.rs +++ b/axum-core/src/response/mod.rs @@ -11,7 +11,7 @@ mod into_response_parts; pub use self::{ into_response::IntoResponse, - into_response_parts::{IntoResponseParts, ResponseParts}, + into_response_parts::{IntoResponseParts, ResponseParts, TryIntoHeaderError}, }; /// Type alias for [`http::Response`] whose body type defaults to [`BoxBody`], the most common body diff --git a/axum/src/extension.rs b/axum/src/extension.rs index ece4d68e..83ba7111 100644 --- a/axum/src/extension.rs +++ b/axum/src/extension.rs @@ -6,6 +6,7 @@ use async_trait::async_trait; use axum_core::response::{IntoResponse, Response, ResponseParts}; use http::Request; use std::{ + convert::Infallible, ops::Deref, task::{Context, Poll}, }; @@ -107,8 +108,11 @@ impl IntoResponseParts for Extension where T: Send + Sync + 'static, { - fn into_response_parts(self, res: &mut ResponseParts) { - res.insert_extension(self.0); + type Error = Infallible; + + fn into_response_parts(self, mut res: ResponseParts) -> Result { + res.extensions_mut().insert(self.0); + Ok(res) } } diff --git a/axum/src/typed_header.rs b/axum/src/typed_header.rs index 5c311678..bb5bc721 100644 --- a/axum/src/typed_header.rs +++ b/axum/src/typed_header.rs @@ -2,8 +2,7 @@ use crate::extract::{FromRequest, RequestParts}; use async_trait::async_trait; use axum_core::response::{IntoResponse, IntoResponseParts, Response, ResponseParts}; use headers::HeaderMapExt; -use http::header::{HeaderName, HeaderValue}; -use std::ops::Deref; +use std::{convert::Infallible, ops::Deref}; /// Extractor and response that works with typed header values from [`headers`]. /// @@ -87,29 +86,11 @@ impl IntoResponseParts for TypedHeader where T: headers::Header, { - fn into_response_parts(self, res: &mut ResponseParts) { - struct ExtendHeaders<'a> { - res: &'a mut ResponseParts, - key: &'static HeaderName, - } + type Error = Infallible; - impl<'a> Extend for ExtendHeaders<'a> { - fn extend(&mut self, iter: T) - where - T: IntoIterator, - { - for value in iter { - self.res.append_header(self.key, value); - } - } - } - - let mut extend = ExtendHeaders { - res, - key: T::name(), - }; - - self.0.encode(&mut extend); + fn into_response_parts(self, mut res: ResponseParts) -> Result { + res.headers_mut().typed_insert(self.0); + Ok(res) } }