1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-04-26 13:56:22 +02:00

Update to hyper 1.0-rc.3 and http-body-util 0.1.0-rc.2

This commit is contained in:
David Pedersen 2023-03-23 23:23:39 +01:00
parent 44bb38dfa2
commit c73b6a9969
41 changed files with 790 additions and 816 deletions
axum-core
axum-extra
axum
examples
consume-body-in-extractor-or-middleware
graceful-shutdown/src
handle-head-request
http-proxy/src
listen-multiple-addrs/src
low-level-openssl/src
low-level-rustls/src
print-request-response
query-params-with-empty-strings
rest-grpc-multiplex/src
reverse-proxy/src
testing
tls-rustls/src
unix-domain-socket/src

View file

@ -20,6 +20,7 @@ bytes = "1.0"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
http = "0.2.7"
http-body = "1.0.0-rc.2"
http-body-util = "0.1.0-rc.2"
mime = "0.3.16"
pin-project-lite = "0.2.7"
sync_wrapper = "0.1.1"

View file

@ -2,17 +2,16 @@
use crate::{BoxError, Error};
use bytes::Bytes;
use bytes::{Buf, BufMut};
use futures_util::stream::Stream;
use futures_util::TryStream;
use http::HeaderMap;
use http_body::Body as _;
use http_body::{Body as _, Frame};
use http_body_util::BodyExt;
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};
use sync_wrapper::SyncWrapper;
type BoxBody = http_body::combinators::UnsyncBoxBody<Bytes, Error>;
type BoxBody = http_body_util::combinators::UnsyncBoxBody<Bytes, Error>;
fn boxed<B>(body: B) -> BoxBody
where
@ -35,58 +34,6 @@ where
}
}
// copied from hyper under the following license:
// Copyright (c) 2014-2021 Sean McArthur
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
pub(crate) async fn to_bytes<T>(body: T) -> Result<Bytes, T::Error>
where
T: http_body::Body,
{
futures_util::pin_mut!(body);
// If there's only 1 chunk, we can just return Buf::to_bytes()
let mut first = if let Some(buf) = body.data().await {
buf?
} else {
return Ok(Bytes::new());
};
let second = if let Some(buf) = body.data().await {
buf?
} else {
return Ok(first.copy_to_bytes(first.remaining()));
};
// With more than 1 buf, we gotta flatten into a Vec first.
let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize;
let mut vec = Vec::with_capacity(cap);
vec.put(first);
vec.put(second);
while let Some(buf) = body.data().await {
vec.put(buf?);
}
Ok(vec.into())
}
/// The body type used in axum requests and responses.
#[derive(Debug)]
pub struct Body(BoxBody);
@ -103,7 +50,7 @@ impl Body {
/// Create an empty body.
pub fn empty() -> Self {
Self::new(http_body::Empty::new())
Self::new(http_body_util::Empty::new())
}
/// Create a new `Body` from a [`Stream`].
@ -131,7 +78,7 @@ macro_rules! body_from_impl {
($ty:ty) => {
impl From<$ty> for Body {
fn from(buf: $ty) -> Self {
Self::new(http_body::Full::from(buf))
Self::new(http_body_util::Full::from(buf))
}
}
};
@ -152,19 +99,11 @@ impl http_body::Body for Body {
type Error = Error;
#[inline]
fn poll_data(
fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> std::task::Poll<Option<Result<Self::Data, Self::Error>>> {
Pin::new(&mut self.0).poll_data(cx)
}
#[inline]
fn poll_trailers(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> std::task::Poll<Result<Option<HeaderMap>, Self::Error>> {
Pin::new(&mut self.0).poll_trailers(cx)
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
Pin::new(&mut self.0).poll_frame(cx)
}
#[inline]
@ -182,8 +121,16 @@ impl Stream for Body {
type Item = Result<Bytes, Error>;
#[inline]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.poll_data(cx)
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match futures_util::ready!(self.as_mut().poll_frame(cx)?) {
Some(frame) => match frame.into_data() {
Ok(data) => return Poll::Ready(Some(Ok(data))),
Err(_frame) => {}
},
None => return Poll::Ready(None),
}
}
}
}
@ -203,25 +150,17 @@ where
type Data = Bytes;
type Error = Error;
fn poll_data(
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let stream = self.project().stream.get_pin_mut();
match futures_util::ready!(stream.try_poll_next(cx)) {
Some(Ok(chunk)) => Poll::Ready(Some(Ok(chunk.into()))),
Some(Ok(chunk)) => Poll::Ready(Some(Ok(Frame::data(chunk.into())))),
Some(Err(err)) => Poll::Ready(Some(Err(Error::new(err)))),
None => Poll::Ready(None),
}
}
#[inline]
fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None))
}
}
#[test]

View file

