From 40971a46c54db68460036d5d4091d7bb177bc872 Mon Sep 17 00:00:00 2001 From: Afifurrohman Date: Fri, 10 May 2024 15:08:52 +0000 Subject: [PATCH] refactor: add some informative header (`X-Forwarded-For`, `X-Forwarded-Host`, `X-Forwarded-Proto`, `Forwarded`) --- examples/reverse-proxy/src/main.rs | 63 ++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 13 deletions(-) diff --git a/examples/reverse-proxy/src/main.rs b/examples/reverse-proxy/src/main.rs index 575f7c84..cc176680 100644 --- a/examples/reverse-proxy/src/main.rs +++ b/examples/reverse-proxy/src/main.rs @@ -1,5 +1,5 @@ -//! Reverse proxy listening in "localhost:4000" will proxy all `GET` requests to "localhost:3000" except for path /https is example.com -//! endpoint. +//! Reverse proxy listening in "localhost:4000" will proxy all `GET` requests to "localhost:3000" +//! except for path /https is example.com endpoint. //! //! On unix like OS: make sure `ca-certificates` is installed. //! @@ -9,6 +9,8 @@ //! cargo run -p example-reverse-proxy //! ``` +use axum::extract::ConnectInfo; +use axum::http::header::FORWARDED; use axum::http::{header::HOST, StatusCode}; use axum::{ body::Body, @@ -19,6 +21,7 @@ use axum::{ }; use hyper_tls::HttpsConnector; use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; +use std::net::SocketAddr; type Client = hyper_util::client::legacy::Client, Body>; @@ -38,16 +41,31 @@ async fn main() { .await .unwrap(); println!("listening on {}", listener.local_addr().unwrap()); - axum::serve(listener, app).await.unwrap(); + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) + .await + .unwrap(); } -async fn handler(State(client): State, mut req: Request) -> Result { - let path = req.uri().path(); - let path_query = req - .uri() - .path_and_query() - .map(|v| v.as_str()) - .unwrap_or(path); +async fn handler( + State(client): State, + ConnectInfo(addr): ConnectInfo, + mut req: Request, +) -> Result { + let uri = req.uri(); + + let ip = addr.ip().to_string(); + let host = uri + .authority() + .map(|a| a.as_str()) + .unwrap_or("127.0.0.1:4000") + .to_string(); + let proto = uri.scheme_str().unwrap_or("http").to_string(); + + let path = uri.path(); + let path_query = uri.path_and_query().map(|v| v.as_str()).unwrap_or(path); let mut uri = format!("http://127.0.0.1:3000{}", path_query); if path == "/https" { @@ -56,13 +74,32 @@ async fn handler(State(client): State, mut req: Request) -> Result