Use pin-project-lite instead of pin-project (#95)

This commit is contained in:
Sunli 2021-08-03 15:33:00 +08:00 committed by GitHub
parent ba74787532
commit be68227d73
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 268 additions and 236 deletions

View file

@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix stripping prefix when nesting services at `/` ([#91](https://github.com/tokio-rs/axum/pull/91))
- Add support for WebSocket protocol negotiation. ([#83](https://github.com/tokio-rs/axum/pull/83))
- Use `pin-project-lite` instead of `pin-project`. ([#95](https://github.com/tokio-rs/axum/pull/95))
## Breaking changes

View file

@ -24,7 +24,7 @@ futures-util = "0.3"
http = "0.2"
http-body = "0.4.2"
hyper = { version = "0.14", features = ["server", "tcp", "http1", "stream"] }
pin-project = "1.0"
pin-project-lite = "0.2.7"
regex = "1.5"
serde = "1.0"
serde_json = "1.0"

View file

@ -1,5 +1,5 @@
use futures_util::ready;
use pin_project::pin_project;
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
@ -119,23 +119,26 @@ where
});
ResponseFuture {
state: State::Channel(reply_rx),
state: State::Channel { reply_rx },
permit,
}
}
}
#[pin_project]
pub(crate) struct ResponseFuture<F, E> {
#[pin]
state: State<F, E>,
permit: OwnedSemaphorePermit,
pin_project! {
pub(crate) struct ResponseFuture<F, E> {
#[pin]
state: State<F, E>,
permit: OwnedSemaphorePermit,
}
}
#[pin_project(project = StateProj)]
enum State<F, E> {
Channel(oneshot::Receiver<WorkerReply<F, E>>),
Future(#[pin] F),
pin_project! {
#[project = StateProj]
enum State<F, E> {
Channel { reply_rx: oneshot::Receiver<WorkerReply<F, E>> },
Future { #[pin] future: F },
}
}
impl<F, E, T> Future for ResponseFuture<F, E>
@ -149,16 +152,16 @@ where
let mut this = self.as_mut().project();
let new_state = match this.state.as_mut().project() {
StateProj::Channel(reply_rx) => {
StateProj::Channel { reply_rx } => {
let msg = ready!(Pin::new(reply_rx).poll(cx))
.expect("buffer worker not running. This is a bug in axum and should never happen. Please file an issue");
match msg {
WorkerReply::Future(future) => State::Future(future),
WorkerReply::Future(future) => State::Future { future },
WorkerReply::Error(err) => return Poll::Ready(Err(err)),
}
}
StateProj::Future(future) => {
StateProj::Future { future } => {
return future.poll(cx);
}
};

View file

@ -89,7 +89,9 @@ where
fn call(&mut self, target: T) -> Self::Future {
let connect_info = ConnectInfo(C::connect_info(target));
let svc = AddExtension::new(self.svc.clone(), connect_info);
ResponseFuture(futures_util::future::ok(svc))
ResponseFuture {
future: futures_util::future::ok(svc),
}
}
}

View file

@ -7,7 +7,7 @@ use crate::{body::BoxBody, response::IntoResponse};
use bytes::Bytes;
use futures_util::{future::BoxFuture, ready};
use http::{Request, Response};
use pin_project::pin_project;
use pin_project_lite::pin_project;
use std::{
fmt,
future::Future,
@ -178,33 +178,38 @@ where
});
ExtractorMiddlewareResponseFuture {
state: State::Extracting(extract_future),
state: State::Extracting {
future: extract_future,
},
svc: Some(self.inner.clone()),
}
}
}
/// Response future for [`ExtractorMiddleware`].
#[allow(missing_debug_implementations)]
#[pin_project]
pub struct ExtractorMiddlewareResponseFuture<ReqBody, S, E>
where
E: FromRequest<ReqBody>,
S: Service<Request<ReqBody>>,
{
#[pin]
state: State<ReqBody, S, E>,
svc: Option<S>,
pin_project! {
/// Response future for [`ExtractorMiddleware`].
#[allow(missing_debug_implementations)]
pub struct ExtractorMiddlewareResponseFuture<ReqBody, S, E>
where
E: FromRequest<ReqBody>,
S: Service<Request<ReqBody>>,
{
#[pin]
state: State<ReqBody, S, E>,
svc: Option<S>,
}
}
#[pin_project(project = StateProj)]
enum State<ReqBody, S, E>
where
E: FromRequest<ReqBody>,
S: Service<Request<ReqBody>>,
{
Extracting(BoxFuture<'static, (RequestParts<ReqBody>, Result<E, E::Rejection>)>),
Call(#[pin] S::Future),
pin_project! {
#[project = StateProj]
enum State<ReqBody, S, E>
where
E: FromRequest<ReqBody>,
S: Service<Request<ReqBody>>,
{
Extracting { future: BoxFuture<'static, (RequestParts<ReqBody>, Result<E, E::Rejection>)> },
Call { #[pin] future: S::Future },
}
}
impl<ReqBody, S, E, ResBody> Future for ExtractorMiddlewareResponseFuture<ReqBody, S, E>
@ -221,14 +226,14 @@ where
let mut this = self.as_mut().project();
let new_state = match this.state.as_mut().project() {
StateProj::Extracting(future) => {
StateProj::Extracting { future } => {
let (mut req, extracted) = ready!(future.as_mut().poll(cx));
match extracted {
Ok(_) => {
let mut svc = this.svc.take().expect("future polled after completion");
let future = svc.call(req.into_request());
State::Call(future)
State::Call { future }
}
Err(err) => {
let res = err.into_response().map(crate::body::box_body);
@ -236,7 +241,7 @@ where
}
}
}
StateProj::Call(future) => {
StateProj::Call { future } => {
return future
.poll(cx)
.map(|result| result.map(|response| response.map(crate::body::box_body)));

View file

@ -444,7 +444,7 @@ where
let res = Handler::call(handler, req).await;
Ok(res)
});
future::IntoServiceFuture(future)
future::IntoServiceFuture { future }
}
}

View file

@ -9,10 +9,12 @@ macro_rules! opaque_future {
};
($(#[$m:meta])* pub type $name:ident<$($param:ident),*> = $actual:ty;) => {
#[pin_project::pin_project]
$(#[$m])*
pub struct $name<$($param),*>(#[pin] pub(crate) $actual)
where;
pin_project_lite::pin_project! {
$(#[$m])*
pub struct $name<$($param),*> {
#[pin] pub(crate) future: $actual,
}
}
impl<$($param),*> std::fmt::Debug for $name<$($param),*> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@ -27,7 +29,7 @@ macro_rules! opaque_future {
type Output = <$actual as std::future::Future>::Output;
#[inline]
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
self.project().0.poll(cx)
self.project().future.poll(cx)
}
}
};

View file

@ -11,7 +11,7 @@ use async_trait::async_trait;
use bytes::Bytes;
use futures_util::future;
use http::{Method, Request, Response, StatusCode, Uri};
use pin_project::pin_project;
use pin_project_lite::pin_project;
use regex::Regex;
use std::{
borrow::Cow,
@ -385,13 +385,16 @@ where
}
}
/// The response future for [`Route`].
#[pin_project]
#[derive(Debug)]
pub struct RouteFuture<S, F, B>(#[pin] RouteFutureInner<S, F, B>)
where
S: Service<Request<B>>,
F: Service<Request<B>>;
pin_project! {
/// The response future for [`Route`].
#[derive(Debug)]
pub struct RouteFuture<S, F, B>
where
S: Service<Request<B>>,
F: Service<Request<B>> {
#[pin] inner: RouteFutureInner<S, F, B>,
}
}
impl<S, F, B> RouteFuture<S, F, B>
where
@ -399,23 +402,29 @@ where
F: Service<Request<B>>,
{
pub(crate) fn a(a: Oneshot<S, Request<B>>) -> Self {
RouteFuture(RouteFutureInner::A(a))
RouteFuture {
inner: RouteFutureInner::A { a },
}
}
pub(crate) fn b(b: Oneshot<F, Request<B>>) -> Self {
RouteFuture(RouteFutureInner::B(b))
RouteFuture {
inner: RouteFutureInner::B { b },
}
}
}
#[pin_project(project = RouteFutureInnerProj)]
#[derive(Debug)]
enum RouteFutureInner<S, F, B>
where
S: Service<Request<B>>,
F: Service<Request<B>>,
{
A(#[pin] Oneshot<S, Request<B>>),
B(#[pin] Oneshot<F, Request<B>>),
pin_project! {
#[project = RouteFutureInnerProj]
#[derive(Debug)]
enum RouteFutureInner<S, F, B>
where
S: Service<Request<B>>,
F: Service<Request<B>>,
{
A { #[pin] a: Oneshot<S, Request<B>> },
B { #[pin] b: Oneshot<F, Request<B>> },
}
}
impl<S, F, B> Future for RouteFuture<S, F, B>
@ -426,9 +435,9 @@ where
type Output = Result<Response<BoxBody>, S::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().0.project() {
RouteFutureInnerProj::A(inner) => inner.poll(cx),
RouteFutureInnerProj::B(inner) => inner.poll(cx),
match self.project().inner.project() {
RouteFutureInnerProj::A { a } => a.poll(cx),
RouteFutureInnerProj::B { b } => b.poll(cx),
}
}
}
@ -510,7 +519,9 @@ impl<B, E> Service<Request<B>> for EmptyRouter<E> {
fn call(&mut self, _req: Request<B>) -> Self::Future {
let mut res = Response::new(crate::body::empty());
*res.status_mut() = self.status;
EmptyRouterFuture(future::ok(res))
EmptyRouterFuture {
future: future::ok(res),
}
}
}
@ -655,15 +666,14 @@ where
}
}
/// The response future for [`BoxRoute`].
#[pin_project]
pub struct BoxRouteFuture<B, E>
where
E: Into<BoxError>,
{
#[pin]
inner:
Oneshot<MpscBuffer<BoxService<Request<B>, Response<BoxBody>, E>, Request<B>>, Request<B>>,
pin_project! {
/// The response future for [`BoxRoute`].
pub struct BoxRouteFuture<B, E>
where
E: Into<BoxError>,
{
#[pin] inner: Oneshot<MpscBuffer<BoxService<Request<B>, Response<BoxBody>, E>, Request<B>>, Request<B>>,
}
}
impl<B, E> Future for BoxRouteFuture<B, E>

View file

@ -7,7 +7,7 @@ use crate::{
use bytes::Bytes;
use futures_util::ready;
use http::Response;
use pin_project::pin_project;
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
@ -15,13 +15,14 @@ use std::{
};
use tower::BoxError;
/// Response future for [`HandleError`](super::HandleError).
#[pin_project]
#[derive(Debug)]
pub struct HandleErrorFuture<Fut, F> {
#[pin]
pub(super) inner: Fut,
pub(super) f: Option<F>,
pin_project! {
/// Response future for [`HandleError`](super::HandleError).
#[derive(Debug)]
pub struct HandleErrorFuture<Fut, F> {
#[pin]
pub(super) inner: Fut,
pub(super) f: Option<F>,
}
}
impl<Fut, F, E, E2, B, Res> Future for HandleErrorFuture<Fut, F>

View file

@ -94,7 +94,7 @@ use crate::{
use bytes::Bytes;
use futures_util::ready;
use http::{Request, Response};
use pin_project::pin_project;
use pin_project_lite::pin_project;
use std::{
convert::Infallible,
fmt,
@ -645,14 +645,17 @@ where
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let fut = self.inner.clone().oneshot(req);
BoxResponseBodyFuture(fut)
BoxResponseBodyFuture { future: fut }
}
}
/// Response future for [`BoxResponseBody`].
#[pin_project]
#[derive(Debug)]
pub struct BoxResponseBodyFuture<F>(#[pin] F);
pin_project! {
/// Response future for [`BoxResponseBody`].
#[derive(Debug)]
pub struct BoxResponseBodyFuture<F> {
#[pin] future: F,
}
}
impl<F, B, E> Future for BoxResponseBodyFuture<F>
where
@ -663,7 +666,7 @@ where
type Output = Result<Response<BoxBody>, E>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = ready!(self.project().0.poll(cx))?;
let res = ready!(self.project().future.poll(cx))?;
let res = res.map(box_body);
Poll::Ready(Ok(res))
}

View file

@ -81,7 +81,7 @@ use futures_util::{
};
use http::{Request, Response};
use hyper::Body;
use pin_project::pin_project;
use pin_project_lite::pin_project;
use serde::Serialize;
use std::{
borrow::Cow,
@ -245,49 +245,51 @@ where
let handler = self.handler.clone();
let keep_alive = self.keep_alive.clone();
ResponseFuture(Box::pin(async move {
let mut req = RequestParts::new(req);
let input = match T::from_request(&mut req).await {
Ok(input) => input,
Err(err) => {
return Ok(err.into_response().map(box_body));
}
};
ResponseFuture {
future: Box::pin(async move {
let mut req = RequestParts::new(req);
let input = match T::from_request(&mut req).await {
Ok(input) => input,
Err(err) => {
return Ok(err.into_response().map(box_body));
}
};
let stream = match handler.call(input).await {
Ok(stream) => stream,
Err(err) => {
return Ok(err.into_response().map(box_body));
}
};
let stream = match handler.call(input).await {
Ok(stream) => stream,
Err(err) => {
return Ok(err.into_response().map(box_body));
}
};
let stream = if let Some(keep_alive) = keep_alive {
KeepAliveStream {
event_stream: stream,
comment_text: keep_alive.comment_text,
max_interval: keep_alive.max_interval,
alive_timer: tokio::time::sleep(keep_alive.max_interval),
}
.left_stream()
} else {
stream.into_stream().right_stream()
};
let stream = if let Some(keep_alive) = keep_alive {
KeepAliveStream {
event_stream: stream,
comment_text: keep_alive.comment_text,
max_interval: keep_alive.max_interval,
alive_timer: tokio::time::sleep(keep_alive.max_interval),
}
.left_stream()
} else {
stream.into_stream().right_stream()
};
let stream = stream
.map_ok(|event| event.to_string())
.map_err(|err| BoxStdError(err.into()))
.into_stream();
let stream = stream
.map_ok(|event| event.to_string())
.map_err(|err| BoxStdError(err.into()))
.into_stream();
let body = box_body(Body::wrap_stream(stream));
let body = box_body(Body::wrap_stream(stream));
let response = Response::builder()
.header(http::header::CONTENT_TYPE, "text/event-stream")
.header(http::header::CACHE_CONTROL, "no-cache")
.body(body)
.unwrap();
let response = Response::builder()
.header(http::header::CONTENT_TYPE, "text/event-stream")
.header(http::header::CACHE_CONTROL, "no-cache")
.body(body)
.unwrap();
Ok(response)
}))
Ok(response)
}),
}
}
}
@ -483,14 +485,15 @@ impl Default for KeepAlive {
}
}
#[pin_project]
struct KeepAliveStream<S> {
#[pin]
event_stream: S,
comment_text: Cow<'static, str>,
max_interval: Duration,
#[pin]
alive_timer: Sleep,
pin_project! {
struct KeepAliveStream<S> {
#[pin]
event_stream: S,
comment_text: Cow<'static, str>,
max_interval: Duration,
#[pin]
alive_timer: Sleep,
}
}
impl<S> Stream for KeepAliveStream<S>

View file

@ -277,110 +277,112 @@ where
let this = self.clone();
let protocols = self.protocols.clone();
ResponseFuture(Box::pin(async move {
if req.method() != http::Method::GET {
return response(StatusCode::NOT_FOUND, "Request method must be `GET`");
}
ResponseFuture {
future: Box::pin(async move {
if req.method() != http::Method::GET {
return response(StatusCode::NOT_FOUND, "Request method must be `GET`");
}
if !header_contains(&req, header::CONNECTION, "upgrade") {
return response(
StatusCode::BAD_REQUEST,
"Connection header did not include 'upgrade'",
);
}
if !header_contains(&req, header::CONNECTION, "upgrade") {
return response(
StatusCode::BAD_REQUEST,
"Connection header did not include 'upgrade'",
);
}
if !header_eq(&req, header::UPGRADE, "websocket") {
return response(
StatusCode::BAD_REQUEST,
"`Upgrade` header did not include 'websocket'",
);
}
if !header_eq(&req, header::UPGRADE, "websocket") {
return response(
StatusCode::BAD_REQUEST,
"`Upgrade` header did not include 'websocket'",
);
}
if !header_eq(&req, header::SEC_WEBSOCKET_VERSION, "13") {
return response(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Version` header did not include '13'",
);
}
if !header_eq(&req, header::SEC_WEBSOCKET_VERSION, "13") {
return response(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Version` header did not include '13'",
);
}
// check requested protocols
let protocol =
req.headers()
.get(&header::SEC_WEBSOCKET_PROTOCOL)
.and_then(|req_protocols| {
let req_protocols = req_protocols.to_str().ok()?;
req_protocols
.split(',')
.map(|req_p| req_p.trim())
.find(|req_p| protocols.iter().any(|p| p == req_p))
});
let protocol = match protocol {
Some(protocol) => {
if let Ok(protocol) = HeaderValue::from_str(protocol) {
Some(protocol)
} else {
return response(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Protocol` header is invalid",
);
// check requested protocols
let protocol =
req.headers()
.get(&header::SEC_WEBSOCKET_PROTOCOL)
.and_then(|req_protocols| {
let req_protocols = req_protocols.to_str().ok()?;
req_protocols
.split(',')
.map(|req_p| req_p.trim())
.find(|req_p| protocols.iter().any(|p| p == req_p))
});
let protocol = match protocol {
Some(protocol) => {
if let Ok(protocol) = HeaderValue::from_str(protocol) {
Some(protocol)
} else {
return response(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Protocol` header is invalid",
);
}
}
None => None,
};
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
key
} else {
return response(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Key` header missing",
);
};
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
let config = this.config;
let callback = this.callback.clone();
let mut req = RequestParts::new(req);
let input = match T::from_request(&mut req).await {
Ok(input) => input,
Err(rejection) => {
let res = rejection.into_response().map(box_body);
return Ok(res);
}
};
tokio::spawn(async move {
let upgraded = on_upgrade.await.unwrap();
let socket = WebSocketStream::from_raw_socket(
upgraded,
protocol::Role::Server,
Some(config),
)
.await;
let socket = WebSocket { inner: socket };
callback.call(socket, input).await;
});
let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(
http::header::CONNECTION,
HeaderValue::from_str("upgrade").unwrap(),
)
.header(
http::header::UPGRADE,
HeaderValue::from_str("websocket").unwrap(),
)
.header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()));
if let Some(protocol) = protocol {
builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
None => None,
};
let res = builder.body(box_body(Full::new(Bytes::new()))).unwrap();
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
key
} else {
return response(
StatusCode::BAD_REQUEST,
"`Sec-Websocket-Key` header missing",
);
};
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
let config = this.config;
let callback = this.callback.clone();
let mut req = RequestParts::new(req);
let input = match T::from_request(&mut req).await {
Ok(input) => input,
Err(rejection) => {
let res = rejection.into_response().map(box_body);
return Ok(res);
}
};
tokio::spawn(async move {
let upgraded = on_upgrade.await.unwrap();
let socket = WebSocketStream::from_raw_socket(
upgraded,
protocol::Role::Server,
Some(config),
)
.await;
let socket = WebSocket { inner: socket };
callback.call(socket, input).await;
});
let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(
http::header::CONNECTION,
HeaderValue::from_str("upgrade").unwrap(),
)
.header(
http::header::UPGRADE,
HeaderValue::from_str("websocket").unwrap(),
)
.header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()));
if let Some(protocol) = protocol {
builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
let res = builder.body(box_body(Full::new(Bytes::new()))).unwrap();
Ok(res)
}))
Ok(res)
}),
}
}
}