From a0ac8a5b78a214c1af481fdb6f490f82e0aa80df Mon Sep 17 00:00:00 2001 From: Sunli Date: Fri, 6 Aug 2021 16:31:38 +0800 Subject: [PATCH] Fixed the implementation of `IntoResponse` of `(HeaderMap, T)` and `(StatusCode, HeaderMap, T)` would ignore headers from `T` (#137) Co-authored-by: David Pedersen --- CHANGELOG.md | 1 + src/response.rs | 43 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ade148fc..dd955803 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Re-export `http` crate and `hyper::Server`. ([#110](https://github.com/tokio-rs/axum/pull/110)) - Fix `Query` and `Form` extractors giving bad request error when query string is empty. ([#117](https://github.com/tokio-rs/axum/pull/117)) - Add `Path` extractor. ([#124](https://github.com/tokio-rs/axum/pull/124)) +- Fixed the implementation of `IntoResponse` of `(HeaderMap, T)` and `(StatusCode, HeaderMap, T)` would ignore headers from `T` ([#137](https://github.com/tokio-rs/axum/pull/137)) ## Breaking changes diff --git a/src/response.rs b/src/response.rs index 80d2d7a2..0ed511e4 100644 --- a/src/response.rs +++ b/src/response.rs @@ -158,7 +158,7 @@ where { fn into_response(self) -> Response { let mut res = self.1.into_response(); - *res.headers_mut() = self.0; + res.headers_mut().extend(self.0); res } } @@ -170,7 +170,7 @@ where fn into_response(self) -> Response { let mut res = self.2.into_response(); *res.status_mut() = self.0; - *res.headers_mut() = self.1; + res.headers_mut().extend(self.1); res } } @@ -264,3 +264,42 @@ impl From for Json { Self(inner) } } + +#[cfg(test)] +mod tests { + use super::*; + use http::header::{HeaderMap, HeaderName}; + + #[test] + fn test_merge_headers() { + struct MyResponse; + + impl IntoResponse for MyResponse { + fn into_response(self) -> Response { + let mut resp = Response::new(String::new().into()); + resp.headers_mut() + .insert(HeaderName::from_static("a"), HeaderValue::from_static("1")); + resp + } + } + + fn check(resp: impl IntoResponse) { + let resp = resp.into_response(); + assert_eq!( + resp.headers().get(HeaderName::from_static("a")).unwrap(), + &HeaderValue::from_static("1") + ); + assert_eq!( + resp.headers().get(HeaderName::from_static("b")).unwrap(), + &HeaderValue::from_static("2") + ); + } + + let headers: HeaderMap = + std::iter::once((HeaderName::from_static("b"), HeaderValue::from_static("2"))) + .collect(); + + check((headers.clone(), MyResponse)); + check((StatusCode::OK, headers, MyResponse)); + } +}