refactor: add some informative header (X-Forwarded-For, X-Forwarded-Host, X-Forwarded-Proto, Forwarded)

This commit is contained in:
Afifurrohman 2024-05-10 15:08:52 +00:00
parent e01086b179
commit 40971a46c5

View file

@ -1,5 +1,5 @@
//! Reverse proxy listening in "localhost:4000" will proxy all `GET` requests to "localhost:3000" except for path /https is example.com //! Reverse proxy listening in "localhost:4000" will proxy all `GET` requests to "localhost:3000"
//! endpoint. //! except for path /https is example.com endpoint.
//! //!
//! On unix like OS: make sure `ca-certificates` is installed. //! On unix like OS: make sure `ca-certificates` is installed.
//! //!
@ -9,6 +9,8 @@
//! cargo run -p example-reverse-proxy //! cargo run -p example-reverse-proxy
//! ``` //! ```
use axum::extract::ConnectInfo;
use axum::http::header::FORWARDED;
use axum::http::{header::HOST, StatusCode}; use axum::http::{header::HOST, StatusCode};
use axum::{ use axum::{
body::Body, body::Body,
@ -19,6 +21,7 @@ use axum::{
}; };
use hyper_tls::HttpsConnector; use hyper_tls::HttpsConnector;
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
use std::net::SocketAddr;
type Client = hyper_util::client::legacy::Client<HttpsConnector<HttpConnector>, Body>; type Client = hyper_util::client::legacy::Client<HttpsConnector<HttpConnector>, Body>;
@ -38,16 +41,31 @@ async fn main() {
.await .await
.unwrap(); .unwrap();
println!("listening on {}", listener.local_addr().unwrap()); println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap(); axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.unwrap();
} }
async fn handler(State(client): State<Client>, mut req: Request) -> Result<Response, StatusCode> { async fn handler(
let path = req.uri().path(); State(client): State<Client>,
let path_query = req ConnectInfo(addr): ConnectInfo<SocketAddr>,
.uri() mut req: Request,
.path_and_query() ) -> Result<Response, StatusCode> {
.map(|v| v.as_str()) let uri = req.uri();
.unwrap_or(path);
let ip = addr.ip().to_string();
let host = uri
.authority()
.map(|a| a.as_str())
.unwrap_or("127.0.0.1:4000")
.to_string();
let proto = uri.scheme_str().unwrap_or("http").to_string();
let path = uri.path();
let path_query = uri.path_and_query().map(|v| v.as_str()).unwrap_or(path);
let mut uri = format!("http://127.0.0.1:3000{}", path_query); let mut uri = format!("http://127.0.0.1:3000{}", path_query);
if path == "/https" { if path == "/https" {
@ -56,13 +74,32 @@ async fn handler(State(client): State<Client>, mut req: Request) -> Result<Respo
*req.uri_mut() = Uri::try_from(uri).unwrap(); *req.uri_mut() = Uri::try_from(uri).unwrap();
//? Remove incorrect header host, hyper will add automatically for you. // Remove incorrect header host, hyper will add automatically for you.
req.headers_mut().remove(HOST).unwrap(); req.headers_mut().remove(HOST);
// Add some informative header (de-facto)
req.headers_mut()
.insert("X-Forwarded-For", ip.parse().unwrap());
req.headers_mut()
.insert("X-Forwarded-Host", host.parse().unwrap());
req.headers_mut()
.insert("X-Forwarded-Proto", proto.parse().unwrap());
// a standardized
req.headers_mut().insert(
FORWARDED,
format!("for={ip};host={host};proto={proto};")
.parse()
.unwrap(),
);
Ok(client Ok(client
.request(req) .request(req)
.await .await
.map_err(|_| StatusCode::BAD_REQUEST)? .map_err(|err| {
eprintln!("{:?}", err);
StatusCode::BAD_REQUEST
})?
.into_response()) .into_response())
} }