Change HeaderMap extractor to clone the headers (#698)

* Change `HeaderMap` extractor to clone the headers

* fix docs

* changelog

* inline variable

* also add changelog item to axum

* don't list types from axum in axum-core's changelog

* document that `HeaderMap::from_request` clones the headers

* fix typo

* a few more typos
This commit is contained in:
David Pedersen 2022-01-11 20:39:39 +01:00
parent d5694f0d0d
commit 184ea656c0
16 changed files with 92 additions and 144 deletions

View file

@ -7,7 +7,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- None.
- **breaking:** Using `HeaderMap` as an extractor will no longer remove the headers and thus
they'll still be accessible to other extractors, such as `axum::extract::Json`. Instead
`HeaderMap` will clone the headers. You should prefer to use `TypedHeader` to extract only the
headers you need ([#698])
This includes these breaking changes:
- `RequestParts::take_headers` has been removed.
- `RequestParts::headers` returns `&HeaderMap`.
- `RequestParts::headers_mut` returns `&mut HeaderMap`.
- `HeadersAlreadyExtracted` has been removed.
- The `HeadersAlreadyExtracted` variant has been removed from these rejections:
- `RequestAlreadyExtracted`
- `RequestPartsAlreadyExtracted`
- `<HeaderMap as FromRequest<_>>::Error` has been changed to `std::convert::Infallible`.
[#698]: https://github.com/tokio-rs/axum/pull/698
# 0.1.1 (06. December, 2021)

View file

@ -77,7 +77,7 @@ pub struct RequestParts<B> {
method: Method,
uri: Uri,
version: Version,
headers: Option<HeaderMap>,
headers: HeaderMap,
extensions: Option<Extensions>,
body: Option<B>,
}
@ -107,7 +107,7 @@ impl<B> RequestParts<B> {
method,
uri,
version,
headers: Some(headers),
headers,
extensions: Some(extensions),
body: Some(body),
}
@ -117,14 +117,11 @@ impl<B> RequestParts<B> {
///
/// Fails if
///
/// - The full [`HeaderMap`] has been extracted, that is [`take_headers`]
/// have been called.
/// - The full [`Extensions`] has been extracted, that is
/// [`take_extensions`] have been called.
/// - The request body has been extracted, that is [`take_body`] have been
/// called.
///
/// [`take_headers`]: RequestParts::take_headers
/// [`take_extensions`]: RequestParts::take_extensions
/// [`take_body`]: RequestParts::take_body
pub fn try_into_request(self) -> Result<Request<B>, RequestAlreadyExtracted> {
@ -132,7 +129,7 @@ impl<B> RequestParts<B> {
method,
uri,
version,
mut headers,
headers,
mut extensions,
mut body,
} = self;
@ -148,14 +145,7 @@ impl<B> RequestParts<B> {
*req.method_mut() = method;
*req.uri_mut() = uri;
*req.version_mut() = version;
if let Some(headers) = headers.take() {
*req.headers_mut() = headers;
} else {
return Err(RequestAlreadyExtracted::HeadersAlreadyExtracted(
HeadersAlreadyExtracted,
));
}
*req.headers_mut() = headers;
if let Some(extensions) = extensions.take() {
*req.extensions_mut() = extensions;
@ -199,22 +189,13 @@ impl<B> RequestParts<B> {
}
/// Gets a reference to the request headers.
///
/// Returns `None` if the headers has been taken by another extractor.
pub fn headers(&self) -> Option<&HeaderMap> {
self.headers.as_ref()
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
/// Gets a mutable reference to the request headers.
///
/// Returns `None` if the headers has been taken by another extractor.
pub fn headers_mut(&mut self) -> Option<&mut HeaderMap> {
self.headers.as_mut()
}
/// Takes the headers out of the request, leaving a `None` in its place.
pub fn take_headers(&mut self) -> Option<HeaderMap> {
self.headers.take()
pub fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.headers
}
/// Gets a reference to the request extensions.

View file

@ -8,13 +8,6 @@ define_rejection! {
pub struct BodyAlreadyExtracted;
}
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Headers taken by other extractor"]
/// Rejection used if the headers has been taken by another extractor.
pub struct HeadersAlreadyExtracted;
}
define_rejection! {
#[status = INTERNAL_SERVER_ERROR]
#[body = "Extensions taken by other extractor"]
@ -47,7 +40,6 @@ composite_rejection! {
/// [`Request<_>`]: http::Request
pub enum RequestAlreadyExtracted {
BodyAlreadyExtracted,
HeadersAlreadyExtracted,
ExtensionsAlreadyExtracted,
}
}
@ -79,7 +71,6 @@ composite_rejection! {
///
/// Contains one variant for each way the [`http::request::Parts`] extractor can fail.
pub enum RequestPartsAlreadyExtracted {
HeadersAlreadyExtracted,
ExtensionsAlreadyExtracted,
}
}

View file

@ -19,7 +19,7 @@ where
method: req.method.clone(),
version: req.version,
uri: req.uri.clone(),
headers: None,
headers: HeaderMap::new(),
extensions: None,
body: None,
},
@ -65,15 +65,20 @@ where
}
}
/// Clone the headers from the request.
///
/// Prefer using [`TypedHeader`] to extract only the headers you need.
///
/// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html
#[async_trait]
impl<B> FromRequest<B> for HeaderMap
where
B: Send,
{
type Rejection = HeadersAlreadyExtracted;
type Rejection = Infallible;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
req.take_headers().ok_or(HeadersAlreadyExtracted)
Ok(req.headers().clone())
}
}
@ -143,7 +148,10 @@ where
let method = unwrap_infallible(Method::from_request(req).await);
let uri = unwrap_infallible(Uri::from_request(req).await);
let version = unwrap_infallible(Version::from_request(req).await);
let headers = HeaderMap::from_request(req).await?;
let headers = match HeaderMap::from_request(req).await {
Ok(headers) => headers,
Err(err) => match err {},
};
let extensions = Extensions::from_request(req).await?;
let mut temp_request = Request::new(());

View file

@ -13,9 +13,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
overwriting old values.
- **breaking:** Require `Output = ()` on `WebSocketStream::on_upgrade` ([#644])
- **breaking:** Make `TypedHeaderRejectionReason` `#[non_exhaustive]` ([#665])
- **breaking:** Using `HeaderMap` as an extractor will no longer remove the headers and thus
they'll still be accessible to other extractors, such as `axum::extract::Json`. Instead
`HeaderMap` will clone the headers. You should prefer to use `TypedHeader` to extract only the
headers you need ([#698])
This includes these breaking changes:
- `RequestParts::take_headers` has been removed.
- `RequestParts::headers` returns `&HeaderMap`.
- `RequestParts::headers_mut` returns `&mut HeaderMap`.
- `HeadersAlreadyExtracted` has been removed.
- The `HeadersAlreadyExtracted` removed variant has been removed from these rejections:
- `RequestAlreadyExtracted`
- `RequestPartsAlreadyExtracted`
- `JsonRejection`
- `FormRejection`
- `ContentLengthLimitRejection`
- `WebSocketUpgradeRejection`
- `<HeaderMap as FromRequest<_>>::Error` has been changed to `std::convert::Infallible`.
[#644]: https://github.com/tokio-rs/axum/pull/644
[#665]: https://github.com/tokio-rs/axum/pull/665
[#698]: https://github.com/tokio-rs/axum/pull/698
# 0.4.4 (13. January, 2021)

View file

@ -320,10 +320,6 @@ async fn handler(result: Result<Json<Value>, JsonRejection>) -> impl IntoRespons
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to buffer request body".to_string(),
)),
JsonRejection::HeadersAlreadyExtracted(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Headers already extracted".to_string(),
)),
// we must provide a catch-all case since `JsonRejection` is marked
// `#[non_exhaustive]`
_ => Err((
@ -377,9 +373,7 @@ where
type Rejection = (StatusCode, &'static str);
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let user_agent = req.headers().and_then(|headers| headers.get(USER_AGENT));
if let Some(user_agent) = user_agent {
if let Some(user_agent) = req.headers().get(USER_AGENT) {
Ok(ExtractUserAgent(user_agent.clone()))
} else {
Err((StatusCode::BAD_REQUEST, "`User-Agent` header is missing"))

View file

@ -39,14 +39,7 @@ where
type Rejection = ContentLengthLimitRejection<T::Rejection>;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let content_length = req
.headers()
.ok_or_else(|| {
ContentLengthLimitRejection::HeadersAlreadyExtracted(
HeadersAlreadyExtracted::default(),
)
})?
.get(http::header::CONTENT_LENGTH);
let content_length = req.headers().get(http::header::CONTENT_LENGTH);
let content_length =
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());

View file

@ -59,7 +59,7 @@ use tower_service::Service;
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
/// let auth_header = req
/// .headers()
/// .and_then(|headers| headers.get(http::header::AUTHORIZATION))
/// .get(http::header::AUTHORIZATION)
/// .and_then(|value| value.to_str().ok());
///
/// match auth_header {
@ -291,7 +291,6 @@ mod tests {
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
if let Some(auth) = req
.headers()
.expect("headers already extracted")
.get("authorization")
.and_then(|v| v.to_str().ok())
{

View file

@ -60,7 +60,7 @@ where
.map_err(FailedToDeserializeQueryString::new::<T, _>)?;
Ok(Form(value))
} else {
if !has_content_type(req, &mime::APPLICATION_WWW_FORM_URLENCODED)? {
if !has_content_type(req, &mime::APPLICATION_WWW_FORM_URLENCODED) {
return Err(InvalidFormContentType.into());
}

View file

@ -78,24 +78,20 @@ pub use self::typed_header::TypedHeader;
pub(crate) fn has_content_type<B>(
req: &RequestParts<B>,
expected_content_type: &mime::Mime,
) -> Result<bool, HeadersAlreadyExtracted> {
let content_type = if let Some(content_type) = req
.headers()
.ok_or_else(HeadersAlreadyExtracted::default)?
.get(header::CONTENT_TYPE)
{
) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
content_type
} else {
return Ok(false);
return false;
};
let content_type = if let Ok(content_type) = content_type.to_str() {
content_type
} else {
return Ok(false);
return false;
};
Ok(content_type.starts_with(expected_content_type.as_ref()))
content_type.starts_with(expected_content_type.as_ref())
}
pub(crate) fn take_body<B>(req: &mut RequestParts<B>) -> Result<B, BodyAlreadyExtracted> {

View file

@ -58,7 +58,7 @@ where
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let stream = BodyStream::from_request(req).await?;
let headers = req.headers().ok_or_else(HeadersAlreadyExtracted::default)?;
let headers = req.headers();
let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?;
let multipart = multer::Multipart::new(stream, boundary);
Ok(Self { inner: multipart })
@ -179,7 +179,6 @@ composite_rejection! {
pub enum MultipartRejection {
BodyAlreadyExtracted,
InvalidBoundary,
HeadersAlreadyExtracted,
}
}

View file

@ -124,7 +124,6 @@ composite_rejection! {
InvalidFormContentType,
FailedToDeserializeQueryString,
BytesRejection,
HeadersAlreadyExtracted,
}
}
@ -139,7 +138,6 @@ composite_rejection! {
InvalidJsonBody,
MissingJsonContentType,
BytesRejection,
HeadersAlreadyExtracted,
}
}
@ -195,8 +193,6 @@ pub enum ContentLengthLimitRejection<T> {
#[allow(missing_docs)]
LengthRequired(LengthRequired),
#[allow(missing_docs)]
HeadersAlreadyExtracted(HeadersAlreadyExtracted),
#[allow(missing_docs)]
Inner(T),
}
@ -208,7 +204,6 @@ where
match self {
Self::PayloadTooLarge(inner) => inner.into_response(),
Self::LengthRequired(inner) => inner.into_response(),
Self::HeadersAlreadyExtracted(inner) => inner.into_response(),
Self::Inner(inner) => inner.into_response(),
}
}
@ -222,7 +217,6 @@ where
match self {
Self::PayloadTooLarge(inner) => inner.fmt(f),
Self::LengthRequired(inner) => inner.fmt(f),
Self::HeadersAlreadyExtracted(inner) => inner.fmt(f),
Self::Inner(inner) => inner.fmt(f),
}
}
@ -236,7 +230,6 @@ where
match self {
Self::PayloadTooLarge(inner) => Some(inner),
Self::LengthRequired(inner) => Some(inner),
Self::HeadersAlreadyExtracted(inner) => Some(inner),
Self::Inner(inner) => Some(inner),
}
}

View file

@ -44,16 +44,7 @@ where
type Rejection = TypedHeaderRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let headers = if let Some(headers) = req.headers() {
headers
} else {
return Err(TypedHeaderRejection {
name: T::name(),
reason: TypedHeaderRejectionReason::Missing,
});
};
match headers.typed_try_get::<T>() {
match req.headers().typed_try_get::<T>() {
Ok(Some(value)) => Ok(Self(value)),
Ok(None) => Err(TypedHeaderRejection {
name: T::name(),

View file

@ -249,27 +249,24 @@ where
return Err(MethodNotGet.into());
}
if !header_contains(req, header::CONNECTION, "upgrade")? {
if !header_contains(req, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into());
}
if !header_eq(req, header::UPGRADE, "websocket")? {
if !header_eq(req, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into());
}
if !header_eq(req, header::SEC_WEBSOCKET_VERSION, "13")? {
if !header_eq(req, header::SEC_WEBSOCKET_VERSION, "13") {
return Err(InvalidWebSocketVersionHeader.into());
}
let sec_websocket_key = if let Some(key) = req
.headers_mut()
.ok_or_else(HeadersAlreadyExtracted::default)?
.remove(header::SEC_WEBSOCKET_KEY)
{
key
} else {
return Err(WebSocketKeyHeaderMissing.into());
};
let sec_websocket_key =
if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
key
} else {
return Err(WebSocketKeyHeaderMissing.into());
};
let on_upgrade = req
.extensions_mut()
@ -277,11 +274,7 @@ where
.remove::<OnUpgrade>()
.unwrap();
let sec_websocket_protocol = req
.headers()
.ok_or_else(HeadersAlreadyExtracted::default)?
.get(header::SEC_WEBSOCKET_PROTOCOL)
.cloned();
let sec_websocket_protocol = req.headers().get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
Ok(Self {
config: Default::default(),
@ -293,41 +286,25 @@ where
}
}
fn header_eq<B>(
req: &RequestParts<B>,
key: HeaderName,
value: &'static str,
) -> Result<bool, HeadersAlreadyExtracted> {
if let Some(header) = req
.headers()
.ok_or_else(HeadersAlreadyExtracted::default)?
.get(&key)
{
Ok(header.as_bytes().eq_ignore_ascii_case(value.as_bytes()))
fn header_eq<B>(req: &RequestParts<B>, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = req.headers().get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
} else {
Ok(false)
false
}
}
fn header_contains<B>(
req: &RequestParts<B>,
key: HeaderName,
value: &'static str,
) -> Result<bool, HeadersAlreadyExtracted> {
let header = if let Some(header) = req
.headers()
.ok_or_else(HeadersAlreadyExtracted::default)?
.get(&key)
{
fn header_contains<B>(req: &RequestParts<B>, key: HeaderName, value: &'static str) -> bool {
let header = if let Some(header) = req.headers().get(&key) {
header
} else {
return Ok(false);
return false;
};
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
Ok(header.to_ascii_lowercase().contains(value))
header.to_ascii_lowercase().contains(value)
} else {
Ok(false)
false
}
}
@ -585,7 +562,6 @@ pub mod rejection {
InvalidUpgradeHeader,
InvalidWebSocketVersionHeader,
WebSocketKeyHeaderMissing,
HeadersAlreadyExtracted,
ExtensionsAlreadyExtracted,
}
}

View file

@ -96,7 +96,7 @@ where
type Rejection = JsonRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
if json_content_type(req)? {
if json_content_type(req) {
let bytes = Bytes::from_request(req).await?;
let value = serde_json::from_slice(&bytes).map_err(InvalidJsonBody::from_err)?;
@ -108,33 +108,29 @@ where
}
}
fn json_content_type<B>(req: &RequestParts<B>) -> Result<bool, HeadersAlreadyExtracted> {
let content_type = if let Some(content_type) = req
.headers()
.ok_or_else(HeadersAlreadyExtracted::default)?
.get(header::CONTENT_TYPE)
{
fn json_content_type<B>(req: &RequestParts<B>) -> bool {
let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
content_type
} else {
return Ok(false);
return false;
};
let content_type = if let Ok(content_type) = content_type.to_str() {
content_type
} else {
return Ok(false);
return false;
};
let mime = if let Ok(mime) = content_type.parse::<mime::Mime>() {
mime
} else {
return Ok(false);
return false;
};
let is_json_content_type = mime.type_() == "application"
&& (mime.subtype() == "json" || mime.suffix().map_or(false, |name| name == "json"));
Ok(is_json_content_type)
is_json_content_type
}
impl<T> Deref for Json<T> {

View file

@ -73,9 +73,6 @@ where
JsonRejection::MissingJsonContentType(err) => {
(StatusCode::BAD_REQUEST, err.to_string().into())
}
JsonRejection::HeadersAlreadyExtracted(err) => {
(StatusCode::INTERNAL_SERVER_ERROR, err.to_string().into())
}
err => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unknown internal error: {}", err).into(),