@ -1,7 +1,7 @@
use crate::body::Body;
use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts, Request};
use futures_util::future::BoxFuture;
use http_body::Limited;
use http_body_util::Limited;
mod sealed {
pub trait Sealed {}
@ -326,9 +326,9 @@ impl RequestExt for Request {
match self.extensions().get::<DefaultBodyLimitKind>().copied() {
Some(DefaultBodyLimitKind::Disable) => Err(self),
Some(DefaultBodyLimitKind::Limit(limit)) => {
Ok(self.map(|b| http_body::Limited::new(b, limit)))
Ok(self.map(|b| http_body_util::Limited::new(b, limit)))
}
None => Ok(self.map(|b| http_body::Limited::new(b, DEFAULT_LIMIT))),
None => Ok(self.map(|b| http_body_util::Limited::new(b, DEFAULT_LIMIT))),
}
}

View file

@ -20,7 +20,7 @@ impl FailedToBufferBody {
Ok(err) => err.into_inner(),
Err(err) => err,
};
match box_error.downcast::<http_body::LengthLimitError>() {
match box_error.downcast::<http_body_util::LengthLimitError>() {
Ok(err) => Self::LengthLimitError(LengthLimitError::from_err(err)),
Err(err) => Self::UnknownBodyError(UnknownBodyError::from_err(err)),
}

View file

@ -3,6 +3,7 @@ use crate::{body::Body, RequestExt};
use async_trait::async_trait;
use bytes::Bytes;
use http::{request::Parts, HeaderMap, Method, Uri, Version};
use http_body_util::BodyExt;
use std::convert::Infallible;
#[async_trait]
@ -79,12 +80,16 @@ where
async fn from_request(req: Request, _: &S) -> Result<Self, Self::Rejection> {
let bytes = match req.into_limited_body() {
Ok(limited_body) => crate::body::to_bytes(limited_body)
Ok(limited_body) => limited_body
.collect()
.await
.map_err(FailedToBufferBody::from_err)?,
Err(unlimited_body) => crate::body::to_bytes(unlimited_body)
.map_err(FailedToBufferBody::from_err)?
.to_bytes(),
Err(unlimited_body) => unlimited_body
.collect()
.await
.map_err(FailedToBufferBody::from_err)?,
.map_err(FailedToBufferBody::from_err)?
.to_bytes(),
};
Ok(bytes)

View file

@ -5,7 +5,7 @@ use http::{
header::{self, HeaderMap, HeaderName, HeaderValue},
Extensions, StatusCode,
};
use http_body::SizeHint;
use http_body::{Frame, SizeHint};
use std::{
borrow::Cow,
convert::Infallible,
@ -250,30 +250,23 @@ where
type Data = Bytes;
type Error = Infallible;
fn poll_data(
fn poll_frame(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
if let Some(mut buf) = self.first.take() {
let bytes = buf.copy_to_bytes(buf.remaining());
return Poll::Ready(Some(Ok(bytes)));
return Poll::Ready(Some(Ok(Frame::data(bytes))));
}
if let Some(mut buf) = self.second.take() {
let bytes = buf.copy_to_bytes(buf.remaining());
return Poll::Ready(Some(Ok(bytes)));
return Poll::Ready(Some(Ok(Frame::data(bytes))));
}
Poll::Ready(None)
}
fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None))
}
fn is_end_stream(&self) -> bool {
self.first.is_none() && self.second.is_none()
}

View file

@ -40,6 +40,7 @@ bytes = "1.1.0"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
http = "0.2"
http-body = "1.0.0-rc.2"
http-body-util = "0.1.0-rc.2"
mime = "0.3"
pin-project-lite = "0.2"
tokio = "1.19"

View file

@ -1,6 +1,5 @@
use axum::{
body::{Body, Bytes, HttpBody},
http::HeaderMap,
response::{IntoResponse, Response},
Error,
};
@ -69,18 +68,22 @@ impl HttpBody for AsyncReadBody {
type Data = Bytes;
type Error = Error;
fn poll_data(
#[inline]
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
self.project().body.poll_data(cx)
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
self.project().body.poll_frame(cx)
}
fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
self.project().body.poll_trailers(cx)
#[inline]
fn is_end_stream(&self) -> bool {
self.body.is_end_stream()
}
#[inline]
fn size_hint(&self) -> http_body::SizeHint {
self.body.size_hint()
}
}

View file

@ -232,6 +232,7 @@ fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) {
mod tests {
use super::*;
use axum::{body::Body, extract::FromRef, http::Request, routing::get, Router};
use http_body_util::BodyExt;
use tower::ServiceExt;
macro_rules! cookie_test {
@ -376,7 +377,7 @@ mod tests {
B: axum::body::HttpBody,
B::Error: std::fmt::Debug,
{
let bytes = hyper::body::to_bytes(body).await.unwrap();
let bytes = body.collect().await.unwrap().to_bytes();
String::from_utf8(bytes.to_vec()).unwrap()
}
}

View file

@ -278,7 +278,7 @@ fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
if err
.downcast_ref::<axum::Error>()
.and_then(|err| err.source())
.and_then(|err| err.downcast_ref::<http_body::LengthLimitError>())
.and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
.is_some()
{
return StatusCode::PAYLOAD_TOO_LARGE;

View file

@ -148,6 +148,7 @@ mod tests {
use super::*;
use axum::{body::Body, extract::Path, http::Method, Router};
use http::Request;
use http_body_util::BodyExt;
use tower::ServiceExt;
#[tokio::test]
@ -216,7 +217,7 @@ mod tests {
)
.await
.unwrap();
let bytes = hyper::body::to_bytes(res).await.unwrap();
let bytes = res.collect().await.unwrap().to_bytes();
String::from_utf8(bytes.to_vec()).unwrap()
}
}

View file

@ -37,6 +37,7 @@ bytes = "1.0"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
http = "0.2.9"
http-body = "1.0.0-rc.2"
http-body-util = "0.1.0-rc.2"
hyper = "1.0.0-rc.3"
itoa = "1.0.5"
matchit = "0.7"

View file

@ -3,7 +3,6 @@ use axum::{
routing::{get, post},
Extension, Json, Router,
};
use hyper::server::conn::AddrIncoming;
use serde::{Deserialize, Serialize};
use std::{
io::BufRead,
@ -149,13 +148,7 @@ impl BenchmarkBuilder {
let addr = listener.local_addr().unwrap();
std::thread::spawn(move || {
rt.block_on(async move {
let incoming = AddrIncoming::from_listener(listener).unwrap();
hyper::Server::builder(incoming)
.serve(app.into_make_service())
.await
.unwrap();
});
rt.block_on(axum::serve(listener, app));
});
let mut cmd = Command::new("rewrk");

View file

@ -8,7 +8,6 @@ use super::{Extension, FromRequestParts};
use crate::{middleware::AddExtension, serve::IncomingStream};
use async_trait::async_trait;
use http::request::Parts;
use hyper::server::conn::AddrStream;
use std::{
convert::Infallible,
fmt,
@ -83,12 +82,6 @@ pub trait Connected<T>: Clone + Send + Sync + 'static {
fn connect_info(target: T) -> Self;
}
impl Connected<&AddrStream> for SocketAddr {
fn connect_info(target: &AddrStream) -> Self {
target.remote_addr()
}
}
impl Connected<IncomingStream<'_>> for SocketAddr {
fn connect_info(target: IncomingStream<'_>) -> Self {
target.remote_addr()

View file

@ -240,7 +240,7 @@ fn status_code_from_multer_error(err: &multer::Error) -> StatusCode {
if err
.downcast_ref::<crate::Error>()
.and_then(|err| err.source())
.and_then(|err| err.downcast_ref::<http_body::LengthLimitError>())
.and_then(|err| err.downcast_ref::<http_body_util::LengthLimitError>())
.is_some()
{
return StatusCode::PAYLOAD_TOO_LARGE;

View file

@ -132,7 +132,7 @@ pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
protocol: Option<HeaderValue>,
sec_websocket_key: HeaderValue,
on_upgrade: hyper1::upgrade::OnUpgrade,
on_upgrade: hyper::upgrade::OnUpgrade,
on_failed_upgrade: F,
sec_websocket_protocol: Option<HeaderValue>,
}
@ -386,7 +386,7 @@ where
let on_upgrade = parts
.extensions
.remove::<hyper1::upgrade::OnUpgrade>()
.remove::<hyper::upgrade::OnUpgrade>()
.ok_or(ConnectionNotUpgradable)?;
let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
@ -429,7 +429,7 @@ fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) ->
/// See [the module level documentation](self) for more details.
#[derive(Debug)]
pub struct WebSocket {
inner: WebSocketStream<hyper1::upgrade::Upgraded>,
inner: WebSocketStream<hyper::upgrade::Upgraded>,
protocol: Option<HeaderValue>,
}

View file

@ -376,6 +376,7 @@ mod tests {
use super::*;
use crate::{body::Body, routing::get, Router};
use http::{HeaderMap, StatusCode};
use http_body_util::BodyExt;
use tower::ServiceExt;
#[crate::test]
@ -400,7 +401,7 @@ mod tests {
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = hyper::body::to_bytes(res).await.unwrap();
let body = res.collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"ok");
}
}

View file

@ -40,6 +40,7 @@ use futures_util::{
ready,
stream::{Stream, TryStream},
};
use http_body::Frame;
use pin_project_lite::pin_project;
use std::{
fmt,
@ -129,16 +130,16 @@ where
type Data = Bytes;
type Error = E;
fn poll_data(
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let this = self.project();
match this.event_stream.get_pin_mut().poll_next(cx) {
Poll::Pending => {
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
keep_alive.poll_event(cx).map(|e| Some(Ok(e)))
keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
} else {
Poll::Pending
}
@ -147,19 +148,12 @@ where
if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
keep_alive.reset();
}
Poll::Ready(Some(Ok(event.finalize())))
Poll::Ready(Some(Ok(Frame::data(event.finalize()))))
}
Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
Poll::Ready(None) => Poll::Ready(None),
}
}
fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None))
}
}
/// Server-sent event

View file

@ -1260,6 +1260,7 @@ mod tests {
};
use axum_core::response::IntoResponse;
use http::{header::ALLOW, HeaderMap};
use http_body_util::BodyExt;
use std::time::Duration;
use tower::{timeout::TimeoutLayer, Service, ServiceBuilder, ServiceExt};
use tower_http::{services::fs::ServeDir, validate_request::ValidateRequestHeaderLayer};
@ -1553,7 +1554,8 @@ mod tests {
.unwrap()
.into_response();
let (parts, body) = response.into_parts();
let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap();
let body =
String::from_utf8(BodyExt::collect(body).await.unwrap().to_bytes().to_vec()).unwrap();
(parts.status, parts.headers, body)
}

View file

@ -32,7 +32,7 @@ mod for_handlers {
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.headers()["x-some-header"], "foobar");
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
let body = BodyExt::collect(res.into_body()).await.unwrap().to_bytes();
assert_eq!(body.len(), 0);
}
}
@ -67,7 +67,7 @@ mod for_services {
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.headers()["x-some-header"], "foobar");
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
let body = BodyExt::collect(res.into_body()).await.unwrap().to_bytes();
assert_eq!(body.len(), 0);
}
}

