Update to hyper 1.0.0-rc.4 (#2094)

This commit is contained in:
David Pedersen 2023-07-15 17:38:38 +02:00 committed by GitHub
parent 5d96ca9fcd
commit b34715fe81
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 186 additions and 11 deletions

View file

@ -51,8 +51,8 @@ tower-layer = "0.3.2"
tower-service = "0.3"
# wont need this when axum uses http-body 1.0
hyper1 = { package = "hyper", version = "=1.0.0-rc.3", features = ["server", "http1"] }
tower-hyper-http-body-compat = { version = "0.1.4", features = ["server", "http1"] }
hyper1 = { package = "hyper", version = "=1.0.0-rc.4", features = ["server", "http1"] }
tower-hyper-http-body-compat = { version = "0.2", features = ["server", "http1"] }
# optional dependencies
axum-macros = { path = "../axum-macros", version = "0.3.7", optional = true }

View file

@ -92,7 +92,7 @@
use self::rejection::*;
use super::FromRequestParts;
use crate::{body::Bytes, response::Response, Error};
use crate::{body::Bytes, hyper1_tokio_io::TokioIo, response::Response, Error};
use async_trait::async_trait;
use axum_core::body::Body;
use futures_util::{
@ -293,6 +293,7 @@ impl<F> WebSocketUpgrade<F> {
return;
}
};
let upgraded = TokioIo::new(upgraded);
let socket =
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
@ -430,7 +431,7 @@ fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) ->
/// See [the module level documentation](self) for more details.
#[derive(Debug)]
pub struct WebSocket {
inner: WebSocketStream<hyper1::upgrade::Upgraded>,
inner: WebSocketStream<TokioIo<hyper1::upgrade::Upgraded>>,
protocol: Option<HeaderValue>,
}

162
axum/src/hyper1_tokio_io.rs Normal file
View file

@ -0,0 +1,162 @@
// Copied from https://github.com/hyperium/hyper-util/blob/master/src/rt/tokio_io.rs
#![allow(unsafe_code)]
//! Tokio IO integration for hyper
use hyper1 as hyper;
use std::{
pin::Pin,
task::{Context, Poll},
};
use pin_project_lite::pin_project;
pin_project! {
/// A wrapping implementing hyper IO traits for a type that
/// implements Tokio's IO traits.
#[derive(Debug)]
pub(crate) struct TokioIo<T> {
#[pin]
inner: T,
}
}
impl<T> TokioIo<T> {
/// Wrap a type implementing Tokio's IO traits.
pub(crate) fn new(inner: T) -> Self {
Self { inner }
}
/// Borrow the inner type.
pub(crate) fn inner(&self) -> &T {
&self.inner
}
}
impl<T> hyper::rt::Read for TokioIo<T>
where
T: tokio::io::AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: hyper::rt::ReadBufCursor<'_>,
) -> Poll<Result<(), std::io::Error>> {
let n = unsafe {
let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
Poll::Ready(Ok(())) => tbuf.filled().len(),
other => return other,
}
};
unsafe {
buf.advance(n);
}
Poll::Ready(Ok(()))
}
}
impl<T> hyper::rt::Write for TokioIo<T>
where
T: tokio::io::AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
}
fn is_write_vectored(&self) -> bool {
tokio::io::AsyncWrite::is_write_vectored(&self.inner)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
}
}
impl<T> tokio::io::AsyncRead for TokioIo<T>
where
T: hyper::rt::Read,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
tbuf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<Result<(), std::io::Error>> {
//let init = tbuf.initialized().len();
let filled = tbuf.filled().len();
let sub_filled = unsafe {
let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) {
Poll::Ready(Ok(())) => buf.filled().len(),
other => return other,
}
};
let n_filled = filled + sub_filled;
// At least sub_filled bytes had to have been initialized.
let n_init = sub_filled;
unsafe {
tbuf.assume_init(n_init);
tbuf.set_filled(n_filled);
}
Poll::Ready(Ok(()))
}
}
impl<T> tokio::io::AsyncWrite for TokioIo<T>
where
T: hyper::rt::Write,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
hyper::rt::Write::poll_write(self.project().inner, cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
hyper::rt::Write::poll_flush(self.project().inner, cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
hyper::rt::Write::poll_shutdown(self.project().inner, cx)
}
fn is_write_vectored(&self) -> bool {
hyper::rt::Write::is_write_vectored(&self.inner)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
}
}

