mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-01 08:56:15 +01:00
Add Host extractor (#827)
This commit is contained in:
parent
b05a5c6dfe
commit
843437b501
4 changed files with 132 additions and 0 deletions
|
@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- **added:** `Extension<_>` can now be used in tuples for building responses, and will set an
|
||||
extension on the response ([#797])
|
||||
- **added:** Implement `tower::Layer` for `Extension` ([#801])
|
||||
- **added:** `extract::Host` for extracting the hostname of a request ([#827])
|
||||
- **changed:** `Router::merge` now accepts `Into<Router>` ([#819])
|
||||
- **breaking:** `sse::Event` now accepts types implementing `AsRef<str>` instead of `Into<String>`
|
||||
as field values.
|
||||
|
|
111
axum/src/extract/host.rs
Normal file
111
axum/src/extract/host.rs
Normal file
|
@ -0,0 +1,111 @@
|
|||
use super::{
|
||||
rejection::{FailedToResolveHost, HostRejection},
|
||||
FromRequest, RequestParts,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
|
||||
const X_FORWARDED_HOST_HEADER_KEY: &'static str = "X-Forwarded-Host";
|
||||
|
||||
/// Extractor that resolves the hostname of the request.
|
||||
///
|
||||
/// Hostname is resolved through the following, in order:
|
||||
/// - `X-Forwarded-Host` header
|
||||
/// - `Host` header
|
||||
/// - request target / URI
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Host(pub String);
|
||||
|
||||
#[async_trait]
|
||||
impl<B> FromRequest<B> for Host
|
||||
where
|
||||
B: Send,
|
||||
{
|
||||
type Rejection = HostRejection;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
// todo: extract host from http::header::FORWARDED
|
||||
|
||||
if let Some(host) = req
|
||||
.headers()
|
||||
.get(X_FORWARDED_HOST_HEADER_KEY)
|
||||
.and_then(|host| host.to_str().ok())
|
||||
{
|
||||
return Ok(Host(host.to_owned()));
|
||||
}
|
||||
|
||||
if let Some(host) = req
|
||||
.headers()
|
||||
.get(http::header::HOST)
|
||||
.and_then(|host| host.to_str().ok())
|
||||
{
|
||||
return Ok(Host(host.to_owned()));
|
||||
}
|
||||
|
||||
if let Some(host) = req.uri().host() {
|
||||
return Ok(Host(host.to_owned()));
|
||||
}
|
||||
|
||||
Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{routing::get, test_helpers::TestClient, Router};
|
||||
|
||||
fn test_client() -> TestClient {
|
||||
async fn host_as_body(Host(host): Host) -> String {
|
||||
host
|
||||
}
|
||||
|
||||
TestClient::new(Router::new().route("/", get(host_as_body)))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_header() {
|
||||
let original_host = "some-domain:123";
|
||||
let host = test_client()
|
||||
.get("/")
|
||||
.header(http::header::HOST, original_host)
|
||||
.send()
|
||||
.await
|
||||
.text()
|
||||
.await;
|
||||
assert_eq!(host, original_host);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn x_forwarded_host_header() {
|
||||
let original_host = "some-domain:456";
|
||||
let host = test_client()
|
||||
.get("/")
|
||||
.header(X_FORWARDED_HOST_HEADER_KEY, original_host)
|
||||
.send()
|
||||
.await
|
||||
.text()
|
||||
.await;
|
||||
assert_eq!(host, original_host);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn x_forwarded_host_precedence_over_host_header() {
|
||||
let x_forwarded_host_header = "some-domain:456";
|
||||
let host_header = "some-domain:123";
|
||||
let host = test_client()
|
||||
.get("/")
|
||||
.header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
|
||||
.header(http::header::HOST, host_header)
|
||||
.send()
|
||||
.await
|
||||
.text()
|
||||
.await;
|
||||
assert_eq!(host, x_forwarded_host_header);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn uri_host() {
|
||||
let host = test_client().get("/").send().await.text().await;
|
||||
assert!(host.contains("127.0.0.1"));
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@ pub mod rejection;
|
|||
pub mod ws;
|
||||
|
||||
mod content_length_limit;
|
||||
mod host;
|
||||
mod raw_query;
|
||||
mod request_parts;
|
||||
|
||||
|
@ -23,6 +24,7 @@ pub use self::{
|
|||
connect_info::ConnectInfo,
|
||||
content_length_limit::ContentLengthLimit,
|
||||
extractor_middleware::extractor_middleware,
|
||||
host::Host,
|
||||
path::Path,
|
||||
raw_query::RawQuery,
|
||||
request_parts::{BodyStream, RawBody},
|
||||
|
|
|
@ -65,6 +65,14 @@ define_rejection! {
|
|||
pub struct InvalidFormContentType;
|
||||
}
|
||||
|
||||
define_rejection! {
|
||||
#[status = BAD_REQUEST]
|
||||
#[body = "No host found in request"]
|
||||
/// Rejection type used if the [`Host`](super::Host) extractor is unable to
|
||||
/// resolve a host.
|
||||
pub struct FailedToResolveHost;
|
||||
}
|
||||
|
||||
/// Rejection type for extractors that deserialize query strings if the input
|
||||
/// couldn't be deserialized into the target type.
|
||||
#[derive(Debug)]
|
||||
|
@ -160,6 +168,16 @@ composite_rejection! {
|
|||
}
|
||||
}
|
||||
|
||||
composite_rejection! {
|
||||
/// Rejection used for [`Host`](super::Host).
|
||||
///
|
||||
/// Contains one variant for each way the [`Host`](super::Host) extractor
|
||||
/// can fail.
|
||||
pub enum HostRejection {
|
||||
FailedToResolveHost,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "matched-path")]
|
||||
define_rejection! {
|
||||
#[status = INTERNAL_SERVER_ERROR]
|
||||
|
|
Loading…
Reference in a new issue