View file

@ -11,6 +11,7 @@ use crate::{
use axum_core::extract::Request;
use futures_util::stream::StreamExt;
use http::{header::ALLOW, header::CONTENT_LENGTH, HeaderMap, StatusCode, Uri};
use http_body_util::BodyExt;
use serde_json::json;
use std::{
convert::Infallible,

View file

@ -4,9 +4,8 @@ use std::{convert::Infallible, io, net::SocketAddr};
use axum_core::{body::Body, extract::Request, response::Response};
use futures_util::{future::poll_fn, FutureExt};
use hyper1::server::conn::http1;
use hyper::server::conn::http1;
use tokio::net::{TcpListener, TcpStream};
use tower_hyper_http_body_compat::{HttpBody04ToHttpBody1, HttpBody1ToHttpBody04};
use tower_service::Service;
/// Serve the service with the supplied listener.
@ -95,13 +94,7 @@ where
.await
.unwrap_or_else(|err| match err {});
let service = hyper1::service::service_fn(move |req: Request<hyper1::body::Incoming>| {
let req = req.map(|body| {
// wont need this when axum uses http-body 1.0
let http_body_04 = HttpBody1ToHttpBody04::new(body);
Body::new(http_body_04)
});
let service = hyper::service::service_fn(move |req: Request<hyper::body::Incoming>| {
// doing this saves cloning the service just to await the service being ready
//
// services like `Router` are always ready, so assume the service
@ -111,20 +104,16 @@ where
Some(Err(err)) => match err {},
None => {
// ...otherwise load shed
let mut res = Response::new(HttpBody04ToHttpBody1::new(Body::empty()));
let mut res = Response::new(Body::empty());
*res.status_mut() = http::StatusCode::SERVICE_UNAVAILABLE;
return std::future::ready(Ok(res)).left_future();
}
}
let future = service.call(req);
let future = service.call(req.map(Body::new));
async move {
let response = future
.await
.unwrap_or_else(|err| match err {})
// wont need this when axum uses http-body 1.0
.map(HttpBody04ToHttpBody1::new);
let response = future.await.unwrap_or_else(|err| match err {});
Ok::<_, Infallible>(response)
}

View file

@ -6,6 +6,7 @@ publish = false
[dependencies]
axum = { path = "../../axum" }
http-body-util = "0.1.0-rc.2"
hyper = "1.0.0-rc.3"
tokio = { version = "1.0", features = ["full"] }
tower = "0.4"

View file

@ -14,6 +14,7 @@ use axum::{
routing::post,
Router,
};
use http_body_util::BodyExt;
use tower::ServiceBuilder;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
@ -51,9 +52,11 @@ async fn buffer_request_body(request: Request) -> Result<Request, Response> {
let (parts, body) = request.into_parts();
// this wont work if the body is an long running stream
let bytes = hyper::body::to_bytes(body)
let bytes = body
.collect()
.await
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?;
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?
.to_bytes();
do_thing_with_request_body(bytes.clone());

View file

@ -5,51 +5,56 @@
//! kill or ctrl-c
//! ```
use axum::{response::Html, routing::get, Router};
use std::net::SocketAddr;
use tokio::signal;
#[tokio::main]
async fn main() {
// build our application with a route
let app = Router::new().route("/", get(handler));
// run it
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
println!("listening on {}", addr);
hyper::Server::bind(&addr)
.serve(app.into_make_service())
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
// TODO
fn main() {
eprint!("this example has not yet been updated to hyper 1.0");
}
async fn handler() -> Html<&'static str> {
Html("<h1>Hello, World!</h1>")
}
// use axum::{response::Html, routing::get, Router};
// use std::net::SocketAddr;
// use tokio::signal;
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
// #[tokio::main]
// async fn main() {
// // build our application with a route
// let app = Router::new().route("/", get(handler));
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
// // run it
// let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
// println!("listening on {}", addr);
// hyper::Server::bind(&addr)
// .serve(app.into_make_service())
// .with_graceful_shutdown(shutdown_signal())
// .await
// .unwrap();
// }
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
// async fn handler() -> Html<&'static str> {
// Html("<h1>Hello, World!</h1>")
// }
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
// async fn shutdown_signal() {
// let ctrl_c = async {
// signal::ctrl_c()
// .await
// .expect("failed to install Ctrl+C handler");
// };
println!("signal received, starting graceful shutdown");
}
// #[cfg(unix)]
// let terminate = async {
// signal::unix::signal(signal::unix::SignalKind::terminate())
// .expect("failed to install signal handler")
// .recv()
// .await;
// };
// #[cfg(not(unix))]
// let terminate = std::future::pending::<()>();
// tokio::select! {
// _ = ctrl_c => {},
// _ = terminate => {},
// }
// println!("signal received, starting graceful shutdown");
// }

View file

@ -9,5 +9,6 @@ axum = { path = "../../axum" }
tokio = { version = "1.0", features = ["full"] }
[dev-dependencies]
http-body-util = "0.1.0-rc.2"
hyper = { version = "1.0.0-rc.3", features = ["full"] }
tower = { version = "0.4", features = ["util"] }

View file

@ -44,6 +44,7 @@ mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use http_body_util::BodyExt;
use tower::ServiceExt;
#[tokio::test]
@ -58,7 +59,7 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.headers()["x-some-header"], "header from GET");
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
let body = response.collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"body from GET");
}
@ -74,7 +75,7 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.headers()["x-some-header"], "header from HEAD");
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
let body = response.collect().await.unwrap().to_bytes();
assert!(body.is_empty());
}
}

View file

