From be68227d739aece4210d4e7cb7d419da4d4afe30 Mon Sep 17 00:00:00 2001
From: Sunli <sunlipad4@icloud.com>
Date: Tue, 3 Aug 2021 15:33:00 +0800
Subject: [PATCH] Use `pin-project-lite` instead of `pin-project` (#95)

---
 CHANGELOG.md                        |   1 +
 Cargo.toml                          |   2 +-
 src/buffer.rs                       |  31 +++--
 src/extract/connect_info.rs         |   4 +-
 src/extract/extractor_middleware.rs |  53 ++++----
 src/handler/mod.rs                  |   2 +-
 src/macros.rs                       |  12 +-
 src/routing.rs                      |  74 ++++++-----
 src/service/future.rs               |  17 +--
 src/service/mod.rs                  |  17 ++-
 src/sse.rs                          |  95 +++++++-------
 src/ws/mod.rs                       | 196 ++++++++++++++--------------
 12 files changed, 268 insertions(+), 236 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index f07c3e76..484f7e08 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 - Fix stripping prefix when nesting services at `/` ([#91](https://github.com/tokio-rs/axum/pull/91))
 - Add support for WebSocket protocol negotiation. ([#83](https://github.com/tokio-rs/axum/pull/83))
+- Use `pin-project-lite` instead of `pin-project`. ([#95](https://github.com/tokio-rs/axum/pull/95))
 
 ## Breaking changes
 
diff --git a/Cargo.toml b/Cargo.toml
index bcf17d4f..4234e409 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -24,7 +24,7 @@ futures-util = "0.3"
 http = "0.2"
 http-body = "0.4.2"
 hyper = { version = "0.14", features = ["server", "tcp", "http1", "stream"] }
-pin-project = "1.0"
+pin-project-lite = "0.2.7"
 regex = "1.5"
 serde = "1.0"
 serde_json = "1.0"
diff --git a/src/buffer.rs b/src/buffer.rs
index c7134adc..4b85927c 100644
--- a/src/buffer.rs
+++ b/src/buffer.rs
@@ -1,5 +1,5 @@
 use futures_util::ready;
-use pin_project::pin_project;
+use pin_project_lite::pin_project;
 use std::{
     future::Future,
     pin::Pin,
@@ -119,23 +119,26 @@ where
         });
 
         ResponseFuture {
-            state: State::Channel(reply_rx),
+            state: State::Channel { reply_rx },
             permit,
         }
     }
 }
 
-#[pin_project]
-pub(crate) struct ResponseFuture<F, E> {
-    #[pin]
-    state: State<F, E>,
-    permit: OwnedSemaphorePermit,
+pin_project! {
+    pub(crate) struct ResponseFuture<F, E> {
+        #[pin]
+        state: State<F, E>,
+        permit: OwnedSemaphorePermit,
+    }
 }
 
-#[pin_project(project = StateProj)]
-enum State<F, E> {
-    Channel(oneshot::Receiver<WorkerReply<F, E>>),
-    Future(#[pin] F),
+pin_project! {
+    #[project = StateProj]
+    enum State<F, E> {
+        Channel { reply_rx: oneshot::Receiver<WorkerReply<F, E>> },
+        Future { #[pin] future: F },
+    }
 }
 
 impl<F, E, T> Future for ResponseFuture<F, E>
@@ -149,16 +152,16 @@ where
             let mut this = self.as_mut().project();
 
             let new_state = match this.state.as_mut().project() {
-                StateProj::Channel(reply_rx) => {
+                StateProj::Channel { reply_rx } => {
                     let msg = ready!(Pin::new(reply_rx).poll(cx))
                         .expect("buffer worker not running. This is a bug in axum and should never happen. Please file an issue");
 
                     match msg {
-                        WorkerReply::Future(future) => State::Future(future),
+                        WorkerReply::Future(future) => State::Future { future },
                         WorkerReply::Error(err) => return Poll::Ready(Err(err)),
                     }
                 }
-                StateProj::Future(future) => {
+                StateProj::Future { future } => {
                     return future.poll(cx);
                 }
             };
diff --git a/src/extract/connect_info.rs b/src/extract/connect_info.rs
index cb92e095..0a0655c5 100644
--- a/src/extract/connect_info.rs
+++ b/src/extract/connect_info.rs
@@ -89,7 +89,9 @@ where
     fn call(&mut self, target: T) -> Self::Future {
         let connect_info = ConnectInfo(C::connect_info(target));
         let svc = AddExtension::new(self.svc.clone(), connect_info);
-        ResponseFuture(futures_util::future::ok(svc))
+        ResponseFuture {
+            future: futures_util::future::ok(svc),
+        }
     }
 }
 
diff --git a/src/extract/extractor_middleware.rs b/src/extract/extractor_middleware.rs
index a8fa6301..84372273 100644
--- a/src/extract/extractor_middleware.rs
+++ b/src/extract/extractor_middleware.rs
@@ -7,7 +7,7 @@ use crate::{body::BoxBody, response::IntoResponse};
 use bytes::Bytes;
 use futures_util::{future::BoxFuture, ready};
 use http::{Request, Response};
-use pin_project::pin_project;
+use pin_project_lite::pin_project;
 use std::{
     fmt,
     future::Future,
@@ -178,33 +178,38 @@ where
         });
 
         ExtractorMiddlewareResponseFuture {
-            state: State::Extracting(extract_future),
+            state: State::Extracting {
+                future: extract_future,
+            },
             svc: Some(self.inner.clone()),
         }
     }
 }
 
-/// Response future for [`ExtractorMiddleware`].
-#[allow(missing_debug_implementations)]
-#[pin_project]
-pub struct ExtractorMiddlewareResponseFuture<ReqBody, S, E>
-where
-    E: FromRequest<ReqBody>,
-    S: Service<Request<ReqBody>>,
-{
-    #[pin]
-    state: State<ReqBody, S, E>,
-    svc: Option<S>,
+pin_project! {
+    /// Response future for [`ExtractorMiddleware`].
+    #[allow(missing_debug_implementations)]
+    pub struct ExtractorMiddlewareResponseFuture<ReqBody, S, E>
+    where
+        E: FromRequest<ReqBody>,
+        S: Service<Request<ReqBody>>,
+    {
+        #[pin]
+        state: State<ReqBody, S, E>,
+        svc: Option<S>,
+    }
 }
 
-#[pin_project(project = StateProj)]
-enum State<ReqBody, S, E>
-where
-    E: FromRequest<ReqBody>,
-    S: Service<Request<ReqBody>>,
-{
-    Extracting(BoxFuture<'static, (RequestParts<ReqBody>, Result<E, E::Rejection>)>),
-    Call(#[pin] S::Future),
+pin_project! {
+    #[project = StateProj]
+    enum State<ReqBody, S, E>
+    where
+        E: FromRequest<ReqBody>,
+        S: Service<Request<ReqBody>>,
+    {
+        Extracting { future: BoxFuture<'static, (RequestParts<ReqBody>, Result<E, E::Rejection>)> },
+        Call { #[pin] future: S::Future },
+    }
 }
 
 impl<ReqBody, S, E, ResBody> Future for ExtractorMiddlewareResponseFuture<ReqBody, S, E>
@@ -221,14 +226,14 @@ where
             let mut this = self.as_mut().project();
 
             let new_state = match this.state.as_mut().project() {
-                StateProj::Extracting(future) => {
+                StateProj::Extracting { future } => {
                     let (mut req, extracted) = ready!(future.as_mut().poll(cx));
 
                     match extracted {
                         Ok(_) => {
                             let mut svc = this.svc.take().expect("future polled after completion");
                             let future = svc.call(req.into_request());
-                            State::Call(future)
+                            State::Call { future }
                         }
                         Err(err) => {
                             let res = err.into_response().map(crate::body::box_body);
@@ -236,7 +241,7 @@ where
                         }
                     }
                 }
-                StateProj::Call(future) => {
+                StateProj::Call { future } => {
                     return future
                         .poll(cx)
                         .map(|result| result.map(|response| response.map(crate::body::box_body)));
diff --git a/src/handler/mod.rs b/src/handler/mod.rs
index aaa9792a..ed065b4f 100644
--- a/src/handler/mod.rs
+++ b/src/handler/mod.rs
@@ -444,7 +444,7 @@ where
             let res = Handler::call(handler, req).await;
             Ok(res)
         });
-        future::IntoServiceFuture(future)
+        future::IntoServiceFuture { future }
     }
 }
 
diff --git a/src/macros.rs b/src/macros.rs
index 15b8597f..238b21d0 100644
--- a/src/macros.rs
+++ b/src/macros.rs
@@ -9,10 +9,12 @@ macro_rules! opaque_future {
     };
 
     ($(#[$m:meta])* pub type $name:ident<$($param:ident),*> = $actual:ty;) => {
-        #[pin_project::pin_project]
-        $(#[$m])*
-        pub struct $name<$($param),*>(#[pin] pub(crate) $actual)
-        where;
+        pin_project_lite::pin_project! {
+            $(#[$m])*
+            pub struct $name<$($param),*> {
+                #[pin] pub(crate) future: $actual,
+            }
+        }
 
         impl<$($param),*> std::fmt::Debug for $name<$($param),*> {
             fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -27,7 +29,7 @@ macro_rules! opaque_future {
             type Output = <$actual as std::future::Future>::Output;
             #[inline]
             fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
-                self.project().0.poll(cx)
+                self.project().future.poll(cx)
             }
         }
     };
diff --git a/src/routing.rs b/src/routing.rs
index 21f79d41..4a74763b 100644
--- a/src/routing.rs
+++ b/src/routing.rs
@@ -11,7 +11,7 @@ use async_trait::async_trait;
 use bytes::Bytes;
 use futures_util::future;
 use http::{Method, Request, Response, StatusCode, Uri};
-use pin_project::pin_project;
+use pin_project_lite::pin_project;
 use regex::Regex;
 use std::{
     borrow::Cow,
@@ -385,13 +385,16 @@ where
     }
 }
 
-/// The response future for [`Route`].
-#[pin_project]
-#[derive(Debug)]
-pub struct RouteFuture<S, F, B>(#[pin] RouteFutureInner<S, F, B>)
-where
-    S: Service<Request<B>>,
-    F: Service<Request<B>>;
+pin_project! {
+    /// The response future for [`Route`].
+    #[derive(Debug)]
+    pub struct RouteFuture<S, F, B>
+    where
+        S: Service<Request<B>>,
+        F: Service<Request<B>> {
+        #[pin] inner: RouteFutureInner<S, F, B>,
+    }
+}
 
 impl<S, F, B> RouteFuture<S, F, B>
 where
@@ -399,23 +402,29 @@ where
     F: Service<Request<B>>,
 {
     pub(crate) fn a(a: Oneshot<S, Request<B>>) -> Self {
-        RouteFuture(RouteFutureInner::A(a))
+        RouteFuture {
+            inner: RouteFutureInner::A { a },
+        }
     }
 
     pub(crate) fn b(b: Oneshot<F, Request<B>>) -> Self {
-        RouteFuture(RouteFutureInner::B(b))
+        RouteFuture {
+            inner: RouteFutureInner::B { b },
+        }
     }
 }
 
-#[pin_project(project = RouteFutureInnerProj)]
-#[derive(Debug)]
-enum RouteFutureInner<S, F, B>
-where
-    S: Service<Request<B>>,
-    F: Service<Request<B>>,
-{
-    A(#[pin] Oneshot<S, Request<B>>),
-    B(#[pin] Oneshot<F, Request<B>>),
+pin_project! {
+    #[project = RouteFutureInnerProj]
+    #[derive(Debug)]
+    enum RouteFutureInner<S, F, B>
+    where
+        S: Service<Request<B>>,
+        F: Service<Request<B>>,
+    {
+        A { #[pin] a: Oneshot<S, Request<B>> },
+        B { #[pin] b: Oneshot<F, Request<B>> },
+    }
 }
 
 impl<S, F, B> Future for RouteFuture<S, F, B>
@@ -426,9 +435,9 @@ where
     type Output = Result<Response<BoxBody>, S::Error>;
 
     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
-        match self.project().0.project() {
-            RouteFutureInnerProj::A(inner) => inner.poll(cx),
-            RouteFutureInnerProj::B(inner) => inner.poll(cx),
+        match self.project().inner.project() {
+            RouteFutureInnerProj::A { a } => a.poll(cx),
+            RouteFutureInnerProj::B { b } => b.poll(cx),
         }
     }
 }
@@ -510,7 +519,9 @@ impl<B, E> Service<Request<B>> for EmptyRouter<E> {
     fn call(&mut self, _req: Request<B>) -> Self::Future {
         let mut res = Response::new(crate::body::empty());
         *res.status_mut() = self.status;
-        EmptyRouterFuture(future::ok(res))
+        EmptyRouterFuture {
+            future: future::ok(res),
+        }
     }
 }
 
@@ -655,15 +666,14 @@ where
     }
 }
 
-/// The response future for [`BoxRoute`].
-#[pin_project]
-pub struct BoxRouteFuture<B, E>
-where
-    E: Into<BoxError>,
-{
-    #[pin]
-    inner:
-        Oneshot<MpscBuffer<BoxService<Request<B>, Response<BoxBody>, E>, Request<B>>, Request<B>>,
+pin_project! {
+    /// The response future for [`BoxRoute`].
+    pub struct BoxRouteFuture<B, E>
+    where
+        E: Into<BoxError>,
+    {
+        #[pin] inner: Oneshot<MpscBuffer<BoxService<Request<B>, Response<BoxBody>, E>, Request<B>>, Request<B>>,
+    }
 }
 
 impl<B, E> Future for BoxRouteFuture<B, E>
diff --git a/src/service/future.rs b/src/service/future.rs
index 7e6cd421..0fbd06dd 100644
--- a/src/service/future.rs
+++ b/src/service/future.rs
@@ -7,7 +7,7 @@ use crate::{
 use bytes::Bytes;
 use futures_util::ready;
 use http::Response;
-use pin_project::pin_project;
+use pin_project_lite::pin_project;
 use std::{
     future::Future,
     pin::Pin,
@@ -15,13 +15,14 @@ use std::{
 };
 use tower::BoxError;
 
-/// Response future for [`HandleError`](super::HandleError).
-#[pin_project]
-#[derive(Debug)]
-pub struct HandleErrorFuture<Fut, F> {
-    #[pin]
-    pub(super) inner: Fut,
-    pub(super) f: Option<F>,
+pin_project! {
+    /// Response future for [`HandleError`](super::HandleError).
+    #[derive(Debug)]
+    pub struct HandleErrorFuture<Fut, F> {
+        #[pin]
+        pub(super) inner: Fut,
+        pub(super) f: Option<F>,
+    }
 }
 
 impl<Fut, F, E, E2, B, Res> Future for HandleErrorFuture<Fut, F>
diff --git a/src/service/mod.rs b/src/service/mod.rs
index 7e90468d..7136f2aa 100644
--- a/src/service/mod.rs
+++ b/src/service/mod.rs
@@ -94,7 +94,7 @@ use crate::{
 use bytes::Bytes;
 use futures_util::ready;
 use http::{Request, Response};
-use pin_project::pin_project;
+use pin_project_lite::pin_project;
 use std::{
     convert::Infallible,
     fmt,
@@ -645,14 +645,17 @@ where
 
     fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
         let fut = self.inner.clone().oneshot(req);
-        BoxResponseBodyFuture(fut)
+        BoxResponseBodyFuture { future: fut }
     }
 }
 
-/// Response future for [`BoxResponseBody`].
-#[pin_project]
-#[derive(Debug)]
-pub struct BoxResponseBodyFuture<F>(#[pin] F);
+pin_project! {
+    /// Response future for [`BoxResponseBody`].
+    #[derive(Debug)]
+    pub struct BoxResponseBodyFuture<F> {
+        #[pin] future: F,
+    }
+}
 
 impl<F, B, E> Future for BoxResponseBodyFuture<F>
 where
@@ -663,7 +666,7 @@ where
     type Output = Result<Response<BoxBody>, E>;
 
     fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
-        let res = ready!(self.project().0.poll(cx))?;
+        let res = ready!(self.project().future.poll(cx))?;
         let res = res.map(box_body);
         Poll::Ready(Ok(res))
     }
diff --git a/src/sse.rs b/src/sse.rs
index b7ee46ce..d3794911 100644
--- a/src/sse.rs
+++ b/src/sse.rs
@@ -81,7 +81,7 @@ use futures_util::{
 };
 use http::{Request, Response};
 use hyper::Body;
-use pin_project::pin_project;
+use pin_project_lite::pin_project;
 use serde::Serialize;
 use std::{
     borrow::Cow,
@@ -245,49 +245,51 @@ where
         let handler = self.handler.clone();
         let keep_alive = self.keep_alive.clone();
 
-        ResponseFuture(Box::pin(async move {
-            let mut req = RequestParts::new(req);
-            let input = match T::from_request(&mut req).await {
-                Ok(input) => input,
-                Err(err) => {
-                    return Ok(err.into_response().map(box_body));
-                }
-            };
+        ResponseFuture {
+            future: Box::pin(async move {
+                let mut req = RequestParts::new(req);
+                let input = match T::from_request(&mut req).await {
+                    Ok(input) => input,
+                    Err(err) => {
+                        return Ok(err.into_response().map(box_body));
+                    }
+                };
 
-            let stream = match handler.call(input).await {
-                Ok(stream) => stream,
-                Err(err) => {
-                    return Ok(err.into_response().map(box_body));
-                }
-            };
+                let stream = match handler.call(input).await {
+                    Ok(stream) => stream,
+                    Err(err) => {
+                        return Ok(err.into_response().map(box_body));
+                    }
+                };
 
-            let stream = if let Some(keep_alive) = keep_alive {
-                KeepAliveStream {
-                    event_stream: stream,
-                    comment_text: keep_alive.comment_text,
-                    max_interval: keep_alive.max_interval,
-                    alive_timer: tokio::time::sleep(keep_alive.max_interval),
-                }
-                .left_stream()
-            } else {
-                stream.into_stream().right_stream()
-            };
+                let stream = if let Some(keep_alive) = keep_alive {
+                    KeepAliveStream {
+                        event_stream: stream,
+                        comment_text: keep_alive.comment_text,
+                        max_interval: keep_alive.max_interval,
+                        alive_timer: tokio::time::sleep(keep_alive.max_interval),
+                    }
+                    .left_stream()
+                } else {
+                    stream.into_stream().right_stream()
+                };
 
-            let stream = stream
-                .map_ok(|event| event.to_string())
-                .map_err(|err| BoxStdError(err.into()))
-                .into_stream();
+                let stream = stream
+                    .map_ok(|event| event.to_string())
+                    .map_err(|err| BoxStdError(err.into()))
+                    .into_stream();
 
-            let body = box_body(Body::wrap_stream(stream));
+                let body = box_body(Body::wrap_stream(stream));
 
-            let response = Response::builder()
-                .header(http::header::CONTENT_TYPE, "text/event-stream")
-                .header(http::header::CACHE_CONTROL, "no-cache")
-                .body(body)
-                .unwrap();
+                let response = Response::builder()
+                    .header(http::header::CONTENT_TYPE, "text/event-stream")
+                    .header(http::header::CACHE_CONTROL, "no-cache")
+                    .body(body)
+                    .unwrap();
 
-            Ok(response)
-        }))
+                Ok(response)
+            }),
+        }
     }
 }
 
@@ -483,14 +485,15 @@ impl Default for KeepAlive {
     }
 }
 
-#[pin_project]
-struct KeepAliveStream<S> {
-    #[pin]
-    event_stream: S,
-    comment_text: Cow<'static, str>,
-    max_interval: Duration,
-    #[pin]
-    alive_timer: Sleep,
+pin_project! {
+    struct KeepAliveStream<S> {
+        #[pin]
+        event_stream: S,
+        comment_text: Cow<'static, str>,
+        max_interval: Duration,
+        #[pin]
+        alive_timer: Sleep,
+    }
 }
 
 impl<S> Stream for KeepAliveStream<S>
diff --git a/src/ws/mod.rs b/src/ws/mod.rs
index dff318df..fa8bb795 100644
--- a/src/ws/mod.rs
+++ b/src/ws/mod.rs
@@ -277,110 +277,112 @@ where
         let this = self.clone();
         let protocols = self.protocols.clone();
 
-        ResponseFuture(Box::pin(async move {
-            if req.method() != http::Method::GET {
-                return response(StatusCode::NOT_FOUND, "Request method must be `GET`");
-            }
+        ResponseFuture {
+            future: Box::pin(async move {
+                if req.method() != http::Method::GET {
+                    return response(StatusCode::NOT_FOUND, "Request method must be `GET`");
+                }
 
-            if !header_contains(&req, header::CONNECTION, "upgrade") {
-                return response(
-                    StatusCode::BAD_REQUEST,
-                    "Connection header did not include 'upgrade'",
-                );
-            }
+                if !header_contains(&req, header::CONNECTION, "upgrade") {
+                    return response(
+                        StatusCode::BAD_REQUEST,
+                        "Connection header did not include 'upgrade'",
+                    );
+                }
 
-            if !header_eq(&req, header::UPGRADE, "websocket") {
-                return response(
-                    StatusCode::BAD_REQUEST,
-                    "`Upgrade` header did not include 'websocket'",
-                );
-            }
+                if !header_eq(&req, header::UPGRADE, "websocket") {
+                    return response(
+                        StatusCode::BAD_REQUEST,
+                        "`Upgrade` header did not include 'websocket'",
+                    );
+                }
 
-            if !header_eq(&req, header::SEC_WEBSOCKET_VERSION, "13") {
-                return response(
-                    StatusCode::BAD_REQUEST,
-                    "`Sec-Websocket-Version` header did not include '13'",
-                );
-            }
+                if !header_eq(&req, header::SEC_WEBSOCKET_VERSION, "13") {
+                    return response(
+                        StatusCode::BAD_REQUEST,
+                        "`Sec-Websocket-Version` header did not include '13'",
+                    );
+                }
 
-            // check requested protocols
-            let protocol =
-                req.headers()
-                    .get(&header::SEC_WEBSOCKET_PROTOCOL)
-                    .and_then(|req_protocols| {
-                        let req_protocols = req_protocols.to_str().ok()?;
-                        req_protocols
-                            .split(',')
-                            .map(|req_p| req_p.trim())
-                            .find(|req_p| protocols.iter().any(|p| p == req_p))
-                    });
-            let protocol = match protocol {
-                Some(protocol) => {
-                    if let Ok(protocol) = HeaderValue::from_str(protocol) {
-                        Some(protocol)
-                    } else {
-                        return response(
-                            StatusCode::BAD_REQUEST,
-                            "`Sec-Websocket-Protocol` header is invalid",
-                        );
+                // check requested protocols
+                let protocol =
+                    req.headers()
+                        .get(&header::SEC_WEBSOCKET_PROTOCOL)
+                        .and_then(|req_protocols| {
+                            let req_protocols = req_protocols.to_str().ok()?;
+                            req_protocols
+                                .split(',')
+                                .map(|req_p| req_p.trim())
+                                .find(|req_p| protocols.iter().any(|p| p == req_p))
+                        });
+                let protocol = match protocol {
+                    Some(protocol) => {
+                        if let Ok(protocol) = HeaderValue::from_str(protocol) {
+                            Some(protocol)
+                        } else {
+                            return response(
+                                StatusCode::BAD_REQUEST,
+                                "`Sec-Websocket-Protocol` header is invalid",
+                            );
+                        }
                     }
+                    None => None,
+                };
+
+                let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
+                    key
+                } else {
+                    return response(
+                        StatusCode::BAD_REQUEST,
+                        "`Sec-Websocket-Key` header missing",
+                    );
+                };
+
+                let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
+
+                let config = this.config;
+                let callback = this.callback.clone();
+
+                let mut req = RequestParts::new(req);
+                let input = match T::from_request(&mut req).await {
+                    Ok(input) => input,
+                    Err(rejection) => {
+                        let res = rejection.into_response().map(box_body);
+                        return Ok(res);
+                    }
+                };
+
+                tokio::spawn(async move {
+                    let upgraded = on_upgrade.await.unwrap();
+                    let socket = WebSocketStream::from_raw_socket(
+                        upgraded,
+                        protocol::Role::Server,
+                        Some(config),
+                    )
+                    .await;
+                    let socket = WebSocket { inner: socket };
+                    callback.call(socket, input).await;
+                });
+
+                let mut builder = Response::builder()
+                    .status(StatusCode::SWITCHING_PROTOCOLS)
+                    .header(
+                        http::header::CONNECTION,
+                        HeaderValue::from_str("upgrade").unwrap(),
+                    )
+                    .header(
+                        http::header::UPGRADE,
+                        HeaderValue::from_str("websocket").unwrap(),
+                    )
+                    .header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()));
+                if let Some(protocol) = protocol {
+                    builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol);
                 }
-                None => None,
-            };
+                let res = builder.body(box_body(Full::new(Bytes::new()))).unwrap();
 
-            let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
-                key
-            } else {
-                return response(
-                    StatusCode::BAD_REQUEST,
-                    "`Sec-Websocket-Key` header missing",
-                );
-            };
-
-            let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
-
-            let config = this.config;
-            let callback = this.callback.clone();
-
-            let mut req = RequestParts::new(req);
-            let input = match T::from_request(&mut req).await {
-                Ok(input) => input,
-                Err(rejection) => {
-                    let res = rejection.into_response().map(box_body);
-                    return Ok(res);
-                }
-            };
-
-            tokio::spawn(async move {
-                let upgraded = on_upgrade.await.unwrap();
-                let socket = WebSocketStream::from_raw_socket(
-                    upgraded,
-                    protocol::Role::Server,
-                    Some(config),
-                )
-                .await;
-                let socket = WebSocket { inner: socket };
-                callback.call(socket, input).await;
-            });
-
-            let mut builder = Response::builder()
-                .status(StatusCode::SWITCHING_PROTOCOLS)
-                .header(
-                    http::header::CONNECTION,
-                    HeaderValue::from_str("upgrade").unwrap(),
-                )
-                .header(
-                    http::header::UPGRADE,
-                    HeaderValue::from_str("websocket").unwrap(),
-                )
-                .header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()));
-            if let Some(protocol) = protocol {
-                builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol);
-            }
-            let res = builder.body(box_body(Full::new(Bytes::new()))).unwrap();
-
-            Ok(res)
-        }))
+                Ok(res)
+            }),
+        }
     }
 }