Support Forwarded in Host extractor (#1078)

* Support `Forwarded` in `Host` extractor

* changelog

* Update axum/src/extract/host.rs

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>

* look for `host` key

Co-authored-by: Jonas Platte <jplatte+git@posteo.de>
This commit is contained in:
David Pedersen 2022-06-10 22:30:01 +02:00 committed by GitHub
parent f66893fbda
commit dbdbd0165e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 2 deletions

View file

@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- None.
- **added:** Support resolving host name via `Forwarded` header in `Host`
extractor ([#1078])
[#1078]: https://github.com/tokio-rs/axum/pull/1078
# 0.5.7 (08. June, 2022)

View file

@ -3,12 +3,14 @@ use super::{
FromRequest, RequestParts,
};
use async_trait::async_trait;
use http::header::{HeaderMap, FORWARDED};
const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
/// Extractor that resolves the hostname of the request.
///
/// Hostname is resolved through the following, in order:
/// - `Forwarded` header
/// - `X-Forwarded-Host` header
/// - `Host` header
/// - request target / URI
@ -26,7 +28,9 @@ where
type Rejection = HostRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
// TODO: extract host from http::header::FORWARDED
if let Some(host) = parse_forwarded(req.headers()) {
return Ok(Host(host.to_owned()));
}
if let Some(host) = req
.headers()
@ -52,10 +56,28 @@ where
}
}
#[allow(warnings)]
fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
// if there are multiple `Forwarded` `HeaderMap::get` will return the first one
let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;
// get the first set of values
let first_value = forwarded_values.split(',').nth(0)?;
// find the value of the `host` field
first_value.split(';').find_map(|pair| {
let (key, value) = pair.split_once('=')?;
key.trim()
.eq_ignore_ascii_case("host")
.then(|| value.trim().trim_matches('"'))
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{routing::get, test_helpers::TestClient, Router};
use http::header::HeaderName;
fn test_client() -> TestClient {
async fn host_as_body(Host(host): Host) -> String {
@ -111,4 +133,43 @@ mod tests {
let host = test_client().get("/").send().await.text().await;
assert!(host.contains("127.0.0.1"));
}
#[test]
fn forwarded_parsing() {
// the basic case
let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");
// is case insensitive
let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");
// ipv6
let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "[2001:db8:cafe::17]:4711");
// multiple values in one header
let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");
// multiple header values
let headers = header_map(&[
(FORWARDED, "host=192.0.2.60"),
(FORWARDED, "host=127.0.0.1"),
]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");
}
fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
let mut headers = HeaderMap::new();
for (key, value) in values {
headers.append(key, value.parse().unwrap());
}
headers
}
}