Make IntoResponseParts more flexible (#813)

* Make `IntoResponseParts` more flexible

* fix `impl<T> IntoResponseParts for TypedHeader<T>`

* fix
This commit is contained in:
David Pedersen 2022-03-02 12:41:14 +01:00 committed by GitHub
parent 24359ebd4d
commit 84c725a1ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 149 additions and 138 deletions

View file

@ -460,16 +460,18 @@ macro_rules! impl_into_response {
let ($($ty),*, res) = self; let ($($ty),*, res) = self;
let res = res.into_response(); let res = res.into_response();
let mut parts = ResponseParts { res: Ok(res) }; let parts = ResponseParts { res };
$( $(
$ty.into_response_parts(&mut parts); let parts = match $ty.into_response_parts(parts) {
Ok(parts) => parts,
Err(err) => {
return err.into_response();
}
};
)* )*
match parts.res { parts.res
Ok(res) => res,
Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err).into_response(),
}
} }
} }
@ -483,20 +485,21 @@ macro_rules! impl_into_response {
let (status, $($ty),*, res) = self; let (status, $($ty),*, res) = self;
let res = res.into_response(); let res = res.into_response();
let mut parts = ResponseParts { res: Ok(res) }; let parts = ResponseParts { res };
$( $(
$ty.into_response_parts(&mut parts); let parts = match $ty.into_response_parts(parts) {
Ok(parts) => parts,
Err(err) => {
return err.into_response();
}
};
)* )*
match parts.res { let mut res = parts.res;
Ok(mut res) => {
*res.status_mut() = status; *res.status_mut() = status;
res res
} }
Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err).into_response(),
}
}
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
@ -509,21 +512,22 @@ macro_rules! impl_into_response {
let (version, status, $($ty),*, res) = self; let (version, status, $($ty),*, res) = self;
let res = res.into_response(); let res = res.into_response();
let mut parts = ResponseParts { res: Ok(res) }; let parts = ResponseParts { res };
$( $(
$ty.into_response_parts(&mut parts); let parts = match $ty.into_response_parts(parts) {
Ok(parts) => parts,
Err(err) => {
return err.into_response();
}
};
)* )*
match parts.res { let mut res = parts.res;
Ok(mut res) => {
*res.version_mut() = version; *res.version_mut() = version;
*res.status_mut() = status; *res.status_mut() = status;
res res
} }
Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err).into_response(),
}
}
} }
} }
} }

View file