@ -12,91 +12,96 @@
//!
//! Example is based on <https://github.com/hyperium/hyper/blob/master/examples/http_proxy.rs>
use axum::{
body::Body,
extract::Request,
http::{Method, StatusCode},
response::{IntoResponse, Response},
routing::get,
Router,
};
use hyper::upgrade::Upgraded;
use std::net::SocketAddr;
use tokio::net::TcpStream;
use tower::{make::Shared, ServiceExt};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "example_http_proxy=trace,tower_http=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let router_svc = Router::new().route("/", get(|| async { "Hello, World!" }));
let service = tower::service_fn(move |req: Request<_>| {
let router_svc = router_svc.clone();
let req = req.map(Body::new);
async move {
if req.method() == Method::CONNECT {
proxy(req).await
} else {
router_svc.oneshot(req).await.map_err(|err| match err {})
}
}
});
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
hyper::Server::bind(&addr)
.http1_preserve_header_case(true)
.http1_title_case_headers(true)
.serve(Shared::new(service))
.await
.unwrap();
// TODO
fn main() {
eprint!("this example has not yet been updated to hyper 1.0");
}
async fn proxy(req: Request) -> Result<Response, hyper::Error> {
tracing::trace!(?req);
// use axum::{
// body::Body,
// extract::Request,
// http::{Method, StatusCode},
// response::{IntoResponse, Response},
// routing::get,
// Router,
// };
// use hyper::upgrade::Upgraded;
// use std::net::SocketAddr;
// use tokio::net::TcpStream;
// use tower::{make::Shared, ServiceExt};
// use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
if let Some(host_addr) = req.uri().authority().map(|auth| auth.to_string()) {
tokio::task::spawn(async move {
match hyper::upgrade::on(req).await {
Ok(upgraded) => {
if let Err(e) = tunnel(upgraded, host_addr).await {
tracing::warn!("server io error: {}", e);
};
}
Err(e) => tracing::warn!("upgrade error: {}", e),
}
});
// #[tokio::main]
// async fn main() {
// tracing_subscriber::registry()
// .with(
// tracing_subscriber::EnvFilter::try_from_default_env()
// .unwrap_or_else(|_| "example_http_proxy=trace,tower_http=debug".into()),
// )
// .with(tracing_subscriber::fmt::layer())
// .init();
Ok(Response::new(Body::empty()))
} else {
tracing::warn!("CONNECT host is not socket addr: {:?}", req.uri());
Ok((
StatusCode::BAD_REQUEST,
"CONNECT must be to a socket address",
)
.into_response())
}
}
// let router_svc = Router::new().route("/", get(|| async { "Hello, World!" }));
async fn tunnel(mut upgraded: Upgraded, addr: String) -> std::io::Result<()> {
let mut server = TcpStream::connect(addr).await?;
// let service = tower::service_fn(move |req: Request<_>| {
// let router_svc = router_svc.clone();
// let req = req.map(Body::new);
// async move {
// if req.method() == Method::CONNECT {
// proxy(req).await
// } else {
// router_svc.oneshot(req).await.map_err(|err| match err {})
// }
// }
// });
let (from_client, from_server) =
tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?;
// let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
// tracing::debug!("listening on {}", addr);
// hyper::Server::bind(&addr)
// .http1_preserve_header_case(true)
// .http1_title_case_headers(true)
// .serve(Shared::new(service))
// .await
// .unwrap();
// }
tracing::debug!(
"client wrote {} bytes and received {} bytes",
from_client,
from_server
);
// async fn proxy(req: Request) -> Result<Response, hyper::Error> {
// tracing::trace!(?req);
Ok(())
}
// if let Some(host_addr) = req.uri().authority().map(|auth| auth.to_string()) {
// tokio::task::spawn(async move {
// match hyper::upgrade::on(req).await {
// Ok(upgraded) => {
// if let Err(e) = tunnel(upgraded, host_addr).await {
// tracing::warn!("server io error: {}", e);
// };
// }
// Err(e) => tracing::warn!("upgrade error: {}", e),
// }
// });
// Ok(Response::new(Body::empty()))
// } else {
// tracing::warn!("CONNECT host is not socket addr: {:?}", req.uri());
// Ok((
// StatusCode::BAD_REQUEST,
// "CONNECT must be to a socket address",
// )
// .into_response())
// }
// }
// async fn tunnel(mut upgraded: Upgraded, addr: String) -> std::io::Result<()> {
// let mut server = TcpStream::connect(addr).await?;
// let (from_client, from_server) =
// tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?;
// tracing::debug!(
// "client wrote {} bytes and received {} bytes",
// from_client,
// from_server
// );
// Ok(())
// }

View file

@ -5,56 +5,61 @@
//! listen on both IPv4 and IPv6 when the IPv6 catch-all listener is used (`::`),
//! [like older versions of Windows.](https://docs.microsoft.com/en-us/windows/win32/winsock/dual-stack-sockets)
use axum::{routing::get, Router};
use hyper::server::{accept::Accept, conn::AddrIncoming};
use std::{
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
pin::Pin,
task::{Context, Poll},
};
#[tokio::main]
async fn main() {
let app = Router::new().route("/", get(|| async { "Hello, World!" }));
let localhost_v4 = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8080);
let incoming_v4 = AddrIncoming::bind(&localhost_v4).unwrap();
let localhost_v6 = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 8080);
let incoming_v6 = AddrIncoming::bind(&localhost_v6).unwrap();
let combined = CombinedIncoming {
a: incoming_v4,
b: incoming_v6,
};
hyper::Server::builder(combined)
.serve(app.into_make_service())
.await
.unwrap();
// TODO
fn main() {
eprint!("this example has not yet been updated to hyper 1.0");
}
struct CombinedIncoming {
a: AddrIncoming,
b: AddrIncoming,
}
// use axum::{routing::get, Router};
// use hyper::server::{accept::Accept, conn::AddrIncoming};
// use std::{
// net::{Ipv4Addr, Ipv6Addr, SocketAddr},
// pin::Pin,
// task::{Context, Poll},
// };
impl Accept for CombinedIncoming {
type Conn = <AddrIncoming as Accept>::Conn;
type Error = <AddrIncoming as Accept>::Error;
// #[tokio::main]
// async fn main() {
// let app = Router::new().route("/", get(|| async { "Hello, World!" }));
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
if let Poll::Ready(Some(value)) = Pin::new(&mut self.a).poll_accept(cx) {
return Poll::Ready(Some(value));
}
// let localhost_v4 = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8080);
// let incoming_v4 = AddrIncoming::bind(&localhost_v4).unwrap();
if let Poll::Ready(Some(value)) = Pin::new(&mut self.b).poll_accept(cx) {
return Poll::Ready(Some(value));
}
// let localhost_v6 = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 8080);
// let incoming_v6 = AddrIncoming::bind(&localhost_v6).unwrap();
Poll::Pending
}
}
// let combined = CombinedIncoming {
// a: incoming_v4,
// b: incoming_v6,
// };
// hyper::Server::builder(combined)
// .serve(app.into_make_service())
// .await
// .unwrap();
// }
// struct CombinedIncoming {
// a: AddrIncoming,
// b: AddrIncoming,
// }
// impl Accept for CombinedIncoming {
// type Conn = <AddrIncoming as Accept>::Conn;
// type Error = <AddrIncoming as Accept>::Error;
// fn poll_accept(
// mut self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
// if let Poll::Ready(Some(value)) = Pin::new(&mut self.a).poll_accept(cx) {
// return Poll::Ready(Some(value));
// }
// if let Poll::Ready(Some(value)) = Pin::new(&mut self.b).poll_accept(cx) {
// return Poll::Ready(Some(value));
// }
// Poll::Pending
// }
// }

View file

