diff --git a/axum/Cargo.toml b/axum/Cargo.toml index a1a37573..bb7e0889 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -51,8 +51,8 @@ tower-layer = "0.3.2" tower-service = "0.3" # wont need this when axum uses http-body 1.0 -hyper1 = { package = "hyper", version = "=1.0.0-rc.3", features = ["server", "http1"] } -tower-hyper-http-body-compat = { version = "0.1.4", features = ["server", "http1"] } +hyper1 = { package = "hyper", version = "=1.0.0-rc.4", features = ["server", "http1"] } +tower-hyper-http-body-compat = { version = "0.2", features = ["server", "http1"] } # optional dependencies axum-macros = { path = "../axum-macros", version = "0.3.7", optional = true } diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 7d3e235f..16ac8ad2 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -92,7 +92,7 @@ use self::rejection::*; use super::FromRequestParts; -use crate::{body::Bytes, response::Response, Error}; +use crate::{body::Bytes, hyper1_tokio_io::TokioIo, response::Response, Error}; use async_trait::async_trait; use axum_core::body::Body; use futures_util::{ @@ -293,6 +293,7 @@ impl WebSocketUpgrade { return; } }; + let upgraded = TokioIo::new(upgraded); let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config)) @@ -430,7 +431,7 @@ fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> /// See [the module level documentation](self) for more details. #[derive(Debug)] pub struct WebSocket { - inner: WebSocketStream, + inner: WebSocketStream>, protocol: Option, } diff --git a/axum/src/hyper1_tokio_io.rs b/axum/src/hyper1_tokio_io.rs new file mode 100644 index 00000000..40f1b2aa --- /dev/null +++ b/axum/src/hyper1_tokio_io.rs @@ -0,0 +1,162 @@ +// Copied from https://github.com/hyperium/hyper-util/blob/master/src/rt/tokio_io.rs + +#![allow(unsafe_code)] + +//! Tokio IO integration for hyper +use hyper1 as hyper; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use pin_project_lite::pin_project; + +pin_project! { + /// A wrapping implementing hyper IO traits for a type that + /// implements Tokio's IO traits. + #[derive(Debug)] + pub(crate) struct TokioIo { + #[pin] + inner: T, + } +} + +impl TokioIo { + /// Wrap a type implementing Tokio's IO traits. + pub(crate) fn new(inner: T) -> Self { + Self { inner } + } + + /// Borrow the inner type. + pub(crate) fn inner(&self) -> &T { + &self.inner + } +} + +impl hyper::rt::Read for TokioIo +where + T: tokio::io::AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl hyper::rt::Write for TokioIo +where + T: tokio::io::AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} + +impl tokio::io::AsyncRead for TokioIo +where + T: hyper::rt::Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + tbuf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + //let init = tbuf.initialized().len(); + let filled = tbuf.filled().len(); + let sub_filled = unsafe { + let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); + + match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { + Poll::Ready(Ok(())) => buf.filled().len(), + other => return other, + } + }; + + let n_filled = filled + sub_filled; + // At least sub_filled bytes had to have been initialized. + let n_init = sub_filled; + unsafe { + tbuf.assume_init(n_init); + tbuf.set_filled(n_filled); + } + + Poll::Ready(Ok(())) + } +} + +impl tokio::io::AsyncWrite for TokioIo +where + T: hyper::rt::Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + hyper::rt::Write::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + hyper::rt::Write::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + hyper::rt::Write::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + hyper::rt::Write::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) + } +} diff --git a/axum/src/lib.rs b/axum/src/lib.rs index 6fbb45b4..d3e38c36 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -418,7 +418,9 @@ )] #![deny(unreachable_pub, private_in_public)] #![allow(elided_lifetimes_in_paths, clippy::type_complexity)] -#![forbid(unsafe_code)] +// can't be `forbid` since we've vendored code from hyper-util that contains `unsafe` +// when hyper-util is on crates.io we can stop vendoring it and go back to `forbid` +#![deny(unsafe_code)] #![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))] #![cfg_attr(test, allow(clippy::float_cmp))] #![cfg_attr(not(test), warn(clippy::print_stdout, clippy::dbg_macro))] @@ -430,6 +432,8 @@ mod boxed; mod extension; #[cfg(feature = "form")] mod form; +#[cfg(feature = "tokio")] +mod hyper1_tokio_io; #[cfg(feature = "json")] mod json; mod service_ext; diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 096e7ec7..027d735e 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -916,7 +916,7 @@ async fn state_isnt_cloned_too_much() { client.get("/").send().await; - assert_eq!(COUNT.load(Ordering::SeqCst), 4); + assert_eq!(COUNT.load(Ordering::SeqCst), 5); } #[crate::test] diff --git a/axum/src/serve.rs b/axum/src/serve.rs index 83f529c8..76058da1 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -2,6 +2,7 @@ use std::{convert::Infallible, io, net::SocketAddr}; +use crate::hyper1_tokio_io::TokioIo; use axum_core::{body::Body, extract::Request, response::Response}; use futures_util::{future::poll_fn, FutureExt}; use hyper1::server::conn::http1; @@ -82,12 +83,13 @@ where { loop { let (tcp_stream, remote_addr) = tcp_listener.accept().await?; + let tcp_stream = TokioIo::new(tcp_stream); poll_fn(|cx| make_service.poll_ready(cx)) .await .unwrap_or_else(|err| match err {}); - let mut service = make_service + let service = make_service .call(IncomingStream { tcp_stream: &tcp_stream, remote_addr, @@ -96,6 +98,10 @@ where .unwrap_or_else(|err| match err {}); let service = hyper1::service::service_fn(move |req: Request| { + // `hyper1::service::service_fn` takes an `Fn` closure. So we need an owned service in + // order to call `poll_ready` and `call` which need `&mut self` + let mut service = service.clone(); + let req = req.map(|body| { // wont need this when axum uses http-body 1.0 let http_body_04 = HttpBody1ToHttpBody04::new(body); @@ -158,14 +164,14 @@ where /// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo #[derive(Debug)] pub struct IncomingStream<'a> { - tcp_stream: &'a TcpStream, + tcp_stream: &'a TokioIo, remote_addr: SocketAddr, } impl IncomingStream<'_> { /// Returns the local address that this stream is bound to. pub fn local_addr(&self) -> std::io::Result { - self.tcp_stream.local_addr() + self.tcp_stream.inner().local_addr() } /// Returns the remote address that this stream is bound to. diff --git a/examples/hyper-1-0/Cargo.toml b/examples/hyper-1-0/Cargo.toml index 772b268d..f55e532a 100644 --- a/examples/hyper-1-0/Cargo.toml +++ b/examples/hyper-1-0/Cargo.toml @@ -6,9 +6,10 @@ publish = false [dependencies] axum = { path = "../../axum" } -hyper = { version = "1.0.0-rc.3", features = ["full"] } +hyper = { version = "=1.0.0-rc.4", features = ["full"] } +hyper-util = { git = "https://github.com/hyperium/hyper-util", rev = "f898015", features = ["full"] } tokio = { version = "1.0", features = ["full"] } tower-http = { version = "0.4", features = ["trace"] } -tower-hyper-http-body-compat = { version = "0.1", features = ["http1", "server"] } +tower-hyper-http-body-compat = { version = "0.2", features = ["http1", "server"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/hyper-1-0/src/main.rs b/examples/hyper-1-0/src/main.rs index 51493d44..06216b57 100644 --- a/examples/hyper-1-0/src/main.rs +++ b/examples/hyper-1-0/src/main.rs @@ -38,6 +38,7 @@ async fn main() { tracing::debug!("listening on {addr}"); loop { let (tcp_stream, _) = tcp_listener.accept().await.unwrap(); + let tcp_stream = hyper_util::rt::TokioIo::new(tcp_stream); let service = service.clone(); tokio::task::spawn(async move { if let Err(http_err) = http1::Builder::new()