@ -1,14 +1,22 @@
use super::Response; use super::{IntoResponse, Response};
use http::header::{HeaderMap, HeaderName, HeaderValue}; use http::{
use std::{convert::TryInto, fmt}; header::{HeaderMap, HeaderName, HeaderValue},
Extensions, StatusCode,
};
use std::{
convert::{Infallible, TryInto},
fmt,
};
/// Trait for adding headers and extensions to a response. /// Trait for adding headers and extensions to a response.
///
/// You generally don't need to implement this trait manually. It's recommended instead to rely
/// on the implementations in axum.
pub trait IntoResponseParts { pub trait IntoResponseParts {
/// The type returned in the event of an error.
///
/// This can be used to fallibly convert types into headers or extensions.
type Error: IntoResponse;
/// Set parts of the response /// Set parts of the response
fn into_response_parts(self, res: &mut ResponseParts); fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error>;
} }
/// Parts of a response. /// Parts of a response.
@ -16,96 +24,37 @@ pub trait IntoResponseParts {
/// Used with [`IntoResponseParts`]. /// Used with [`IntoResponseParts`].
#[derive(Debug)] #[derive(Debug)]
pub struct ResponseParts { pub struct ResponseParts {
pub(crate) res: Result<Response, String>, pub(crate) res: Response,
} }
impl ResponseParts { impl ResponseParts {
/// Insert a header into the response. /// Gets a reference to the response headers.
/// pub fn headers(&self) -> &HeaderMap {
/// If the header already exists, it will be overwritten. self.res.headers()
pub fn insert_header<K, V>(&mut self, key: K, value: V)
where
K: TryInto<HeaderName>,
K::Error: fmt::Display,
V: TryInto<HeaderValue>,
V::Error: fmt::Display,
{
self.update_headers(key, value, |headers, key, value| {
headers.insert(key, value);
});
} }
/// Append a header to the response. /// Gets a mutable reference to the response headers.
/// pub fn headers_mut(&mut self) -> &mut HeaderMap {
/// If the header already exists it will be appended to. self.res.headers_mut()
pub fn append_header<K, V>(&mut self, key: K, value: V)
where
K: TryInto<HeaderName>,
K::Error: fmt::Display,
V: TryInto<HeaderValue>,
V::Error: fmt::Display,
{
self.update_headers(key, value, |headers, key, value| {
headers.append(key, value);
});
} }
fn update_headers<K, V, F>(&mut self, key: K, value: V, f: F) /// Gets a reference to the response extensions.
where pub fn extensions(&self) -> &Extensions {
K: TryInto<HeaderName>, self.res.extensions()
K::Error: fmt::Display,
V: TryInto<HeaderValue>,
V::Error: fmt::Display,
F: FnOnce(&mut HeaderMap, HeaderName, HeaderValue),
{
if let Ok(response) = &mut self.res {
let key = match key.try_into() {
Ok(key) => key,
Err(err) => {
self.res = Err(err.to_string());
return;
}
};
let value = match value.try_into() {
Ok(value) => value,
Err(err) => {
self.res = Err(err.to_string());
return;
}
};
f(response.headers_mut(), key, value);
}
} }
/// Insert an extension into the response. /// Gets a mutable reference to the response extensions.
/// pub fn extensions_mut(&mut self) -> &mut Extensions {
/// If the extension already exists it will be overwritten. self.res.extensions_mut()
pub fn insert_extension<T>(&mut self, extension: T)
where
T: Send + Sync + 'static,
{
if let Ok(res) = &mut self.res {
res.extensions_mut().insert(extension);
}
}
}
impl Extend<(Option<HeaderName>, HeaderValue)> for ResponseParts {
fn extend<T>(&mut self, iter: T)
where
T: IntoIterator<Item = (Option<HeaderName>, HeaderValue)>,
{
if let Ok(res) = &mut self.res {
res.headers_mut().extend(iter);
}
} }
} }
impl IntoResponseParts for HeaderMap { impl IntoResponseParts for HeaderMap {
fn into_response_parts(self, res: &mut ResponseParts) { type Error = Infallible;
res.extend(self);
fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
res.headers_mut().extend(self);
Ok(res)
} }
} }
@ -116,9 +65,82 @@ where
V: TryInto<HeaderValue>, V: TryInto<HeaderValue>,
V::Error: fmt::Display, V::Error: fmt::Display,
{ {
fn into_response_parts(self, res: &mut ResponseParts) { type Error = TryIntoHeaderError<K::Error, V::Error>;
fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
for (key, value) in self { for (key, value) in self {
res.insert_header(key, value); let key = key.try_into().map_err(TryIntoHeaderError::key)?;
let value = value.try_into().map_err(TryIntoHeaderError::value)?;
res.headers_mut().insert(key, value);
}
Ok(res)
}
}
/// Error returned if converting a value to a header fails.
#[derive(Debug)]
pub struct TryIntoHeaderError<K, V> {
kind: TryIntoHeaderErrorKind<K, V>,
}
impl<K, V> TryIntoHeaderError<K, V> {
fn key(err: K) -> Self {
Self {
kind: TryIntoHeaderErrorKind::Key(err),
}
}
fn value(err: V) -> Self {
Self {
kind: TryIntoHeaderErrorKind::Value(err),
}
}
}
#[derive(Debug)]
enum TryIntoHeaderErrorKind<K, V> {
Key(K),
Value(V),
}
impl<K, V> IntoResponse for TryIntoHeaderError<K, V>
where
K: fmt::Display,
V: fmt::Display,
{
fn into_response(self) -> Response {
match self.kind {
TryIntoHeaderErrorKind::Key(inner) => {
(StatusCode::INTERNAL_SERVER_ERROR, inner.to_string()).into_response()
}
TryIntoHeaderErrorKind::Value(inner) => {
(StatusCode::INTERNAL_SERVER_ERROR, inner.to_string()).into_response()
}
}
}
}
impl<K, V> fmt::Display for TryIntoHeaderError<K, V> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.kind {
TryIntoHeaderErrorKind::Key(_) => write!(f, "failed to convert key to a header name"),
TryIntoHeaderErrorKind::Value(_) => {
write!(f, "failed to convert value to a header value")
}
}
}
}
impl<K, V> std::error::Error for TryIntoHeaderError<K, V>
where
K: std::error::Error + 'static,
V: std::error::Error + 'static,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.kind {
TryIntoHeaderErrorKind::Key(inner) => Some(inner),
TryIntoHeaderErrorKind::Value(inner) => Some(inner),
} }
} }
} }