@ -1,88 +1,90 @@
use openssl::ssl::{Ssl, SslAcceptor, SslFiletype, SslMethod};
use tokio_openssl::SslStream;
use axum::{body::Body, extract::ConnectInfo, http::Request, routing::get, Router};
use futures_util::future::poll_fn;
use hyper::server::{
accept::Accept,
conn::{AddrIncoming, Http},
};
use std::{net::SocketAddr, path::PathBuf, pin::Pin, sync::Arc};
use tokio::net::TcpListener;
use tower::MakeService;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "example_low_level_openssl=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls()).unwrap();
tls_builder
.set_certificate_file(
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("self_signed_certs")
.join("cert.pem"),
SslFiletype::PEM,
)
.unwrap();
tls_builder
.set_private_key_file(
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("self_signed_certs")
.join("key.pem"),
SslFiletype::PEM,
)
.unwrap();
tls_builder.check_private_key().unwrap();
let acceptor = tls_builder.build();
let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap();
let mut listener = AddrIncoming::from_listener(listener).unwrap();
let protocol = Arc::new(Http::new());
let mut app = Router::new()
.route("/", get(handler))
.into_make_service_with_connect_info::<SocketAddr>();
tracing::info!("listening on https://localhost:3000");
loop {
let stream = poll_fn(|cx| Pin::new(&mut listener).poll_accept(cx))
.await
.unwrap()
.unwrap();
let acceptor = acceptor.clone();
let protocol = protocol.clone();
let svc = MakeService::<_, Request<Body>>::make_service(&mut app, &stream);
tokio::spawn(async move {
let ssl = Ssl::new(acceptor.context()).unwrap();
let mut tls_stream = SslStream::new(ssl, stream).unwrap();
SslStream::accept(Pin::new(&mut tls_stream)).await.unwrap();
let _ = protocol
.serve_connection(tls_stream, svc.await.unwrap())
.await;
});
}
// TODO
fn main() {
eprint!("this example has not yet been updated to hyper 1.0");
}
async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
addr.to_string()
}
// use openssl::ssl::{Ssl, SslAcceptor, SslFiletype, SslMethod};
// use tokio_openssl::SslStream;
// use axum::{body::Body, extract::ConnectInfo, http::Request, routing::get, Router};
// use futures_util::future::poll_fn;
// use std::{net::SocketAddr, path::PathBuf, pin::Pin, sync::Arc};
// use tokio::net::TcpListener;
// use tower::MakeService;
// use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
// #[tokio::main]
// async fn main() {
// tracing_subscriber::registry()
// .with(
// tracing_subscriber::EnvFilter::try_from_default_env()
// .unwrap_or_else(|_| "example_low_level_openssl=debug".into()),
// )
// .with(tracing_subscriber::fmt::layer())
// .init();
// let mut tls_builder = SslAcceptor::mozilla_modern_v5(SslMethod::tls()).unwrap();
// tls_builder
// .set_certificate_file(
// PathBuf::from(env!("CARGO_MANIFEST_DIR"))
// .join("self_signed_certs")
// .join("cert.pem"),
// SslFiletype::PEM,
// )
// .unwrap();
// tls_builder
// .set_private_key_file(
// PathBuf::from(env!("CARGO_MANIFEST_DIR"))
// .join("self_signed_certs")
// .join("key.pem"),
// SslFiletype::PEM,
// )
// .unwrap();
// tls_builder.check_private_key().unwrap();
// let acceptor = tls_builder.build();
// let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap();
// let mut listener = AddrIncoming::from_listener(listener).unwrap();
// let protocol = Arc::new(Http::new());
// let mut app = Router::new()
// .route("/", get(handler))
// .into_make_service_with_connect_info::<SocketAddr>();
// tracing::info!("listening on https://localhost:3000");
// loop {
// let stream = poll_fn(|cx| Pin::new(&mut listener).poll_accept(cx))
// .await
// .unwrap()
// .unwrap();
// let acceptor = acceptor.clone();
// let protocol = protocol.clone();
// let svc = MakeService::<_, Request<Body>>::make_service(&mut app, &stream);
// tokio::spawn(async move {
// let ssl = Ssl::new(acceptor.context()).unwrap();
// let mut tls_stream = SslStream::new(ssl, stream).unwrap();
// SslStream::accept(Pin::new(&mut tls_stream)).await.unwrap();
// let _ = protocol
// .serve_connection(tls_stream, svc.await.unwrap())
// .await;
// });
// }
// }
// #[allow(dead_code)]
// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
// addr.to_string()
// }

View file

@ -4,101 +4,104 @@
//! cargo run -p example-low-level-rustls
//! ```
use axum::{extract::ConnectInfo, extract::Request, routing::get, Router};
use futures_util::future::poll_fn;
use hyper::server::{
accept::Accept,
conn::{AddrIncoming, Http},
};
use rustls_pemfile::{certs, pkcs8_private_keys};
use std::{
fs::File,
io::BufReader,
net::SocketAddr,
path::{Path, PathBuf},
pin::Pin,
sync::Arc,
};
use tokio::net::TcpListener;
use tokio_rustls::{
rustls::{Certificate, PrivateKey, ServerConfig},
TlsAcceptor,
};
use tower::make::MakeService;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "example_tls_rustls=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let rustls_config = rustls_server_config(
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("self_signed_certs")
.join("key.pem"),
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("self_signed_certs")
.join("cert.pem"),
);
let acceptor = TlsAcceptor::from(rustls_config);
let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap();
let mut listener = AddrIncoming::from_listener(listener).unwrap();
let protocol = Arc::new(Http::new());
let mut app = Router::<()>::new()
.route("/", get(handler))
.into_make_service_with_connect_info::<SocketAddr>();
loop {
let stream = poll_fn(|cx| Pin::new(&mut listener).poll_accept(cx))
.await
.unwrap()
.unwrap();
let acceptor = acceptor.clone();
let protocol = protocol.clone();
let svc = MakeService::<_, Request<hyper::Body>>::make_service(&mut app, &stream);
tokio::spawn(async move {
if let Ok(stream) = acceptor.accept(stream).await {
let _ = protocol.serve_connection(stream, svc.await.unwrap()).await;
}
});
}
// TODO
fn main() {
eprint!("this example has not yet been updated to hyper 1.0");
}
async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
addr.to_string()
}
// use axum::{extract::ConnectInfo, extract::Request, routing::get, Router};
// use futures_util::future::poll_fn;
// use rustls_pemfile::{certs, pkcs8_private_keys};
// use std::{
// fs::File,
// io::BufReader,
// net::SocketAddr,
// path::{Path, PathBuf},
// pin::Pin,
// sync::Arc,
// };
// use tokio::net::TcpListener;
// use tokio_rustls::{
// rustls::{Certificate, PrivateKey, ServerConfig},
// TlsAcceptor,
// };
// use tower::make::MakeService;
// use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
fn rustls_server_config(key: impl AsRef<Path>, cert: impl AsRef<Path>) -> Arc<ServerConfig> {
let mut key_reader = BufReader::new(File::open(key).unwrap());
let mut cert_reader = BufReader::new(File::open(cert).unwrap());
// #[tokio::main]
// async fn main() {
// tracing_subscriber::registry()
// .with(
// tracing_subscriber::EnvFilter::try_from_default_env()
// .unwrap_or_else(|_| "example_tls_rustls=debug".into()),
// )
// .with(tracing_subscriber::fmt::layer())
// .init();
let key = PrivateKey(pkcs8_private_keys(&mut key_reader).unwrap().remove(0));
let certs = certs(&mut cert_reader)
.unwrap()
.into_iter()
.map(Certificate)
.collect();
// let rustls_config = rustls_server_config(
// PathBuf::from(env!("CARGO_MANIFEST_DIR"))
// .join("self_signed_certs")
// .join("key.pem"),
// PathBuf::from(env!("CARGO_MANIFEST_DIR"))
// .join("self_signed_certs")
// .join("cert.pem"),
// );
let mut config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, key)
.expect("bad certificate/key");
// let acceptor = TlsAcceptor::from(rustls_config);
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
// let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap();
// let mut listener = AddrIncoming::from_listener(listener).unwrap();
Arc::new(config)
}
// let protocol = Arc::new(Http::new());
// let mut app = Router::<()>::new()
// .route("/", get(handler))
// .into_make_service_with_connect_info::<SocketAddr>();
// loop {
// let stream = poll_fn(|cx| Pin::new(&mut listener).poll_accept(cx))
// .await
// .unwrap()
// .unwrap();
// let acceptor = acceptor.clone();
// let protocol = protocol.clone();
// let svc = MakeService::<_, Request<hyper::Body>>::make_service(&mut app, &stream);
// tokio::spawn(async move {
// if let Ok(stream) = acceptor.accept(stream).await {
// let _ = protocol.serve_connection(stream, svc.await.unwrap()).await;
// }
// });
// }
// }
// #[allow(dead_code)]
// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
// addr.to_string()
// }
// #[allow(dead_code)]
// fn rustls_server_config(key: impl AsRef<Path>, cert: impl AsRef<Path>) -> Arc<ServerConfig> {
// let mut key_reader = BufReader::new(File::open(key).unwrap());
// let mut cert_reader = BufReader::new(File::open(cert).unwrap());
// let key = PrivateKey(pkcs8_private_keys(&mut key_reader).unwrap().remove(0));
// let certs = certs(&mut cert_reader)
// .unwrap()
// .into_iter()
// .map(Certificate)
// .collect();
// let mut config = ServerConfig::builder()
// .with_safe_defaults()
// .with_no_client_auth()
// .with_single_cert(certs, key)
// .expect("bad certificate/key");
// config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
// Arc::new(config)
// }

