diff --git a/examples/http-proxy/Cargo.toml b/examples/http-proxy/Cargo.toml index c44f3094..aa607002 100644 --- a/examples/http-proxy/Cargo.toml +++ b/examples/http-proxy/Cargo.toml @@ -6,7 +6,8 @@ publish = false [dependencies] axum = { path = "../../axum" } -hyper = { version = "1.0.0", features = ["full"] } +hyper = { version = "1", features = ["full"] } +hyper-util = "0.1.1" tokio = { version = "1.0", features = ["full"] } tower = { version = "0.4", features = ["make"] } tracing = "0.1" diff --git a/examples/http-proxy/src/main.rs b/examples/http-proxy/src/main.rs index 04ef3e51..b60ed03d 100644 --- a/examples/http-proxy/src/main.rs +++ b/examples/http-proxy/src/main.rs @@ -12,96 +12,114 @@ //! //! Example is based on -// TODO -fn main() { - eprint!("this example has not yet been updated to hyper 1.0"); +use axum::{ + body::Body, + extract::Request, + http::{Method, StatusCode}, + response::{IntoResponse, Response}, + routing::get, + Router, +}; + +use hyper::body::Incoming; +use hyper::server::conn::http1; +use hyper::upgrade::Upgraded; +use std::net::SocketAddr; +use tokio::net::{TcpListener, TcpStream}; +use tower::Service; +use tower::ServiceExt; + +use hyper_util::rt::TokioIo; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "example_http_proxy=trace,tower_http=debug".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let router_svc = Router::new().route("/", get(|| async { "Hello, World!" })); + + let tower_service = tower::service_fn(move |req: Request<_>| { + let router_svc = router_svc.clone(); + let req = req.map(Body::new); + async move { + if req.method() == Method::CONNECT { + proxy(req).await + } else { + router_svc.oneshot(req).await.map_err(|err| match err {}) + } + } + }); + + let hyper_service = hyper::service::service_fn(move |request: Request| { + tower_service.clone().call(request) + }); + + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + tracing::debug!("listening on {}", addr); + + let listener = TcpListener::bind(addr).await.unwrap(); + loop { + let (stream, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); + let hyper_service = hyper_service.clone(); + tokio::task::spawn(async move { + if let Err(err) = http1::Builder::new() + .preserve_header_case(true) + .title_case_headers(true) + .serve_connection(io, hyper_service) + .with_upgrades() + .await + { + println!("Failed to serve connection: {:?}", err); + } + }); + } } -// use axum::{ -// body::Body, -// extract::Request, -// http::{Method, StatusCode}, -// response::{IntoResponse, Response}, -// routing::get, -// Router, -// }; -// use hyper::upgrade::Upgraded; -// use std::net::SocketAddr; -// use tokio::net::TcpStream; -// use tower::{make::Shared, ServiceExt}; -// use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +async fn proxy(req: Request) -> Result { + tracing::trace!(?req); -// #[tokio::main] -// async fn main() { -// tracing_subscriber::registry() -// .with( -// tracing_subscriber::EnvFilter::try_from_default_env() -// .unwrap_or_else(|_| "example_http_proxy=trace,tower_http=debug".into()), -// ) -// .with(tracing_subscriber::fmt::layer()) -// .init(); + if let Some(host_addr) = req.uri().authority().map(|auth| auth.to_string()) { + tokio::task::spawn(async move { + match hyper::upgrade::on(req).await { + Ok(upgraded) => { + if let Err(e) = tunnel(upgraded, host_addr).await { + tracing::warn!("server io error: {}", e); + }; + } + Err(e) => tracing::warn!("upgrade error: {}", e), + } + }); -// let router_svc = Router::new().route("/", get(|| async { "Hello, World!" })); + Ok(Response::new(Body::empty())) + } else { + tracing::warn!("CONNECT host is not socket addr: {:?}", req.uri()); + Ok(( + StatusCode::BAD_REQUEST, + "CONNECT must be to a socket address", + ) + .into_response()) + } +} -// let service = tower::service_fn(move |req: Request<_>| { -// let router_svc = router_svc.clone(); -// let req = req.map(Body::new); -// async move { -// if req.method() == Method::CONNECT { -// proxy(req).await -// } else { -// router_svc.oneshot(req).await.map_err(|err| match err {}) -// } -// } -// }); +async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> { + let mut server = TcpStream::connect(addr).await?; + let mut upgraded = TokioIo::new(upgraded); -// let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); -// tracing::debug!("listening on {}", addr); -// hyper::Server::bind(&addr) -// .http1_preserve_header_case(true) -// .http1_title_case_headers(true) -// .serve(Shared::new(service)) -// .await -// .unwrap(); -// } + let (from_client, from_server) = + tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?; -// async fn proxy(req: Request) -> Result { -// tracing::trace!(?req); + tracing::debug!( + "client wrote {} bytes and received {} bytes", + from_client, + from_server + ); -// if let Some(host_addr) = req.uri().authority().map(|auth| auth.to_string()) { -// tokio::task::spawn(async move { -// match hyper::upgrade::on(req).await { -// Ok(upgraded) => { -// if let Err(e) = tunnel(upgraded, host_addr).await { -// tracing::warn!("server io error: {}", e); -// }; -// } -// Err(e) => tracing::warn!("upgrade error: {}", e), -// } -// }); - -// Ok(Response::new(Body::empty())) -// } else { -// tracing::warn!("CONNECT host is not socket addr: {:?}", req.uri()); -// Ok(( -// StatusCode::BAD_REQUEST, -// "CONNECT must be to a socket address", -// ) -// .into_response()) -// } -// } - -// async fn tunnel(mut upgraded: Upgraded, addr: String) -> std::io::Result<()> { -// let mut server = TcpStream::connect(addr).await?; - -// let (from_client, from_server) = -// tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?; - -// tracing::debug!( -// "client wrote {} bytes and received {} bytes", -// from_client, -// from_server -// ); - -// Ok(()) -// } + Ok(()) +}