mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-04 02:01:23 +01:00
Add Host extractor (#827)
This commit is contained in:
parent
b05a5c6dfe
commit
843437b501
4 changed files with 132 additions and 0 deletions
|
@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
- **added:** `Extension<_>` can now be used in tuples for building responses, and will set an
|
- **added:** `Extension<_>` can now be used in tuples for building responses, and will set an
|
||||||
extension on the response ([#797])
|
extension on the response ([#797])
|
||||||
- **added:** Implement `tower::Layer` for `Extension` ([#801])
|
- **added:** Implement `tower::Layer` for `Extension` ([#801])
|
||||||
|
- **added:** `extract::Host` for extracting the hostname of a request ([#827])
|
||||||
- **changed:** `Router::merge` now accepts `Into<Router>` ([#819])
|
- **changed:** `Router::merge` now accepts `Into<Router>` ([#819])
|
||||||
- **breaking:** `sse::Event` now accepts types implementing `AsRef<str>` instead of `Into<String>`
|
- **breaking:** `sse::Event` now accepts types implementing `AsRef<str>` instead of `Into<String>`
|
||||||
as field values.
|
as field values.
|
||||||
|
|
111
axum/src/extract/host.rs
Normal file
111
axum/src/extract/host.rs
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
use super::{
|
||||||
|
rejection::{FailedToResolveHost, HostRejection},
|
||||||
|
FromRequest, RequestParts,
|
||||||
|
};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
const X_FORWARDED_HOST_HEADER_KEY: &'static str = "X-Forwarded-Host";
|
||||||
|
|
||||||
|
/// Extractor that resolves the hostname of the request.
|
||||||
|
///
|
||||||
|
/// Hostname is resolved through the following, in order:
|
||||||
|
/// - `X-Forwarded-Host` header
|
||||||
|
/// - `Host` header
|
||||||
|
/// - request target / URI
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Host(pub String);
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<B> FromRequest<B> for Host
|
||||||
|
where
|
||||||
|
B: Send,
|
||||||
|
{
|
||||||
|
type Rejection = HostRejection;
|
||||||
|
|
||||||
|
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||||
|
// todo: extract host from http::header::FORWARDED
|
||||||
|
|
||||||
|
if let Some(host) = req
|
||||||
|
.headers()
|
||||||
|
.get(X_FORWARDED_HOST_HEADER_KEY)
|
||||||
|
.and_then(|host| host.to_str().ok())
|
||||||
|
{
|
||||||
|
return Ok(Host(host.to_owned()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(host) = req
|
||||||
|
.headers()
|
||||||
|
.get(http::header::HOST)
|
||||||
|
.and_then(|host| host.to_str().ok())
|
||||||
|
{
|
||||||
|
return Ok(Host(host.to_owned()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(host) = req.uri().host() {
|
||||||
|
return Ok(Host(host.to_owned()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::{routing::get, test_helpers::TestClient, Router};
|
||||||
|
|
||||||
|
fn test_client() -> TestClient {
|
||||||
|
async fn host_as_body(Host(host): Host) -> String {
|
||||||
|
host
|
||||||
|
}
|
||||||
|
|
||||||
|
TestClient::new(Router::new().route("/", get(host_as_body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn host_header() {
|
||||||
|
let original_host = "some-domain:123";
|
||||||
|
let host = test_client()
|
||||||
|
.get("/")
|
||||||
|
.header(http::header::HOST, original_host)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.text()
|
||||||
|
.await;
|
||||||
|
assert_eq!(host, original_host);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn x_forwarded_host_header() {
|
||||||
|
let original_host = "some-domain:456";
|
||||||
|
let host = test_client()
|
||||||
|
.get("/")
|
||||||
|
.header(X_FORWARDED_HOST_HEADER_KEY, original_host)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.text()
|
||||||
|
.await;
|
||||||
|
assert_eq!(host, original_host);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn x_forwarded_host_precedence_over_host_header() {
|
||||||
|
let x_forwarded_host_header = "some-domain:456";
|
||||||
|
let host_header = "some-domain:123";
|
||||||
|
let host = test_client()
|
||||||
|
.get("/")
|
||||||
|
.header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
|
||||||
|
.header(http::header::HOST, host_header)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.text()
|
||||||
|
.await;
|
||||||
|
assert_eq!(host, x_forwarded_host_header);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn uri_host() {
|
||||||
|
let host = test_client().get("/").send().await.text().await;
|
||||||
|
assert!(host.contains("127.0.0.1"));
|
||||||
|
}
|
||||||
|
}
|
|
@ -12,6 +12,7 @@ pub mod rejection;
|
||||||
pub mod ws;
|
pub mod ws;
|
||||||
|
|
||||||
mod content_length_limit;
|
mod content_length_limit;
|
||||||
|
mod host;
|
||||||
mod raw_query;
|
mod raw_query;
|
||||||
mod request_parts;
|
mod request_parts;
|
||||||
|
|
||||||
|
@ -23,6 +24,7 @@ pub use self::{
|
||||||
connect_info::ConnectInfo,
|
connect_info::ConnectInfo,
|
||||||
content_length_limit::ContentLengthLimit,
|
content_length_limit::ContentLengthLimit,
|
||||||
extractor_middleware::extractor_middleware,
|
extractor_middleware::extractor_middleware,
|
||||||
|
host::Host,
|
||||||
path::Path,
|
path::Path,
|
||||||
raw_query::RawQuery,
|
raw_query::RawQuery,
|
||||||
request_parts::{BodyStream, RawBody},
|
request_parts::{BodyStream, RawBody},
|
||||||
|
|
|
@ -65,6 +65,14 @@ define_rejection! {
|
||||||
pub struct InvalidFormContentType;
|
pub struct InvalidFormContentType;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
define_rejection! {
|
||||||
|
#[status = BAD_REQUEST]
|
||||||
|
#[body = "No host found in request"]
|
||||||
|
/// Rejection type used if the [`Host`](super::Host) extractor is unable to
|
||||||
|
/// resolve a host.
|
||||||
|
pub struct FailedToResolveHost;
|
||||||
|
}
|
||||||
|
|
||||||
/// Rejection type for extractors that deserialize query strings if the input
|
/// Rejection type for extractors that deserialize query strings if the input
|
||||||
/// couldn't be deserialized into the target type.
|
/// couldn't be deserialized into the target type.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -160,6 +168,16 @@ composite_rejection! {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
composite_rejection! {
|
||||||
|
/// Rejection used for [`Host`](super::Host).
|
||||||
|
///
|
||||||
|
/// Contains one variant for each way the [`Host`](super::Host) extractor
|
||||||
|
/// can fail.
|
||||||
|
pub enum HostRejection {
|
||||||
|
FailedToResolveHost,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(feature = "matched-path")]
|
#[cfg(feature = "matched-path")]
|
||||||
define_rejection! {
|
define_rejection! {
|
||||||
#[status = INTERNAL_SERVER_ERROR]
|
#[status = INTERNAL_SERVER_ERROR]
|
||||||
|
|
Loading…
Reference in a new issue