mirror of
https://github.com/tokio-rs/axum.git
synced 2024-11-21 22:56:46 +01:00
Update to hyper 1.0.0-rc.4 (#2094)
This commit is contained in:
parent
5d96ca9fcd
commit
b34715fe81
8 changed files with 186 additions and 11 deletions
|
@ -51,8 +51,8 @@ tower-layer = "0.3.2"
|
|||
tower-service = "0.3"
|
||||
|
||||
# wont need this when axum uses http-body 1.0
|
||||
hyper1 = { package = "hyper", version = "=1.0.0-rc.3", features = ["server", "http1"] }
|
||||
tower-hyper-http-body-compat = { version = "0.1.4", features = ["server", "http1"] }
|
||||
hyper1 = { package = "hyper", version = "=1.0.0-rc.4", features = ["server", "http1"] }
|
||||
tower-hyper-http-body-compat = { version = "0.2", features = ["server", "http1"] }
|
||||
|
||||
# optional dependencies
|
||||
axum-macros = { path = "../axum-macros", version = "0.3.7", optional = true }
|
||||
|
|
|
@ -92,7 +92,7 @@
|
|||
|
||||
use self::rejection::*;
|
||||
use super::FromRequestParts;
|
||||
use crate::{body::Bytes, response::Response, Error};
|
||||
use crate::{body::Bytes, hyper1_tokio_io::TokioIo, response::Response, Error};
|
||||
use async_trait::async_trait;
|
||||
use axum_core::body::Body;
|
||||
use futures_util::{
|
||||
|
@ -293,6 +293,7 @@ impl<F> WebSocketUpgrade<F> {
|
|||
return;
|
||||
}
|
||||
};
|
||||
let upgraded = TokioIo::new(upgraded);
|
||||
|
||||
let socket =
|
||||
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
|
||||
|
@ -430,7 +431,7 @@ fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) ->
|
|||
/// See [the module level documentation](self) for more details.
|
||||
#[derive(Debug)]
|
||||
pub struct WebSocket {
|
||||
inner: WebSocketStream<hyper1::upgrade::Upgraded>,
|
||||
inner: WebSocketStream<TokioIo<hyper1::upgrade::Upgraded>>,
|
||||
protocol: Option<HeaderValue>,
|
||||
}
|
||||
|
||||
|
|
162
axum/src/hyper1_tokio_io.rs
Normal file
162
axum/src/hyper1_tokio_io.rs
Normal file
|
@ -0,0 +1,162 @@
|
|||
// Copied from https://github.com/hyperium/hyper-util/blob/master/src/rt/tokio_io.rs
|
||||
|
||||
#![allow(unsafe_code)]
|
||||
|
||||
//! Tokio IO integration for hyper
|
||||
use hyper1 as hyper;
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
pin_project! {
|
||||
/// A wrapping implementing hyper IO traits for a type that
|
||||
/// implements Tokio's IO traits.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct TokioIo<T> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> TokioIo<T> {
|
||||
/// Wrap a type implementing Tokio's IO traits.
|
||||
pub(crate) fn new(inner: T) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
/// Borrow the inner type.
|
||||
pub(crate) fn inner(&self) -> &T {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> hyper::rt::Read for TokioIo<T>
|
||||
where
|
||||
T: tokio::io::AsyncRead,
|
||||
{
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
mut buf: hyper::rt::ReadBufCursor<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let n = unsafe {
|
||||
let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
|
||||
match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
|
||||
Poll::Ready(Ok(())) => tbuf.filled().len(),
|
||||
other => return other,
|
||||
}
|
||||
};
|
||||
|
||||
unsafe {
|
||||
buf.advance(n);
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> hyper::rt::Write for TokioIo<T>
|
||||
where
|
||||
T: tokio::io::AsyncWrite,
|
||||
{
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
tokio::io::AsyncWrite::is_write_vectored(&self.inner)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> tokio::io::AsyncRead for TokioIo<T>
|
||||
where
|
||||
T: hyper::rt::Read,
|
||||
{
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
tbuf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
//let init = tbuf.initialized().len();
|
||||
let filled = tbuf.filled().len();
|
||||
let sub_filled = unsafe {
|
||||
let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
|
||||
|
||||
match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) {
|
||||
Poll::Ready(Ok(())) => buf.filled().len(),
|
||||
other => return other,
|
||||
}
|
||||
};
|
||||
|
||||
let n_filled = filled + sub_filled;
|
||||
// At least sub_filled bytes had to have been initialized.
|
||||
let n_init = sub_filled;
|
||||
unsafe {
|
||||
tbuf.assume_init(n_init);
|
||||
tbuf.set_filled(n_filled);
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> tokio::io::AsyncWrite for TokioIo<T>
|
||||
where
|
||||
T: hyper::rt::Write,
|
||||
{
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
hyper::rt::Write::poll_write(self.project().inner, cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
hyper::rt::Write::poll_flush(self.project().inner, cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
hyper::rt::Write::poll_shutdown(self.project().inner, cx)
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
hyper::rt::Write::is_write_vectored(&self.inner)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
|
||||
}
|
||||
}
|
|
@ -418,7 +418,9 @@
|
|||
)]
|
||||
#![deny(unreachable_pub, private_in_public)]
|
||||
#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
|
||||
#![forbid(unsafe_code)]
|
||||
// can't be `forbid` since we've vendored code from hyper-util that contains `unsafe`
|
||||
// when hyper-util is on crates.io we can stop vendoring it and go back to `forbid`
|
||||
#![deny(unsafe_code)]
|
||||
#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))]
|
||||
#![cfg_attr(test, allow(clippy::float_cmp))]
|
||||
#![cfg_attr(not(test), warn(clippy::print_stdout, clippy::dbg_macro))]
|
||||
|
@ -430,6 +432,8 @@ mod boxed;
|
|||
mod extension;
|
||||
#[cfg(feature = "form")]
|
||||
mod form;
|
||||
#[cfg(feature = "tokio")]
|
||||
mod hyper1_tokio_io;
|
||||
#[cfg(feature = "json")]
|
||||
mod json;
|
||||
mod service_ext;
|
||||
|
|
|
@ -916,7 +916,7 @@ async fn state_isnt_cloned_too_much() {
|
|||
|
||||
client.get("/").send().await;
|
||||
|
||||
assert_eq!(COUNT.load(Ordering::SeqCst), 4);
|
||||
assert_eq!(COUNT.load(Ordering::SeqCst), 5);
|
||||
}
|
||||
|
||||
#[crate::test]
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
use std::{convert::Infallible, io, net::SocketAddr};
|
||||
|
||||
use crate::hyper1_tokio_io::TokioIo;
|
||||
use axum_core::{body::Body, extract::Request, response::Response};
|
||||
use futures_util::{future::poll_fn, FutureExt};
|
||||
use hyper1::server::conn::http1;
|
||||
|
@ -82,12 +83,13 @@ where
|
|||
{
|
||||
loop {
|
||||
let (tcp_stream, remote_addr) = tcp_listener.accept().await?;
|
||||
let tcp_stream = TokioIo::new(tcp_stream);
|
||||
|
||||
poll_fn(|cx| make_service.poll_ready(cx))
|
||||
.await
|
||||
.unwrap_or_else(|err| match err {});
|
||||
|
||||
let mut service = make_service
|
||||
let service = make_service
|
||||
.call(IncomingStream {
|
||||
tcp_stream: &tcp_stream,
|
||||
remote_addr,
|
||||
|
@ -96,6 +98,10 @@ where
|
|||
.unwrap_or_else(|err| match err {});
|
||||
|
||||
let service = hyper1::service::service_fn(move |req: Request<hyper1::body::Incoming>| {
|
||||
// `hyper1::service::service_fn` takes an `Fn` closure. So we need an owned service in
|
||||
// order to call `poll_ready` and `call` which need `&mut self`
|
||||
let mut service = service.clone();
|
||||
|
||||
let req = req.map(|body| {
|
||||
// wont need this when axum uses http-body 1.0
|
||||
let http_body_04 = HttpBody1ToHttpBody04::new(body);
|
||||
|
@ -158,14 +164,14 @@ where
|
|||
/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
|
||||
#[derive(Debug)]
|
||||
pub struct IncomingStream<'a> {
|
||||
tcp_stream: &'a TcpStream,
|
||||
tcp_stream: &'a TokioIo<TcpStream>,
|
||||
remote_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl IncomingStream<'_> {
|
||||
/// Returns the local address that this stream is bound to.
|
||||
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
|
||||
self.tcp_stream.local_addr()
|
||||
self.tcp_stream.inner().local_addr()
|
||||
}
|
||||
|
||||
/// Returns the remote address that this stream is bound to.
|
||||
|
|
|
@ -6,9 +6,10 @@ publish = false
|
|||
|
||||
[dependencies]
|
||||
axum = { path = "../../axum" }
|
||||
hyper = { version = "1.0.0-rc.3", features = ["full"] }
|
||||
hyper = { version = "=1.0.0-rc.4", features = ["full"] }
|
||||
hyper-util = { git = "https://github.com/hyperium/hyper-util", rev = "f898015", features = ["full"] }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
tower-http = { version = "0.4", features = ["trace"] }
|
||||
tower-hyper-http-body-compat = { version = "0.1", features = ["http1", "server"] }
|
||||
tower-hyper-http-body-compat = { version = "0.2", features = ["http1", "server"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
|
|
@ -38,6 +38,7 @@ async fn main() {
|
|||
tracing::debug!("listening on {addr}");
|
||||
loop {
|
||||
let (tcp_stream, _) = tcp_listener.accept().await.unwrap();
|
||||
let tcp_stream = hyper_util::rt::TokioIo::new(tcp_stream);
|
||||
let service = service.clone();
|
||||
tokio::task::spawn(async move {
|
||||
if let Err(http_err) = http1::Builder::new()
|
||||
|
|
Loading…
Reference in a new issue