View file

@ -6,6 +6,7 @@ publish = false
[dependencies]
axum = { path = "../../axum" }
http-body-util = "0.1.0-rc.2"
hyper = { version = "1.0.0-rc.3", features = ["full"] }
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.4", features = ["util", "filter"] }

View file

@ -13,6 +13,7 @@ use axum::{
routing::post,
Router,
};
use http_body_util::BodyExt;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
@ -58,8 +59,8 @@ where
B: axum::body::HttpBody<Data = Bytes>,
B::Error: std::fmt::Display,
{
let bytes = match hyper::body::to_bytes(body).await {
Ok(bytes) => bytes,
let bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(err) => {
return Err((
StatusCode::BAD_REQUEST,

View file

@ -6,6 +6,7 @@ publish = false
[dependencies]
axum = { path = "../../axum" }
http-body-util = "0.1.0-rc.2"
hyper = "1.0.0-rc.3"
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.0", features = ["full"] }

View file

@ -58,6 +58,7 @@ where
mod tests {
use super::*;
use axum::{body::Body, http::Request};
use http_body_util::BodyExt;
use tower::ServiceExt;
#[tokio::test]
@ -114,7 +115,7 @@ mod tests {
.await
.unwrap()
.into_body();
let bytes = hyper::body::to_bytes(body).await.unwrap();
let bytes = body.collect().await.unwrap().to_bytes();
String::from_utf8(bytes.to_vec()).unwrap()
}
}

View file

@ -4,69 +4,74 @@
//! cargo run -p example-rest-grpc-multiplex
//! ```
use self::multiplex_service::MultiplexService;
use axum::{routing::get, Router};
use proto::{
greeter_server::{Greeter, GreeterServer},
HelloReply, HelloRequest,
};
use std::net::SocketAddr;
use tonic::{Response as TonicResponse, Status};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
mod multiplex_service;
mod proto {
tonic::include_proto!("helloworld");
// TODO: updating this example requires updating tonic
fn main() {
eprint!("this example has not yet been updated to hyper 1.0");
}
#[derive(Default)]
struct GrpcServiceImpl {}
// use self::multiplex_service::MultiplexService;
// use axum::{routing::get, Router};
// use proto::{
// greeter_server::{Greeter, GreeterServer},
// HelloReply, HelloRequest,
// };
// use std::net::SocketAddr;
// use tonic::{Response as TonicResponse, Status};
// use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tonic::async_trait]
impl Greeter for GrpcServiceImpl {
async fn say_hello(
&self,
request: tonic::Request<HelloRequest>,
) -> Result<TonicResponse<HelloReply>, Status> {
tracing::info!("Got a request from {:?}", request.remote_addr());
// mod multiplex_service;
let reply = HelloReply {
message: format!("Hello {}!", request.into_inner().name),
};
// mod proto {
// tonic::include_proto!("helloworld");
// }
Ok(TonicResponse::new(reply))
}
}
// #[derive(Default)]
// struct GrpcServiceImpl {}
async fn web_root() -> &'static str {
"Hello, World!"
}
// #[tonic::async_trait]
// impl Greeter for GrpcServiceImpl {
// async fn say_hello(
// &self,
// request: tonic::Request<HelloRequest>,
// ) -> Result<TonicResponse<HelloReply>, Status> {
// tracing::info!("Got a request from {:?}", request.remote_addr());
#[tokio::main]
async fn main() {
// initialize tracing
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "example_rest_grpc_multiplex=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
// let reply = HelloReply {
// message: format!("Hello {}!", request.into_inner().name),
// };
// build the rest service
let rest = Router::new().route("/", get(web_root));
// Ok(TonicResponse::new(reply))
// }
// }
// build the grpc service
let grpc = GreeterServer::new(GrpcServiceImpl::default());
// async fn web_root() -> &'static str {
// "Hello, World!"
// }
// combine them into one service
let service = MultiplexService::new(rest, grpc);
// #[tokio::main]
// async fn main() {
// // initialize tracing
// tracing_subscriber::registry()
// .with(
// tracing_subscriber::EnvFilter::try_from_default_env()
// .unwrap_or_else(|_| "example_rest_grpc_multiplex=debug".into()),
// )
// .with(tracing_subscriber::fmt::layer())
// .init();
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
hyper::Server::bind(&addr)
.serve(tower::make::Shared::new(service))
.await
.unwrap();
}
// // build the rest service
// let rest = Router::new().route("/", get(web_root));
// // build the grpc service
// let grpc = GreeterServer::new(GrpcServiceImpl::default());
// // combine them into one service
// let service = MultiplexService::new(rest, grpc);
// let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
// tracing::debug!("listening on {}", addr);
// hyper::Server::bind(&addr)
// .serve(tower::make::Shared::new(service))
// .await
// .unwrap();
// }

View file

@ -7,54 +7,59 @@
//! cargo run -p example-reverse-proxy
//! ```
use axum::{
body::Body,
extract::{Request, State},
http::uri::Uri,
response::{IntoResponse, Response},
routing::get,
Router,
};
use hyper::client::HttpConnector;
type Client = hyper::client::Client<HttpConnector, Body>;
#[tokio::main]
async fn main() {
tokio::spawn(server());
let client: Client = hyper::Client::builder().build(HttpConnector::new());
let app = Router::new().route("/", get(handler)).with_state(client);
let listener = tokio::net::TcpListener::bind("127.0.0.1:4000")
.await
.unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
// TODO
fn main() {
eprint!("this example has not yet been updated to hyper 1.0");
}
async fn handler(State(client): State<Client>, mut req: Request) -> Response {
let path = req.uri().path();
let path_query = req
.uri()
.path_and_query()
.map(|v| v.as_str())
.unwrap_or(path);
// use axum::{
// body::Body,
// extract::{Request, State},
// http::uri::Uri,
// response::{IntoResponse, Response},
// routing::get,
// Router,
// };
// // use hyper::client::HttpConnector;
let uri = format!("http://127.0.0.1:3000{}", path_query);
// // type Client = hyper::client::Client<HttpConnector, Body>;
*req.uri_mut() = Uri::try_from(uri).unwrap();
// #[tokio::main]
// async fn main() {
// tokio::spawn(server());
client.request(req).await.unwrap().into_response()
}
// let client: Client = hyper::Client::builder().build(HttpConnector::new());
async fn server() {
let app = Router::new().route("/", get(|| async { "Hello, world!" }));
// let app = Router::new().route("/", get(handler)).with_state(client);
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
// let listener = tokio::net::TcpListener::bind("127.0.0.1:4000")
// .await
// .unwrap();
// println!("listening on {}", listener.local_addr().unwrap());
// axum::serve(listener, app).await.unwrap();
// }
// async fn handler(State(client): State<Client>, mut req: Request) -> Response {
// let path = req.uri().path();
// let path_query = req
// .uri()
// .path_and_query()
// .map(|v| v.as_str())
// .unwrap_or(path);
// let uri = format!("http://127.0.0.1:3000{}", path_query);
// *req.uri_mut() = Uri::try_from(uri).unwrap();
// client.request(req).await.unwrap().into_response()
// }
// async fn server() {
// let app = Router::new().route("/", get(|| async { "Hello, world!" }));
// let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
// .await
// .unwrap();
// println!("listening on {}", listener.local_addr().unwrap());
// axum::serve(listener, app).await.unwrap();
// }

View file

@ -6,6 +6,7 @@ publish = false
[dependencies]
axum = { path = "../../axum" }
http-body-util = "0.1.0-rc.2"
hyper = { version = "1.0.0-rc.3", features = ["full"] }
mime = "0.3"
serde_json = "1.0"

View file

@ -59,11 +59,12 @@ mod tests {
extract::connect_info::MockConnectInfo,
http::{self, Request, StatusCode},
};
use http_body_util::BodyExt;
use serde_json::{json, Value};
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tower::Service; // for `call`
use tower::ServiceExt; // for `oneshot` and `ready`
use tower::ServiceExt; // for `oneshot` and `ready` // for `collect`
#[tokio::test]
async fn hello_world() {
@ -78,7 +79,7 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"Hello, World!");
}
@ -102,7 +103,7 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
let body = response.into_body().collect().await.unwrap().to_bytes();
let body: Value = serde_json::from_slice(&body).unwrap();
assert_eq!(body, json!({ "data": [1, 2, 3, 4] }));
}
@ -122,34 +123,36 @@ mod tests {
.unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
let body = response.into_body().collect().await.unwrap().to_bytes();
assert!(body.is_empty());
}
// You can also spawn a server and talk to it like any other HTTP server:
#[tokio::test]
async fn the_real_deal() {
let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
let addr = listener.local_addr().unwrap();
todo!();
tokio::spawn(async move {
axum::serve(listener, app()).await.unwrap();
});
// let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
// let addr = listener.local_addr().unwrap();
let client = hyper::Client::new();
// tokio::spawn(async move {
// axum::serve(listener, app()).await.unwrap();
// });
let response = client
.request(
Request::builder()
.uri(format!("http://{}", addr))
.body(hyper::Body::empty())
.unwrap(),
)
.await
.unwrap();
// let client = hyper::Client::new();
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
assert_eq!(&body[..], b"Hello, World!");
// let response = client
// .request(
// Request::builder()
// .uri(format!("http://{}", addr))
// .body(axum::Body::empty())
// .unwrap(),
// )
// .await
// .unwrap();
// let body = response.into_body().collect().await.unwrap().to_bytes();
// assert_eq!(&body[..], b"Hello, World!");
}
// You can use `ready()` and `call()` to avoid using `clone()`

View file

@ -4,6 +4,8 @@
//! cargo run -p example-tls-rustls
//! ```
#![allow(unused_imports)]
use axum::{
extract::Host,
handler::HandlerWithoutStateExt,
@ -16,6 +18,7 @@ use axum_server::tls_rustls::RustlsConfig;
use std::{net::SocketAddr, path::PathBuf};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[allow(dead_code)]
#[derive(Clone, Copy)]
struct Ports {
http: u16,
@ -24,48 +27,52 @@ struct Ports {
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "example_tls_rustls=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
// Updating this example to hyper 1.0 requires axum_server to update first
let ports = Ports {
http: 7878,
https: 3000,
};
// optional: spawn a second server to redirect http requests to this server
tokio::spawn(redirect_http_to_https(ports));
// tracing_subscriber::registry()
// .with(
// tracing_subscriber::EnvFilter::try_from_default_env()
// .unwrap_or_else(|_| "example_tls_rustls=debug".into()),
// )
// .with(tracing_subscriber::fmt::layer())
// .init();
// configure certificate and private key used by https
let config = RustlsConfig::from_pem_file(
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("self_signed_certs")
.join("cert.pem"),
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("self_signed_certs")
.join("key.pem"),
)
.await
.unwrap();
// let ports = Ports {
// http: 7878,
// https: 3000,
// };
// // optional: spawn a second server to redirect http requests to this server
// tokio::spawn(redirect_http_to_https(ports));
let app = Router::new().route("/", get(handler));
// // configure certificate and private key used by https
// let config = RustlsConfig::from_pem_file(
// PathBuf::from(env!("CARGO_MANIFEST_DIR"))
// .join("self_signed_certs")
// .join("cert.pem"),
// PathBuf::from(env!("CARGO_MANIFEST_DIR"))
// .join("self_signed_certs")
// .join("key.pem"),
// )
// .await
// .unwrap();
// run https server
let addr = SocketAddr::from(([127, 0, 0, 1], ports.https));
tracing::debug!("listening on {}", addr);
axum_server::bind_rustls(addr, config)
.serve(app.into_make_service())
.await
.unwrap();
// let app = Router::new().route("/", get(handler));
// // run https server
// let addr = SocketAddr::from(([127, 0, 0, 1], ports.https));
// tracing::debug!("listening on {}", addr);
// axum_server::bind_rustls(addr, config)
// .await
// .unwrap();
}
#[allow(dead_code)]
async fn handler() -> &'static str {
"Hello, World!"
}
#[allow(dead_code)]
async fn redirect_http_to_https(ports: Ports) {
fn make_https(host: String, uri: Uri, ports: Ports) -> Result<Uri, BoxError> {
let mut parts = uri.into_parts();

View file

@ -4,178 +4,183 @@
//! cargo run -p example-unix-domain-socket
//! ```
#[cfg(unix)]
#[tokio::main]
async fn main() {
unix::server().await;
}
#[cfg(not(unix))]
// TODO
fn main() {
println!("This example requires unix")
eprint!("this example has not yet been updated to hyper 1.0");
}
#[cfg(unix)]
mod unix {
use axum::{
body::Body,
extract::connect_info::{self, ConnectInfo},
http::{Method, Request, StatusCode, Uri},
routing::get,
Router,
};
use futures::ready;
use hyper::{
client::connect::{Connected, Connection},
server::accept::Accept,
};
use std::{
io,
path::PathBuf,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{unix::UCred, UnixListener, UnixStream},
};
use tower::BoxError;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
// #[cfg(unix)]
// #[tokio::main]
// async fn main() {
// unix::server().await;
// }
pub async fn server() {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
// #[cfg(not(unix))]
// fn main() {
// println!("This example requires unix")
// }
let path = PathBuf::from("/tmp/axum/helloworld");
// #[cfg(unix)]
// mod unix {
// use axum::{
// body::Body,
// extract::connect_info::{self, ConnectInfo},
// http::{Method, Request, StatusCode, Uri},
// routing::get,
// Router,
// };
// use futures::ready;
// use hyper::{
// client::connect::{Connected, Connection},
// server::accept::Accept,
// };
// use std::{
// io,
// path::PathBuf,
// pin::Pin,
// sync::Arc,
// task::{Context, Poll},
// };
// use tokio::{
// io::{AsyncRead, AsyncWrite},
// net::{unix::UCred, UnixListener, UnixStream},
// };
// use tower::BoxError;
// use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
let _ = tokio::fs::remove_file(&path).await;
tokio::fs::create_dir_all(path.parent().unwrap())
.await
.unwrap();
// pub async fn server() {
// tracing_subscriber::registry()
// .with(
// tracing_subscriber::EnvFilter::try_from_default_env()
// .unwrap_or_else(|_| "debug".into()),
// )
// .with(tracing_subscriber::fmt::layer())
// .init();
let uds = UnixListener::bind(path.clone()).unwrap();
tokio::spawn(async {
let app = Router::new().route("/", get(handler));
// let path = PathBuf::from("/tmp/axum/helloworld");
hyper::Server::builder(ServerAccept { uds })
.serve(app.into_make_service_with_connect_info::<UdsConnectInfo>())
.await
.unwrap();
});
// let _ = tokio::fs::remove_file(&path).await;
// tokio::fs::create_dir_all(path.parent().unwrap())
// .await
// .unwrap();
let connector = tower::service_fn(move |_: Uri| {
let path = path.clone();
Box::pin(async move {
let stream = UnixStream::connect(path).await?;
Ok::<_, io::Error>(ClientConnection { stream })
})
});
let client = hyper::Client::builder().build(connector);
// let uds = UnixListener::bind(path.clone()).unwrap();
// tokio::spawn(async {
// let app = Router::new().route("/", get(handler));
let request = Request::builder()
.method(Method::GET)
.uri("http://uri-doesnt-matter.com")
.body(Body::empty())
.unwrap();
// hyper::Server::builder(ServerAccept { uds })
// .serve(app.into_make_service_with_connect_info::<UdsConnectInfo>())
// .await
// .unwrap();
// });
let response = client.request(request).await.unwrap();
// let connector = tower::service_fn(move |_: Uri| {
// let path = path.clone();
// Box::pin(async move {
// let stream = UnixStream::connect(path).await?;
// Ok::<_, io::Error>(ClientConnection { stream })
// })
// });
// let client = hyper::Client::builder().build(connector);
assert_eq!(response.status(), StatusCode::OK);
// let request = Request::builder()
// .method(Method::GET)
// .uri("http://uri-doesnt-matter.com")
// .body(Body::empty())
// .unwrap();
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
let body = String::from_utf8(body.to_vec()).unwrap();
assert_eq!(body, "Hello, World!");
}
// let response = client.request(request).await.unwrap();
async fn handler(ConnectInfo(info): ConnectInfo<UdsConnectInfo>) -> &'static str {
println!("new connection from `{:?}`", info);
// assert_eq!(response.status(), StatusCode::OK);
"Hello, World!"
}
// let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
// let body = String::from_utf8(body.to_vec()).unwrap();
// assert_eq!(body, "Hello, World!");
// }
struct ServerAccept {
uds: UnixListener,
}
// async fn handler(ConnectInfo(info): ConnectInfo<UdsConnectInfo>) -> &'static str {
// println!("new connection from `{:?}`", info);
impl Accept for ServerAccept {
type Conn = UnixStream;
type Error = BoxError;
// "Hello, World!"
// }
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let (stream, _addr) = ready!(self.uds.poll_accept(cx))?;
Poll::Ready(Some(Ok(stream)))
}
}
// struct ServerAccept {
// uds: UnixListener,
// }
struct ClientConnection {
stream: UnixStream,
}
// impl Accept for ServerAccept {
// type Conn = UnixStream;
// type Error = BoxError;
impl AsyncWrite for ClientConnection {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.stream).poll_write(cx, buf)
}
// fn poll_accept(
// self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
// let (stream, _addr) = ready!(self.uds.poll_accept(cx))?;
// Poll::Ready(Some(Ok(stream)))
// }
// }
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
// struct ClientConnection {
// stream: UnixStream,
// }
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.stream).poll_shutdown(cx)
}
}
// impl AsyncWrite for ClientConnection {
// fn poll_write(
// mut self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// buf: &[u8],
// ) -> Poll<Result<usize, io::Error>> {
// Pin::new(&mut self.stream).poll_write(cx, buf)
// }
impl AsyncRead for ClientConnection {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
}
}
// fn poll_flush(
// mut self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// ) -> Poll<Result<(), io::Error>> {
// Pin::new(&mut self.stream).poll_flush(cx)
// }
impl Connection for ClientConnection {
fn connected(&self) -> Connected {
Connected::new()
}
}
// fn poll_shutdown(
// mut self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// ) -> Poll<Result<(), io::Error>> {
// Pin::new(&mut self.stream).poll_shutdown(cx)
// }
// }
#[derive(Clone, Debug)]
#[allow(dead_code)]
struct UdsConnectInfo {
peer_addr: Arc<tokio::net::unix::SocketAddr>,
peer_cred: UCred,
}
// impl AsyncRead for ClientConnection {
// fn poll_read(
// mut self: Pin<&mut Self>,
// cx: &mut Context<'_>,
// buf: &mut tokio::io::ReadBuf<'_>,
// ) -> Poll<io::Result<()>> {
// Pin::new(&mut self.stream).poll_read(cx, buf)
// }
// }
impl connect_info::Connected<&UnixStream> for UdsConnectInfo {
fn connect_info(target: &UnixStream) -> Self {
let peer_addr = target.peer_addr().unwrap();
let peer_cred = target.peer_cred().unwrap();
// impl Connection for ClientConnection {
// fn connected(&self) -> Connected {
// Connected::new()
// }
// }
Self {
peer_addr: Arc::new(peer_addr),
peer_cred,
}
}
}
}
// #[derive(Clone, Debug)]
// #[allow(dead_code)]
// struct UdsConnectInfo {
// peer_addr: Arc<tokio::net::unix::SocketAddr>,
// peer_cred: UCred,
// }
// impl connect_info::Connected<&UnixStream> for UdsConnectInfo {
// fn connect_info(target: &UnixStream) -> Self {
// let peer_addr = target.peer_addr().unwrap();
// let peer_cred = target.peer_cred().unwrap();
// Self {
// peer_addr: Arc::new(peer_addr),
// peer_cred,
// }
// }
// }
// }