diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index a35948da..52072240 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -33,6 +33,7 @@ json-lines = [ ] multipart = ["dep:multer", "dep:fastrand"] protobuf = ["dep:prost"] +scheme = [] query = ["dep:serde_html_form"] tracing = ["axum-core/tracing", "axum/tracing"] typed-header = ["dep:headers"] diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs index a0e710d1..7d2a5b24 100644 --- a/axum-extra/src/extract/mod.rs +++ b/axum-extra/src/extract/mod.rs @@ -21,6 +21,9 @@ mod query; #[cfg(feature = "multipart")] pub mod multipart; +#[cfg(feature = "scheme")] +mod scheme; + pub use self::{ cached::Cached, host::Host, optional_path::OptionalPath, with_rejection::WithRejection, }; @@ -43,6 +46,10 @@ pub use self::query::{OptionalQuery, OptionalQueryRejection, Query, QueryRejecti #[cfg(feature = "multipart")] pub use self::multipart::Multipart; +#[cfg(feature = "scheme")] +#[doc(no_inline)] +pub use self::scheme::{Scheme, SchemeMissing}; + #[cfg(feature = "json-deserializer")] pub use self::json_deserializer::{ JsonDataError, JsonDeserializer, JsonDeserializerRejection, JsonSyntaxError, diff --git a/axum-extra/src/extract/scheme.rs b/axum-extra/src/extract/scheme.rs new file mode 100644 index 00000000..891d5c0b --- /dev/null +++ b/axum-extra/src/extract/scheme.rs @@ -0,0 +1,152 @@ +//! Extractor that parses the scheme of a request. +//! See [`Scheme`] for more details. + +use axum::{ + extract::FromRequestParts, + response::{IntoResponse, Response}, +}; +use http::{ + header::{HeaderMap, FORWARDED}, + request::Parts, +}; +const X_FORWARDED_PROTO_HEADER_KEY: &str = "X-Forwarded-Proto"; + +/// Extractor that resolves the scheme / protocol of a request. +/// +/// The scheme is resolved through the following, in order: +/// - `Forwarded` header +/// - `X-Forwarded-Proto` header +/// - Request URI (If the request is an HTTP/2 request! e.g. use `--http2(-prior-knowledge)` with cURL) +/// +/// Note that user agents can set the `X-Forwarded-Proto` header to arbitrary values so make +/// sure to validate them to avoid security issues. +#[derive(Debug, Clone)] +pub struct Scheme(pub String); + +/// Rejection type used if the [`Scheme`] extractor is unable to +/// resolve a scheme. +#[derive(Debug)] +pub struct SchemeMissing; + +impl IntoResponse for SchemeMissing { + fn into_response(self) -> Response { + (http::StatusCode::BAD_REQUEST, "No scheme found in request").into_response() + } +} + +impl FromRequestParts for Scheme +where + S: Send + Sync, +{ + type Rejection = SchemeMissing; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + // Within Forwarded header + if let Some(scheme) = parse_forwarded(&parts.headers) { + return Ok(Scheme(scheme.to_owned())); + } + + // X-Forwarded-Proto + if let Some(scheme) = parts + .headers + .get(X_FORWARDED_PROTO_HEADER_KEY) + .and_then(|scheme| scheme.to_str().ok()) + { + return Ok(Scheme(scheme.to_owned())); + } + + // From parts of an HTTP/2 request + if let Some(scheme) = parts.uri.scheme_str() { + return Ok(Scheme(scheme.to_owned())); + } + + Err(SchemeMissing) + } +} + +fn parse_forwarded(headers: &HeaderMap) -> Option<&str> { + // if there are multiple `Forwarded` `HeaderMap::get` will return the first one + let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?; + + // get the first set of values + let first_value = forwarded_values.split(',').next()?; + + // find the value of the `proto` field + first_value.split(';').find_map(|pair| { + let (key, value) = pair.split_once('=')?; + key.trim() + .eq_ignore_ascii_case("proto") + .then(|| value.trim().trim_matches('"')) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::TestClient; + use axum::{routing::get, Router}; + use http::header::HeaderName; + + fn test_client() -> TestClient { + async fn scheme_as_body(Scheme(scheme): Scheme) -> String { + scheme + } + + TestClient::new(Router::new().route("/", get(scheme_as_body))) + } + + #[crate::test] + async fn forwarded_scheme_parsing() { + // the basic case + let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]); + let value = parse_forwarded(&headers).unwrap(); + assert_eq!(value, "http"); + + // is case insensitive + let headers = header_map(&[(FORWARDED, "host=192.0.2.60;PROTO=https;by=203.0.113.43")]); + let value = parse_forwarded(&headers).unwrap(); + assert_eq!(value, "https"); + + // multiple values in one header + let headers = header_map(&[(FORWARDED, "proto=ftp, proto=https")]); + let value = parse_forwarded(&headers).unwrap(); + assert_eq!(value, "ftp"); + + // multiple header values + let headers = header_map(&[(FORWARDED, "proto=ftp"), (FORWARDED, "proto=https")]); + let value = parse_forwarded(&headers).unwrap(); + assert_eq!(value, "ftp"); + } + + #[crate::test] + async fn x_forwarded_scheme_header() { + let original_scheme = "https"; + let scheme = test_client() + .get("/") + .header(X_FORWARDED_PROTO_HEADER_KEY, original_scheme) + .await + .text() + .await; + assert_eq!(scheme, original_scheme); + } + + #[crate::test] + async fn precedence_forwarded_over_x_forwarded() { + let scheme = test_client() + .get("/") + .header(X_FORWARDED_PROTO_HEADER_KEY, "https") + .header(FORWARDED, "proto=ftp") + .await + .text() + .await; + assert_eq!(scheme, "ftp"); + } + + fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap { + let mut headers = HeaderMap::new(); + for (key, value) in values { + headers.append(key, value.parse().unwrap()); + } + headers + } +}