From be68227d739aece4210d4e7cb7d419da4d4afe30 Mon Sep 17 00:00:00 2001 From: Sunli <sunlipad4@icloud.com> Date: Tue, 3 Aug 2021 15:33:00 +0800 Subject: [PATCH] Use `pin-project-lite` instead of `pin-project` (#95) --- CHANGELOG.md | 1 + Cargo.toml | 2 +- src/buffer.rs | 31 +++-- src/extract/connect_info.rs | 4 +- src/extract/extractor_middleware.rs | 53 ++++---- src/handler/mod.rs | 2 +- src/macros.rs | 12 +- src/routing.rs | 74 ++++++----- src/service/future.rs | 17 +-- src/service/mod.rs | 17 ++- src/sse.rs | 95 +++++++------- src/ws/mod.rs | 196 ++++++++++++++-------------- 12 files changed, 268 insertions(+), 236 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f07c3e76..484f7e08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix stripping prefix when nesting services at `/` ([#91](https://github.com/tokio-rs/axum/pull/91)) - Add support for WebSocket protocol negotiation. ([#83](https://github.com/tokio-rs/axum/pull/83)) +- Use `pin-project-lite` instead of `pin-project`. ([#95](https://github.com/tokio-rs/axum/pull/95)) ## Breaking changes diff --git a/Cargo.toml b/Cargo.toml index bcf17d4f..4234e409 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ futures-util = "0.3" http = "0.2" http-body = "0.4.2" hyper = { version = "0.14", features = ["server", "tcp", "http1", "stream"] } -pin-project = "1.0" +pin-project-lite = "0.2.7" regex = "1.5" serde = "1.0" serde_json = "1.0" diff --git a/src/buffer.rs b/src/buffer.rs index c7134adc..4b85927c 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -1,5 +1,5 @@ use futures_util::ready; -use pin_project::pin_project; +use pin_project_lite::pin_project; use std::{ future::Future, pin::Pin, @@ -119,23 +119,26 @@ where }); ResponseFuture { - state: State::Channel(reply_rx), + state: State::Channel { reply_rx }, permit, } } } -#[pin_project] -pub(crate) struct ResponseFuture<F, E> { - #[pin] - state: State<F, E>, - permit: OwnedSemaphorePermit, +pin_project! { + pub(crate) struct ResponseFuture<F, E> { + #[pin] + state: State<F, E>, + permit: OwnedSemaphorePermit, + } } -#[pin_project(project = StateProj)] -enum State<F, E> { - Channel(oneshot::Receiver<WorkerReply<F, E>>), - Future(#[pin] F), +pin_project! { + #[project = StateProj] + enum State<F, E> { + Channel { reply_rx: oneshot::Receiver<WorkerReply<F, E>> }, + Future { #[pin] future: F }, + } } impl<F, E, T> Future for ResponseFuture<F, E> @@ -149,16 +152,16 @@ where let mut this = self.as_mut().project(); let new_state = match this.state.as_mut().project() { - StateProj::Channel(reply_rx) => { + StateProj::Channel { reply_rx } => { let msg = ready!(Pin::new(reply_rx).poll(cx)) .expect("buffer worker not running. This is a bug in axum and should never happen. Please file an issue"); match msg { - WorkerReply::Future(future) => State::Future(future), + WorkerReply::Future(future) => State::Future { future }, WorkerReply::Error(err) => return Poll::Ready(Err(err)), } } - StateProj::Future(future) => { + StateProj::Future { future } => { return future.poll(cx); } }; diff --git a/src/extract/connect_info.rs b/src/extract/connect_info.rs index cb92e095..0a0655c5 100644 --- a/src/extract/connect_info.rs +++ b/src/extract/connect_info.rs @@ -89,7 +89,9 @@ where fn call(&mut self, target: T) -> Self::Future { let connect_info = ConnectInfo(C::connect_info(target)); let svc = AddExtension::new(self.svc.clone(), connect_info); - ResponseFuture(futures_util::future::ok(svc)) + ResponseFuture { + future: futures_util::future::ok(svc), + } } } diff --git a/src/extract/extractor_middleware.rs b/src/extract/extractor_middleware.rs index a8fa6301..84372273 100644 --- a/src/extract/extractor_middleware.rs +++ b/src/extract/extractor_middleware.rs @@ -7,7 +7,7 @@ use crate::{body::BoxBody, response::IntoResponse}; use bytes::Bytes; use futures_util::{future::BoxFuture, ready}; use http::{Request, Response}; -use pin_project::pin_project; +use pin_project_lite::pin_project; use std::{ fmt, future::Future, @@ -178,33 +178,38 @@ where }); ExtractorMiddlewareResponseFuture { - state: State::Extracting(extract_future), + state: State::Extracting { + future: extract_future, + }, svc: Some(self.inner.clone()), } } } -/// Response future for [`ExtractorMiddleware`]. -#[allow(missing_debug_implementations)] -#[pin_project] -pub struct ExtractorMiddlewareResponseFuture<ReqBody, S, E> -where - E: FromRequest<ReqBody>, - S: Service<Request<ReqBody>>, -{ - #[pin] - state: State<ReqBody, S, E>, - svc: Option<S>, +pin_project! { + /// Response future for [`ExtractorMiddleware`]. + #[allow(missing_debug_implementations)] + pub struct ExtractorMiddlewareResponseFuture<ReqBody, S, E> + where + E: FromRequest<ReqBody>, + S: Service<Request<ReqBody>>, + { + #[pin] + state: State<ReqBody, S, E>, + svc: Option<S>, + } } -#[pin_project(project = StateProj)] -enum State<ReqBody, S, E> -where - E: FromRequest<ReqBody>, - S: Service<Request<ReqBody>>, -{ - Extracting(BoxFuture<'static, (RequestParts<ReqBody>, Result<E, E::Rejection>)>), - Call(#[pin] S::Future), +pin_project! { + #[project = StateProj] + enum State<ReqBody, S, E> + where + E: FromRequest<ReqBody>, + S: Service<Request<ReqBody>>, + { + Extracting { future: BoxFuture<'static, (RequestParts<ReqBody>, Result<E, E::Rejection>)> }, + Call { #[pin] future: S::Future }, + } } impl<ReqBody, S, E, ResBody> Future for ExtractorMiddlewareResponseFuture<ReqBody, S, E> @@ -221,14 +226,14 @@ where let mut this = self.as_mut().project(); let new_state = match this.state.as_mut().project() { - StateProj::Extracting(future) => { + StateProj::Extracting { future } => { let (mut req, extracted) = ready!(future.as_mut().poll(cx)); match extracted { Ok(_) => { let mut svc = this.svc.take().expect("future polled after completion"); let future = svc.call(req.into_request()); - State::Call(future) + State::Call { future } } Err(err) => { let res = err.into_response().map(crate::body::box_body); @@ -236,7 +241,7 @@ where } } } - StateProj::Call(future) => { + StateProj::Call { future } => { return future .poll(cx) .map(|result| result.map(|response| response.map(crate::body::box_body))); diff --git a/src/handler/mod.rs b/src/handler/mod.rs index aaa9792a..ed065b4f 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -444,7 +444,7 @@ where let res = Handler::call(handler, req).await; Ok(res) }); - future::IntoServiceFuture(future) + future::IntoServiceFuture { future } } } diff --git a/src/macros.rs b/src/macros.rs index 15b8597f..238b21d0 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -9,10 +9,12 @@ macro_rules! opaque_future { }; ($(#[$m:meta])* pub type $name:ident<$($param:ident),*> = $actual:ty;) => { - #[pin_project::pin_project] - $(#[$m])* - pub struct $name<$($param),*>(#[pin] pub(crate) $actual) - where; + pin_project_lite::pin_project! { + $(#[$m])* + pub struct $name<$($param),*> { + #[pin] pub(crate) future: $actual, + } + } impl<$($param),*> std::fmt::Debug for $name<$($param),*> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -27,7 +29,7 @@ macro_rules! opaque_future { type Output = <$actual as std::future::Future>::Output; #[inline] fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> { - self.project().0.poll(cx) + self.project().future.poll(cx) } } }; diff --git a/src/routing.rs b/src/routing.rs index 21f79d41..4a74763b 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -11,7 +11,7 @@ use async_trait::async_trait; use bytes::Bytes; use futures_util::future; use http::{Method, Request, Response, StatusCode, Uri}; -use pin_project::pin_project; +use pin_project_lite::pin_project; use regex::Regex; use std::{ borrow::Cow, @@ -385,13 +385,16 @@ where } } -/// The response future for [`Route`]. -#[pin_project] -#[derive(Debug)] -pub struct RouteFuture<S, F, B>(#[pin] RouteFutureInner<S, F, B>) -where - S: Service<Request<B>>, - F: Service<Request<B>>; +pin_project! { + /// The response future for [`Route`]. + #[derive(Debug)] + pub struct RouteFuture<S, F, B> + where + S: Service<Request<B>>, + F: Service<Request<B>> { + #[pin] inner: RouteFutureInner<S, F, B>, + } +} impl<S, F, B> RouteFuture<S, F, B> where @@ -399,23 +402,29 @@ where F: Service<Request<B>>, { pub(crate) fn a(a: Oneshot<S, Request<B>>) -> Self { - RouteFuture(RouteFutureInner::A(a)) + RouteFuture { + inner: RouteFutureInner::A { a }, + } } pub(crate) fn b(b: Oneshot<F, Request<B>>) -> Self { - RouteFuture(RouteFutureInner::B(b)) + RouteFuture { + inner: RouteFutureInner::B { b }, + } } } -#[pin_project(project = RouteFutureInnerProj)] -#[derive(Debug)] -enum RouteFutureInner<S, F, B> -where - S: Service<Request<B>>, - F: Service<Request<B>>, -{ - A(#[pin] Oneshot<S, Request<B>>), - B(#[pin] Oneshot<F, Request<B>>), +pin_project! { + #[project = RouteFutureInnerProj] + #[derive(Debug)] + enum RouteFutureInner<S, F, B> + where + S: Service<Request<B>>, + F: Service<Request<B>>, + { + A { #[pin] a: Oneshot<S, Request<B>> }, + B { #[pin] b: Oneshot<F, Request<B>> }, + } } impl<S, F, B> Future for RouteFuture<S, F, B> @@ -426,9 +435,9 @@ where type Output = Result<Response<BoxBody>, S::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - match self.project().0.project() { - RouteFutureInnerProj::A(inner) => inner.poll(cx), - RouteFutureInnerProj::B(inner) => inner.poll(cx), + match self.project().inner.project() { + RouteFutureInnerProj::A { a } => a.poll(cx), + RouteFutureInnerProj::B { b } => b.poll(cx), } } } @@ -510,7 +519,9 @@ impl<B, E> Service<Request<B>> for EmptyRouter<E> { fn call(&mut self, _req: Request<B>) -> Self::Future { let mut res = Response::new(crate::body::empty()); *res.status_mut() = self.status; - EmptyRouterFuture(future::ok(res)) + EmptyRouterFuture { + future: future::ok(res), + } } } @@ -655,15 +666,14 @@ where } } -/// The response future for [`BoxRoute`]. -#[pin_project] -pub struct BoxRouteFuture<B, E> -where - E: Into<BoxError>, -{ - #[pin] - inner: - Oneshot<MpscBuffer<BoxService<Request<B>, Response<BoxBody>, E>, Request<B>>, Request<B>>, +pin_project! { + /// The response future for [`BoxRoute`]. + pub struct BoxRouteFuture<B, E> + where + E: Into<BoxError>, + { + #[pin] inner: Oneshot<MpscBuffer<BoxService<Request<B>, Response<BoxBody>, E>, Request<B>>, Request<B>>, + } } impl<B, E> Future for BoxRouteFuture<B, E> diff --git a/src/service/future.rs b/src/service/future.rs index 7e6cd421..0fbd06dd 100644 --- a/src/service/future.rs +++ b/src/service/future.rs @@ -7,7 +7,7 @@ use crate::{ use bytes::Bytes; use futures_util::ready; use http::Response; -use pin_project::pin_project; +use pin_project_lite::pin_project; use std::{ future::Future, pin::Pin, @@ -15,13 +15,14 @@ use std::{ }; use tower::BoxError; -/// Response future for [`HandleError`](super::HandleError). -#[pin_project] -#[derive(Debug)] -pub struct HandleErrorFuture<Fut, F> { - #[pin] - pub(super) inner: Fut, - pub(super) f: Option<F>, +pin_project! { + /// Response future for [`HandleError`](super::HandleError). + #[derive(Debug)] + pub struct HandleErrorFuture<Fut, F> { + #[pin] + pub(super) inner: Fut, + pub(super) f: Option<F>, + } } impl<Fut, F, E, E2, B, Res> Future for HandleErrorFuture<Fut, F> diff --git a/src/service/mod.rs b/src/service/mod.rs index 7e90468d..7136f2aa 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -94,7 +94,7 @@ use crate::{ use bytes::Bytes; use futures_util::ready; use http::{Request, Response}; -use pin_project::pin_project; +use pin_project_lite::pin_project; use std::{ convert::Infallible, fmt, @@ -645,14 +645,17 @@ where fn call(&mut self, req: Request<ReqBody>) -> Self::Future { let fut = self.inner.clone().oneshot(req); - BoxResponseBodyFuture(fut) + BoxResponseBodyFuture { future: fut } } } -/// Response future for [`BoxResponseBody`]. -#[pin_project] -#[derive(Debug)] -pub struct BoxResponseBodyFuture<F>(#[pin] F); +pin_project! { + /// Response future for [`BoxResponseBody`]. + #[derive(Debug)] + pub struct BoxResponseBodyFuture<F> { + #[pin] future: F, + } +} impl<F, B, E> Future for BoxResponseBodyFuture<F> where @@ -663,7 +666,7 @@ where type Output = Result<Response<BoxBody>, E>; fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - let res = ready!(self.project().0.poll(cx))?; + let res = ready!(self.project().future.poll(cx))?; let res = res.map(box_body); Poll::Ready(Ok(res)) } diff --git a/src/sse.rs b/src/sse.rs index b7ee46ce..d3794911 100644 --- a/src/sse.rs +++ b/src/sse.rs @@ -81,7 +81,7 @@ use futures_util::{ }; use http::{Request, Response}; use hyper::Body; -use pin_project::pin_project; +use pin_project_lite::pin_project; use serde::Serialize; use std::{ borrow::Cow, @@ -245,49 +245,51 @@ where let handler = self.handler.clone(); let keep_alive = self.keep_alive.clone(); - ResponseFuture(Box::pin(async move { - let mut req = RequestParts::new(req); - let input = match T::from_request(&mut req).await { - Ok(input) => input, - Err(err) => { - return Ok(err.into_response().map(box_body)); - } - }; + ResponseFuture { + future: Box::pin(async move { + let mut req = RequestParts::new(req); + let input = match T::from_request(&mut req).await { + Ok(input) => input, + Err(err) => { + return Ok(err.into_response().map(box_body)); + } + }; - let stream = match handler.call(input).await { - Ok(stream) => stream, - Err(err) => { - return Ok(err.into_response().map(box_body)); - } - }; + let stream = match handler.call(input).await { + Ok(stream) => stream, + Err(err) => { + return Ok(err.into_response().map(box_body)); + } + }; - let stream = if let Some(keep_alive) = keep_alive { - KeepAliveStream { - event_stream: stream, - comment_text: keep_alive.comment_text, - max_interval: keep_alive.max_interval, - alive_timer: tokio::time::sleep(keep_alive.max_interval), - } - .left_stream() - } else { - stream.into_stream().right_stream() - }; + let stream = if let Some(keep_alive) = keep_alive { + KeepAliveStream { + event_stream: stream, + comment_text: keep_alive.comment_text, + max_interval: keep_alive.max_interval, + alive_timer: tokio::time::sleep(keep_alive.max_interval), + } + .left_stream() + } else { + stream.into_stream().right_stream() + }; - let stream = stream - .map_ok(|event| event.to_string()) - .map_err(|err| BoxStdError(err.into())) - .into_stream(); + let stream = stream + .map_ok(|event| event.to_string()) + .map_err(|err| BoxStdError(err.into())) + .into_stream(); - let body = box_body(Body::wrap_stream(stream)); + let body = box_body(Body::wrap_stream(stream)); - let response = Response::builder() - .header(http::header::CONTENT_TYPE, "text/event-stream") - .header(http::header::CACHE_CONTROL, "no-cache") - .body(body) - .unwrap(); + let response = Response::builder() + .header(http::header::CONTENT_TYPE, "text/event-stream") + .header(http::header::CACHE_CONTROL, "no-cache") + .body(body) + .unwrap(); - Ok(response) - })) + Ok(response) + }), + } } } @@ -483,14 +485,15 @@ impl Default for KeepAlive { } } -#[pin_project] -struct KeepAliveStream<S> { - #[pin] - event_stream: S, - comment_text: Cow<'static, str>, - max_interval: Duration, - #[pin] - alive_timer: Sleep, +pin_project! { + struct KeepAliveStream<S> { + #[pin] + event_stream: S, + comment_text: Cow<'static, str>, + max_interval: Duration, + #[pin] + alive_timer: Sleep, + } } impl<S> Stream for KeepAliveStream<S> diff --git a/src/ws/mod.rs b/src/ws/mod.rs index dff318df..fa8bb795 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -277,110 +277,112 @@ where let this = self.clone(); let protocols = self.protocols.clone(); - ResponseFuture(Box::pin(async move { - if req.method() != http::Method::GET { - return response(StatusCode::NOT_FOUND, "Request method must be `GET`"); - } + ResponseFuture { + future: Box::pin(async move { + if req.method() != http::Method::GET { + return response(StatusCode::NOT_FOUND, "Request method must be `GET`"); + } - if !header_contains(&req, header::CONNECTION, "upgrade") { - return response( - StatusCode::BAD_REQUEST, - "Connection header did not include 'upgrade'", - ); - } + if !header_contains(&req, header::CONNECTION, "upgrade") { + return response( + StatusCode::BAD_REQUEST, + "Connection header did not include 'upgrade'", + ); + } - if !header_eq(&req, header::UPGRADE, "websocket") { - return response( - StatusCode::BAD_REQUEST, - "`Upgrade` header did not include 'websocket'", - ); - } + if !header_eq(&req, header::UPGRADE, "websocket") { + return response( + StatusCode::BAD_REQUEST, + "`Upgrade` header did not include 'websocket'", + ); + } - if !header_eq(&req, header::SEC_WEBSOCKET_VERSION, "13") { - return response( - StatusCode::BAD_REQUEST, - "`Sec-Websocket-Version` header did not include '13'", - ); - } + if !header_eq(&req, header::SEC_WEBSOCKET_VERSION, "13") { + return response( + StatusCode::BAD_REQUEST, + "`Sec-Websocket-Version` header did not include '13'", + ); + } - // check requested protocols - let protocol = - req.headers() - .get(&header::SEC_WEBSOCKET_PROTOCOL) - .and_then(|req_protocols| { - let req_protocols = req_protocols.to_str().ok()?; - req_protocols - .split(',') - .map(|req_p| req_p.trim()) - .find(|req_p| protocols.iter().any(|p| p == req_p)) - }); - let protocol = match protocol { - Some(protocol) => { - if let Ok(protocol) = HeaderValue::from_str(protocol) { - Some(protocol) - } else { - return response( - StatusCode::BAD_REQUEST, - "`Sec-Websocket-Protocol` header is invalid", - ); + // check requested protocols + let protocol = + req.headers() + .get(&header::SEC_WEBSOCKET_PROTOCOL) + .and_then(|req_protocols| { + let req_protocols = req_protocols.to_str().ok()?; + req_protocols + .split(',') + .map(|req_p| req_p.trim()) + .find(|req_p| protocols.iter().any(|p| p == req_p)) + }); + let protocol = match protocol { + Some(protocol) => { + if let Ok(protocol) = HeaderValue::from_str(protocol) { + Some(protocol) + } else { + return response( + StatusCode::BAD_REQUEST, + "`Sec-Websocket-Protocol` header is invalid", + ); + } } + None => None, + }; + + let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) { + key + } else { + return response( + StatusCode::BAD_REQUEST, + "`Sec-Websocket-Key` header missing", + ); + }; + + let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap(); + + let config = this.config; + let callback = this.callback.clone(); + + let mut req = RequestParts::new(req); + let input = match T::from_request(&mut req).await { + Ok(input) => input, + Err(rejection) => { + let res = rejection.into_response().map(box_body); + return Ok(res); + } + }; + + tokio::spawn(async move { + let upgraded = on_upgrade.await.unwrap(); + let socket = WebSocketStream::from_raw_socket( + upgraded, + protocol::Role::Server, + Some(config), + ) + .await; + let socket = WebSocket { inner: socket }; + callback.call(socket, input).await; + }); + + let mut builder = Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header( + http::header::CONNECTION, + HeaderValue::from_str("upgrade").unwrap(), + ) + .header( + http::header::UPGRADE, + HeaderValue::from_str("websocket").unwrap(), + ) + .header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes())); + if let Some(protocol) = protocol { + builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol); } - None => None, - }; + let res = builder.body(box_body(Full::new(Bytes::new()))).unwrap(); - let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) { - key - } else { - return response( - StatusCode::BAD_REQUEST, - "`Sec-Websocket-Key` header missing", - ); - }; - - let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap(); - - let config = this.config; - let callback = this.callback.clone(); - - let mut req = RequestParts::new(req); - let input = match T::from_request(&mut req).await { - Ok(input) => input, - Err(rejection) => { - let res = rejection.into_response().map(box_body); - return Ok(res); - } - }; - - tokio::spawn(async move { - let upgraded = on_upgrade.await.unwrap(); - let socket = WebSocketStream::from_raw_socket( - upgraded, - protocol::Role::Server, - Some(config), - ) - .await; - let socket = WebSocket { inner: socket }; - callback.call(socket, input).await; - }); - - let mut builder = Response::builder() - .status(StatusCode::SWITCHING_PROTOCOLS) - .header( - http::header::CONNECTION, - HeaderValue::from_str("upgrade").unwrap(), - ) - .header( - http::header::UPGRADE, - HeaderValue::from_str("websocket").unwrap(), - ) - .header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes())); - if let Some(protocol) = protocol { - builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol); - } - let res = builder.body(box_body(Full::new(Bytes::new()))).unwrap(); - - Ok(res) - })) + Ok(res) + }), + } } }