View file

@ -418,7 +418,9 @@
)]
#![deny(unreachable_pub, private_in_public)]
#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
#![forbid(unsafe_code)]
// can't be `forbid` since we've vendored code from hyper-util that contains `unsafe`
// when hyper-util is on crates.io we can stop vendoring it and go back to `forbid`
#![deny(unsafe_code)]
#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))]
#![cfg_attr(test, allow(clippy::float_cmp))]
#![cfg_attr(not(test), warn(clippy::print_stdout, clippy::dbg_macro))]
@ -430,6 +432,8 @@ mod boxed;
mod extension;
#[cfg(feature = "form")]
mod form;
#[cfg(feature = "tokio")]
mod hyper1_tokio_io;
#[cfg(feature = "json")]
mod json;
mod service_ext;

View file

@ -916,7 +916,7 @@ async fn state_isnt_cloned_too_much() {
client.get("/").send().await;
assert_eq!(COUNT.load(Ordering::SeqCst), 4);
assert_eq!(COUNT.load(Ordering::SeqCst), 5);
}
#[crate::test]

View file

@ -2,6 +2,7 @@
use std::{convert::Infallible, io, net::SocketAddr};
use crate::hyper1_tokio_io::TokioIo;
use axum_core::{body::Body, extract::Request, response::Response};
use futures_util::{future::poll_fn, FutureExt};
use hyper1::server::conn::http1;
@ -82,12 +83,13 @@ where
{
loop {
let (tcp_stream, remote_addr) = tcp_listener.accept().await?;
let tcp_stream = TokioIo::new(tcp_stream);
poll_fn(|cx| make_service.poll_ready(cx))
.await
.unwrap_or_else(|err| match err {});
let mut service = make_service
let service = make_service
.call(IncomingStream {
tcp_stream: &tcp_stream,
remote_addr,
@ -96,6 +98,10 @@ where
.unwrap_or_else(|err| match err {});
let service = hyper1::service::service_fn(move |req: Request<hyper1::body::Incoming>| {
// `hyper1::service::service_fn` takes an `Fn` closure. So we need an owned service in
// order to call `poll_ready` and `call` which need `&mut self`
let mut service = service.clone();
let req = req.map(|body| {
// wont need this when axum uses http-body 1.0
let http_body_04 = HttpBody1ToHttpBody04::new(body);
@ -158,14 +164,14 @@ where
/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
#[derive(Debug)]
pub struct IncomingStream<'a> {
tcp_stream: &'a TcpStream,
tcp_stream: &'a TokioIo<TcpStream>,
remote_addr: SocketAddr,
}
impl IncomingStream<'_> {
/// Returns the local address that this stream is bound to.
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.tcp_stream.local_addr()
self.tcp_stream.inner().local_addr()
}
/// Returns the remote address that this stream is bound to.

View file

@ -6,9 +6,10 @@ publish = false
[dependencies]
axum = { path = "../../axum" }
hyper = { version = "1.0.0-rc.3", features = ["full"] }
hyper = { version = "=1.0.0-rc.4", features = ["full"] }
hyper-util = { git = "https://github.com/hyperium/hyper-util", rev = "f898015", features = ["full"] }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.4", features = ["trace"] }
tower-hyper-http-body-compat = { version = "0.1", features = ["http1", "server"] }
tower-hyper-http-body-compat = { version = "0.2", features = ["http1", "server"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View file

@ -38,6 +38,7 @@ async fn main() {
tracing::debug!("listening on {addr}");
loop {
let (tcp_stream, _) = tcp_listener.accept().await.unwrap();
let tcp_stream = hyper_util::rt::TokioIo::new(tcp_stream);
let service = service.clone();
tokio::task::spawn(async move {
if let Err(http_err) = http1::Builder::new()