diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index e9d0fc67..a7ccf4ff 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** `Extension<_>` can now be used in tuples for building responses, and will set an extension on the response ([#797]) - **added:** Implement `tower::Layer` for `Extension` ([#801]) +- **added:** `extract::Host` for extracting the hostname of a request ([#827]) - **changed:** `Router::merge` now accepts `Into` ([#819]) - **breaking:** `sse::Event` now accepts types implementing `AsRef` instead of `Into` as field values. diff --git a/axum/src/extract/host.rs b/axum/src/extract/host.rs new file mode 100644 index 00000000..8e2124fb --- /dev/null +++ b/axum/src/extract/host.rs @@ -0,0 +1,111 @@ +use super::{ + rejection::{FailedToResolveHost, HostRejection}, + FromRequest, RequestParts, +}; +use async_trait::async_trait; + +const X_FORWARDED_HOST_HEADER_KEY: &'static str = "X-Forwarded-Host"; + +/// Extractor that resolves the hostname of the request. +/// +/// Hostname is resolved through the following, in order: +/// - `X-Forwarded-Host` header +/// - `Host` header +/// - request target / URI +#[derive(Debug, Clone)] +pub struct Host(pub String); + +#[async_trait] +impl FromRequest for Host +where + B: Send, +{ + type Rejection = HostRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + // todo: extract host from http::header::FORWARDED + + if let Some(host) = req + .headers() + .get(X_FORWARDED_HOST_HEADER_KEY) + .and_then(|host| host.to_str().ok()) + { + return Ok(Host(host.to_owned())); + } + + if let Some(host) = req + .headers() + .get(http::header::HOST) + .and_then(|host| host.to_str().ok()) + { + return Ok(Host(host.to_owned())); + } + + if let Some(host) = req.uri().host() { + return Ok(Host(host.to_owned())); + } + + Err(HostRejection::FailedToResolveHost(FailedToResolveHost)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{routing::get, test_helpers::TestClient, Router}; + + fn test_client() -> TestClient { + async fn host_as_body(Host(host): Host) -> String { + host + } + + TestClient::new(Router::new().route("/", get(host_as_body))) + } + + #[tokio::test] + async fn host_header() { + let original_host = "some-domain:123"; + let host = test_client() + .get("/") + .header(http::header::HOST, original_host) + .send() + .await + .text() + .await; + assert_eq!(host, original_host); + } + + #[tokio::test] + async fn x_forwarded_host_header() { + let original_host = "some-domain:456"; + let host = test_client() + .get("/") + .header(X_FORWARDED_HOST_HEADER_KEY, original_host) + .send() + .await + .text() + .await; + assert_eq!(host, original_host); + } + + #[tokio::test] + async fn x_forwarded_host_precedence_over_host_header() { + let x_forwarded_host_header = "some-domain:456"; + let host_header = "some-domain:123"; + let host = test_client() + .get("/") + .header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header) + .header(http::header::HOST, host_header) + .send() + .await + .text() + .await; + assert_eq!(host, x_forwarded_host_header); + } + + #[tokio::test] + async fn uri_host() { + let host = test_client().get("/").send().await.text().await; + assert!(host.contains("127.0.0.1")); + } +} diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index 00a11387..db7f974a 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -12,6 +12,7 @@ pub mod rejection; pub mod ws; mod content_length_limit; +mod host; mod raw_query; mod request_parts; @@ -23,6 +24,7 @@ pub use self::{ connect_info::ConnectInfo, content_length_limit::ContentLengthLimit, extractor_middleware::extractor_middleware, + host::Host, path::Path, raw_query::RawQuery, request_parts::{BodyStream, RawBody}, diff --git a/axum/src/extract/rejection.rs b/axum/src/extract/rejection.rs index 780327c8..737a70bd 100644 --- a/axum/src/extract/rejection.rs +++ b/axum/src/extract/rejection.rs @@ -65,6 +65,14 @@ define_rejection! { pub struct InvalidFormContentType; } +define_rejection! { + #[status = BAD_REQUEST] + #[body = "No host found in request"] + /// Rejection type used if the [`Host`](super::Host) extractor is unable to + /// resolve a host. + pub struct FailedToResolveHost; +} + /// Rejection type for extractors that deserialize query strings if the input /// couldn't be deserialized into the target type. #[derive(Debug)] @@ -160,6 +168,16 @@ composite_rejection! { } } +composite_rejection! { + /// Rejection used for [`Host`](super::Host). + /// + /// Contains one variant for each way the [`Host`](super::Host) extractor + /// can fail. + pub enum HostRejection { + FailedToResolveHost, + } +} + #[cfg(feature = "matched-path")] define_rejection! { #[status = INTERNAL_SERVER_ERROR]