Add Host extractor (#827)

This commit is contained in:
Trent 2022-03-07 03:29:10 +11:00 committed by GitHub
parent b05a5c6dfe
commit 843437b501
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 132 additions and 0 deletions

View file

@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** `Extension<_>` can now be used in tuples for building responses, and will set an - **added:** `Extension<_>` can now be used in tuples for building responses, and will set an
extension on the response ([#797]) extension on the response ([#797])
- **added:** Implement `tower::Layer` for `Extension` ([#801]) - **added:** Implement `tower::Layer` for `Extension` ([#801])
- **added:** `extract::Host` for extracting the hostname of a request ([#827])
- **changed:** `Router::merge` now accepts `Into<Router>` ([#819]) - **changed:** `Router::merge` now accepts `Into<Router>` ([#819])
- **breaking:** `sse::Event` now accepts types implementing `AsRef<str>` instead of `Into<String>` - **breaking:** `sse::Event` now accepts types implementing `AsRef<str>` instead of `Into<String>`
as field values. as field values.

111
axum/src/extract/host.rs Normal file
View file

@ -0,0 +1,111 @@
use super::{
rejection::{FailedToResolveHost, HostRejection},
FromRequest, RequestParts,
};
use async_trait::async_trait;
const X_FORWARDED_HOST_HEADER_KEY: &'static str = "X-Forwarded-Host";
/// Extractor that resolves the hostname of the request.
///
/// Hostname is resolved through the following, in order:
/// - `X-Forwarded-Host` header
/// - `Host` header
/// - request target / URI
#[derive(Debug, Clone)]
pub struct Host(pub String);
#[async_trait]
impl<B> FromRequest<B> for Host
where
B: Send,
{
type Rejection = HostRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
// todo: extract host from http::header::FORWARDED
if let Some(host) = req
.headers()
.get(X_FORWARDED_HOST_HEADER_KEY)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
}
if let Some(host) = req
.headers()
.get(http::header::HOST)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
}
if let Some(host) = req.uri().host() {
return Ok(Host(host.to_owned()));
}
Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{routing::get, test_helpers::TestClient, Router};
fn test_client() -> TestClient {
async fn host_as_body(Host(host): Host) -> String {
host
}
TestClient::new(Router::new().route("/", get(host_as_body)))
}
#[tokio::test]
async fn host_header() {
let original_host = "some-domain:123";
let host = test_client()
.get("/")
.header(http::header::HOST, original_host)
.send()
.await
.text()
.await;
assert_eq!(host, original_host);
}
#[tokio::test]
async fn x_forwarded_host_header() {
let original_host = "some-domain:456";
let host = test_client()
.get("/")
.header(X_FORWARDED_HOST_HEADER_KEY, original_host)
.send()
.await
.text()
.await;
assert_eq!(host, original_host);
}
#[tokio::test]
async fn x_forwarded_host_precedence_over_host_header() {
let x_forwarded_host_header = "some-domain:456";
let host_header = "some-domain:123";
let host = test_client()
.get("/")
.header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
.header(http::header::HOST, host_header)
.send()
.await
.text()
.await;
assert_eq!(host, x_forwarded_host_header);
}
#[tokio::test]
async fn uri_host() {
let host = test_client().get("/").send().await.text().await;
assert!(host.contains("127.0.0.1"));
}
}

View file

@ -12,6 +12,7 @@ pub mod rejection;
pub mod ws; pub mod ws;
mod content_length_limit; mod content_length_limit;
mod host;
mod raw_query; mod raw_query;
mod request_parts; mod request_parts;
@ -23,6 +24,7 @@ pub use self::{
connect_info::ConnectInfo, connect_info::ConnectInfo,
content_length_limit::ContentLengthLimit, content_length_limit::ContentLengthLimit,
extractor_middleware::extractor_middleware, extractor_middleware::extractor_middleware,
host::Host,
path::Path, path::Path,
raw_query::RawQuery, raw_query::RawQuery,
request_parts::{BodyStream, RawBody}, request_parts::{BodyStream, RawBody},

View file

@ -65,6 +65,14 @@ define_rejection! {
pub struct InvalidFormContentType; pub struct InvalidFormContentType;
} }
define_rejection! {
#[status = BAD_REQUEST]
#[body = "No host found in request"]
/// Rejection type used if the [`Host`](super::Host) extractor is unable to
/// resolve a host.
pub struct FailedToResolveHost;
}
/// Rejection type for extractors that deserialize query strings if the input /// Rejection type for extractors that deserialize query strings if the input
/// couldn't be deserialized into the target type. /// couldn't be deserialized into the target type.
#[derive(Debug)] #[derive(Debug)]
@ -160,6 +168,16 @@ composite_rejection! {
} }
} }
composite_rejection! {
/// Rejection used for [`Host`](super::Host).
///
/// Contains one variant for each way the [`Host`](super::Host) extractor
/// can fail.
pub enum HostRejection {
FailedToResolveHost,
}
}
#[cfg(feature = "matched-path")] #[cfg(feature = "matched-path")]
define_rejection! { define_rejection! {
#[status = INTERNAL_SERVER_ERROR] #[status = INTERNAL_SERVER_ERROR]