diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index 10d8a930..b0caab21 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -7,9 +7,11 @@ and this project adheres to [Semantic Versioning]. # Unreleased +- **fixed:** `Host` extractor includes port number when parsing authority ([#2242]) - **added:** Add `RouterExt::typed_connect` ([#2961]) - **added:** Add `json!` for easy construction of JSON responses ([#2962]) +[#2242]: https://github.com/tokio-rs/axum/pull/2242 [#2961]: https://github.com/tokio-rs/axum/pull/2961 [#2962]: https://github.com/tokio-rs/axum/pull/2962 diff --git a/axum-extra/src/extract/host.rs b/axum-extra/src/extract/host.rs index 477bc4fa..a6828d30 100644 --- a/axum-extra/src/extract/host.rs +++ b/axum-extra/src/extract/host.rs @@ -3,17 +3,21 @@ use axum::extract::FromRequestParts; use http::{ header::{HeaderMap, FORWARDED}, request::Parts, + uri::Authority, }; const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; -/// Extractor that resolves the hostname of the request. +/// Extractor that resolves the host of the request. /// -/// Hostname is resolved through the following, in order: +/// Host is resolved through the following, in order: /// - `Forwarded` header /// - `X-Forwarded-Host` header /// - `Host` header -/// - request target / URI +/// - Authority of the request URI +/// +/// See for the definition of +/// host. /// /// Note that user agents can set `X-Forwarded-Host` and `Host` headers to arbitrary values so make /// sure to validate them to avoid security issues. @@ -47,8 +51,8 @@ where return Ok(Host(host.to_owned())); } - if let Some(host) = parts.uri.host() { - return Ok(Host(host.to_owned())); + if let Some(authority) = parts.uri.authority() { + return Ok(Host(parse_authority(authority).to_owned())); } Err(HostRejection::FailedToResolveHost(FailedToResolveHost)) @@ -72,12 +76,19 @@ fn parse_forwarded(headers: &HeaderMap) -> Option<&str> { }) } +fn parse_authority(auth: &Authority) -> &str { + auth.as_str() + .rsplit('@') + .next() + .expect("split always has at least 1 item") +} + #[cfg(test)] mod tests { use super::*; use crate::test_helpers::TestClient; use axum::{routing::get, Router}; - use http::header::HeaderName; + use http::{header::HeaderName, Request}; fn test_client() -> TestClient { async fn host_as_body(Host(host): Host) -> String { @@ -127,8 +138,26 @@ mod tests { #[crate::test] async fn uri_host() { - let host = test_client().get("/").await.text().await; - assert!(host.contains("127.0.0.1")); + let client = test_client(); + let port = client.server_port(); + let host = client.get("/").await.text().await; + assert_eq!(host, format!("127.0.0.1:{port}")); + } + + #[crate::test] + async fn ip4_uri_host() { + let mut parts = Request::new(()).into_parts().0; + parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap(); + let host = Host::from_request_parts(&mut parts, &()).await.unwrap(); + assert_eq!(host.0, "127.0.0.1:1234"); + } + + #[crate::test] + async fn ip6_uri_host() { + let mut parts = Request::new(()).into_parts().0; + parts.uri = "http://cool:user@[::1]:456/file.txt".parse().unwrap(); + let host = Host::from_request_parts(&mut parts, &()).await.unwrap(); + assert_eq!(host.0, "[::1]:456"); } #[test] diff --git a/axum/src/test_helpers/test_client.rs b/axum/src/test_helpers/test_client.rs index 058a0245..2dfa95a0 100644 --- a/axum/src/test_helpers/test_client.rs +++ b/axum/src/test_helpers/test_client.rs @@ -83,6 +83,11 @@ impl TestClient { builder: self.client.patch(format!("http://{}{}", self.addr, url)), } } + + #[allow(dead_code)] + pub(crate) fn server_port(&self) -> u16 { + self.addr.port() + } } pub(crate) struct RequestBuilder {