axum/examples/low-level-rustls/src/main.rs

117 lines
3.9 KiB
Rust

//! Run with
//!
//! ```not_rust
//! cargo run -p example-low-level-rustls
//! ```
use axum::{extract::Request, routing::get, Router};
use futures_util::pin_mut;
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
use rustls_pemfile::{certs, pkcs8_private_keys};
use std::{
fs::File,
io::BufReader,
path::{Path, PathBuf},
sync::Arc,
};
use tokio::net::TcpListener;
use tokio_rustls::{
rustls::{Certificate, PrivateKey, ServerConfig},
TlsAcceptor,
};
use tower_service::Service;
use tracing::{error, info, warn};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let rustls_config = rustls_server_config(
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("self_signed_certs")
.join("key.pem"),
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("self_signed_certs")
.join("cert.pem"),
);
let tls_acceptor = TlsAcceptor::from(rustls_config);
let bind = "[::1]:3000";
let tcp_listener = TcpListener::bind(bind).await.unwrap();
info!("HTTPS server listening on {bind}. To contact curl -k https://localhost:3000");
let app = Router::new().route("/", get(handler));
pin_mut!(tcp_listener);
loop {
let tower_service = app.clone();
let tls_acceptor = tls_acceptor.clone();
// Wait for new tcp connection
let (cnx, addr) = tcp_listener.accept().await.unwrap();
tokio::spawn(async move {
// Wait for tls handshake to happen
let Ok(stream) = tls_acceptor.accept(cnx).await else {
error!("error during tls handshake connection from {}", addr);
return;
};
// Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
// `TokioIo` converts between them.
let stream = TokioIo::new(stream);
// Hyper also has its own `Service` trait and doesn't use tower. We can use
// `hyper::service::service_fn` to create a hyper `Service` that calls our app through
// `tower::Service::call`.
let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
// We have to clone `tower_service` because hyper's `Service` uses `&self` whereas
// tower's `Service` requires `&mut self`.
//
// We don't need to call `poll_ready` since `Router` is always ready.
tower_service.clone().call(request)
});
let ret = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(stream, hyper_service)
.await;
if let Err(err) = ret {
warn!("error serving connection from {}: {}", addr, err);
}
});
}
}
async fn handler() -> &'static str {
"Hello, World!"
}
fn rustls_server_config(key: impl AsRef<Path>, cert: impl AsRef<Path>) -> Arc<ServerConfig> {
let mut key_reader = BufReader::new(File::open(key).unwrap());
let mut cert_reader = BufReader::new(File::open(cert).unwrap());
let key = PrivateKey(pkcs8_private_keys(&mut key_reader).unwrap().remove(0));
let certs = certs(&mut cert_reader)
.unwrap()
.into_iter()
.map(Certificate)
.collect();
let mut config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)
.expect("bad certificate/key");
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Arc::new(config)
}