diff --git a/examples/README.md b/examples/README.md index 570b9cce..d1e3dc8d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -12,3 +12,4 @@ - [`websocket`](../examples/websocket.rs) - How to build an app that handles WebSocket connections. - [`error_handling_and_dependency_injection`](../examples/error_handling_and_dependency_injection.rs) - How to handle errors and dependency injection using trait objects. - [`tokio_postgres`](../examples/tokio_postgres.rs) - How to use a tokio-postgres and bb8 to query a database. +- [`unix_domain_socket`](../examples/unix_domain_socket.rs) - How to run an Axum server over unix domain sockets. diff --git a/examples/unix_domain_socket.rs b/examples/unix_domain_socket.rs new file mode 100644 index 00000000..dc5259a8 --- /dev/null +++ b/examples/unix_domain_socket.rs @@ -0,0 +1,126 @@ +use axum::prelude::*; +use futures::ready; +use http::{Method, StatusCode, Uri}; +use hyper::{ + client::connect::{Connected, Connection}, + server::accept::Accept, +}; +use std::{ + io, + path::PathBuf, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::net::UnixListener; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::UnixStream, +}; +use tower::BoxError; + +#[cfg(not(unix))] +fn main() { + println!("This example requires unix") +} + +#[cfg(unix)] +#[tokio::main] +async fn main() { + let path = PathBuf::from("/tmp/axum/helloworld"); + + let _ = tokio::fs::remove_file(&path).await; + tokio::fs::create_dir_all(path.parent().unwrap()) + .await + .unwrap(); + + let uds = UnixListener::bind(path.clone()).unwrap(); + tokio::spawn(async { + let app = route("/", get(|| async { "Hello, World!" })); + + hyper::Server::builder(ServerAccept { uds }) + .serve(app.into_make_service()) + .await + .unwrap(); + }); + + let connector = tower::service_fn(move |_: Uri| { + let path = path.clone(); + Box::pin(async move { + let stream = UnixStream::connect(path).await?; + Ok::<_, io::Error>(ClientConnection { stream }) + }) + }); + let client = hyper::Client::builder().build(connector); + + let request = Request::builder() + .method(Method::GET) + .uri("http://uri-doesnt-matter.com") + .body(Body::empty()) + .unwrap(); + + let response = client.request(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let body = String::from_utf8(body.to_vec()).unwrap(); + assert_eq!(body, "Hello, World!"); +} + +struct ServerAccept { + uds: UnixListener, +} + +impl Accept for ServerAccept { + type Conn = UnixStream; + type Error = BoxError; + + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let (stream, _addr) = ready!(self.uds.poll_accept(cx))?; + Poll::Ready(Some(Ok(stream))) + } +} + +struct ClientConnection { + stream: UnixStream, +} + +impl AsyncWrite for ClientConnection { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.stream).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.stream).poll_shutdown(cx) + } +} + +impl AsyncRead for ClientConnection { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.stream).poll_read(cx, buf) + } +} + +impl Connection for ClientConnection { + fn connected(&self) -> Connected { + Connected::new() + } +}