View file

@ -11,7 +11,7 @@ mod into_response_parts;
pub use self::{ pub use self::{
into_response::IntoResponse, into_response::IntoResponse,
into_response_parts::{IntoResponseParts, ResponseParts}, into_response_parts::{IntoResponseParts, ResponseParts, TryIntoHeaderError},
}; };
/// Type alias for [`http::Response`] whose body type defaults to [`BoxBody`], the most common body /// Type alias for [`http::Response`] whose body type defaults to [`BoxBody`], the most common body

View file

@ -6,6 +6,7 @@ use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response, ResponseParts}; use axum_core::response::{IntoResponse, Response, ResponseParts};
use http::Request; use http::Request;
use std::{ use std::{
convert::Infallible,
ops::Deref, ops::Deref,
task::{Context, Poll}, task::{Context, Poll},
}; };
@ -107,8 +108,11 @@ impl<T> IntoResponseParts for Extension<T>
where where
T: Send + Sync + 'static, T: Send + Sync + 'static,
{ {
fn into_response_parts(self, res: &mut ResponseParts) { type Error = Infallible;
res.insert_extension(self.0);
fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
res.extensions_mut().insert(self.0);
Ok(res)
} }
} }

View file

@ -2,8 +2,7 @@ use crate::extract::{FromRequest, RequestParts};
use async_trait::async_trait; use async_trait::async_trait;
use axum_core::response::{IntoResponse, IntoResponseParts, Response, ResponseParts}; use axum_core::response::{IntoResponse, IntoResponseParts, Response, ResponseParts};
use headers::HeaderMapExt; use headers::HeaderMapExt;
use http::header::{HeaderName, HeaderValue}; use std::{convert::Infallible, ops::Deref};
use std::ops::Deref;
/// Extractor and response that works with typed header values from [`headers`]. /// Extractor and response that works with typed header values from [`headers`].
/// ///
@ -87,29 +86,11 @@ impl<T> IntoResponseParts for TypedHeader<T>
where where
T: headers::Header, T: headers::Header,
{ {
fn into_response_parts(self, res: &mut ResponseParts) { type Error = Infallible;
struct ExtendHeaders<'a> {
res: &'a mut ResponseParts,
key: &'static HeaderName,
}
impl<'a> Extend<HeaderValue> for ExtendHeaders<'a> { fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
fn extend<T>(&mut self, iter: T) res.headers_mut().typed_insert(self.0);
where Ok(res)
T: IntoIterator<Item = HeaderValue>,
{
for value in iter {
self.res.append_header(self.key, value);
}
}
}
let mut extend = ExtendHeaders {
res,
key: T::name(),
};
self.0.encode(&mut extend);
} }
} }