mirror of
https://github.com/tokio-rs/axum.git
synced 2025-03-13 19:27:53 +01:00
Use pin-project-lite
instead of pin-project
(#95)
This commit is contained in:
parent
ba74787532
commit
be68227d73
12 changed files with 268 additions and 236 deletions
|
@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
- Fix stripping prefix when nesting services at `/` ([#91](https://github.com/tokio-rs/axum/pull/91))
|
||||
- Add support for WebSocket protocol negotiation. ([#83](https://github.com/tokio-rs/axum/pull/83))
|
||||
- Use `pin-project-lite` instead of `pin-project`. ([#95](https://github.com/tokio-rs/axum/pull/95))
|
||||
|
||||
## Breaking changes
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ futures-util = "0.3"
|
|||
http = "0.2"
|
||||
http-body = "0.4.2"
|
||||
hyper = { version = "0.14", features = ["server", "tcp", "http1", "stream"] }
|
||||
pin-project = "1.0"
|
||||
pin-project-lite = "0.2.7"
|
||||
regex = "1.5"
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use futures_util::ready;
|
||||
use pin_project::pin_project;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
|
@ -119,23 +119,26 @@ where
|
|||
});
|
||||
|
||||
ResponseFuture {
|
||||
state: State::Channel(reply_rx),
|
||||
state: State::Channel { reply_rx },
|
||||
permit,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
pub(crate) struct ResponseFuture<F, E> {
|
||||
#[pin]
|
||||
state: State<F, E>,
|
||||
permit: OwnedSemaphorePermit,
|
||||
pin_project! {
|
||||
pub(crate) struct ResponseFuture<F, E> {
|
||||
#[pin]
|
||||
state: State<F, E>,
|
||||
permit: OwnedSemaphorePermit,
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project(project = StateProj)]
|
||||
enum State<F, E> {
|
||||
Channel(oneshot::Receiver<WorkerReply<F, E>>),
|
||||
Future(#[pin] F),
|
||||
pin_project! {
|
||||
#[project = StateProj]
|
||||
enum State<F, E> {
|
||||
Channel { reply_rx: oneshot::Receiver<WorkerReply<F, E>> },
|
||||
Future { #[pin] future: F },
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, E, T> Future for ResponseFuture<F, E>
|
||||
|
@ -149,16 +152,16 @@ where
|
|||
let mut this = self.as_mut().project();
|
||||
|
||||
let new_state = match this.state.as_mut().project() {
|
||||
StateProj::Channel(reply_rx) => {
|
||||
StateProj::Channel { reply_rx } => {
|
||||
let msg = ready!(Pin::new(reply_rx).poll(cx))
|
||||
.expect("buffer worker not running. This is a bug in axum and should never happen. Please file an issue");
|
||||
|
||||
match msg {
|
||||
WorkerReply::Future(future) => State::Future(future),
|
||||
WorkerReply::Future(future) => State::Future { future },
|
||||
WorkerReply::Error(err) => return Poll::Ready(Err(err)),
|
||||
}
|
||||
}
|
||||
StateProj::Future(future) => {
|
||||
StateProj::Future { future } => {
|
||||
return future.poll(cx);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -89,7 +89,9 @@ where
|
|||
fn call(&mut self, target: T) -> Self::Future {
|
||||
let connect_info = ConnectInfo(C::connect_info(target));
|
||||
let svc = AddExtension::new(self.svc.clone(), connect_info);
|
||||
ResponseFuture(futures_util::future::ok(svc))
|
||||
ResponseFuture {
|
||||
future: futures_util::future::ok(svc),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::{body::BoxBody, response::IntoResponse};
|
|||
use bytes::Bytes;
|
||||
use futures_util::{future::BoxFuture, ready};
|
||||
use http::{Request, Response};
|
||||
use pin_project::pin_project;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
fmt,
|
||||
future::Future,
|
||||
|
@ -178,33 +178,38 @@ where
|
|||
});
|
||||
|
||||
ExtractorMiddlewareResponseFuture {
|
||||
state: State::Extracting(extract_future),
|
||||
state: State::Extracting {
|
||||
future: extract_future,
|
||||
},
|
||||
svc: Some(self.inner.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Response future for [`ExtractorMiddleware`].
|
||||
#[allow(missing_debug_implementations)]
|
||||
#[pin_project]
|
||||
pub struct ExtractorMiddlewareResponseFuture<ReqBody, S, E>
|
||||
where
|
||||
E: FromRequest<ReqBody>,
|
||||
S: Service<Request<ReqBody>>,
|
||||
{
|
||||
#[pin]
|
||||
state: State<ReqBody, S, E>,
|
||||
svc: Option<S>,
|
||||
pin_project! {
|
||||
/// Response future for [`ExtractorMiddleware`].
|
||||
#[allow(missing_debug_implementations)]
|
||||
pub struct ExtractorMiddlewareResponseFuture<ReqBody, S, E>
|
||||
where
|
||||
E: FromRequest<ReqBody>,
|
||||
S: Service<Request<ReqBody>>,
|
||||
{
|
||||
#[pin]
|
||||
state: State<ReqBody, S, E>,
|
||||
svc: Option<S>,
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project(project = StateProj)]
|
||||
enum State<ReqBody, S, E>
|
||||
where
|
||||
E: FromRequest<ReqBody>,
|
||||
S: Service<Request<ReqBody>>,
|
||||
{
|
||||
Extracting(BoxFuture<'static, (RequestParts<ReqBody>, Result<E, E::Rejection>)>),
|
||||
Call(#[pin] S::Future),
|
||||
pin_project! {
|
||||
#[project = StateProj]
|
||||
enum State<ReqBody, S, E>
|
||||
where
|
||||
E: FromRequest<ReqBody>,
|
||||
S: Service<Request<ReqBody>>,
|
||||
{
|
||||
Extracting { future: BoxFuture<'static, (RequestParts<ReqBody>, Result<E, E::Rejection>)> },
|
||||
Call { #[pin] future: S::Future },
|
||||
}
|
||||
}
|
||||
|
||||
impl<ReqBody, S, E, ResBody> Future for ExtractorMiddlewareResponseFuture<ReqBody, S, E>
|
||||
|
@ -221,14 +226,14 @@ where
|
|||
let mut this = self.as_mut().project();
|
||||
|
||||
let new_state = match this.state.as_mut().project() {
|
||||
StateProj::Extracting(future) => {
|
||||
StateProj::Extracting { future } => {
|
||||
let (mut req, extracted) = ready!(future.as_mut().poll(cx));
|
||||
|
||||
match extracted {
|
||||
Ok(_) => {
|
||||
let mut svc = this.svc.take().expect("future polled after completion");
|
||||
let future = svc.call(req.into_request());
|
||||
State::Call(future)
|
||||
State::Call { future }
|
||||
}
|
||||
Err(err) => {
|
||||
let res = err.into_response().map(crate::body::box_body);
|
||||
|
@ -236,7 +241,7 @@ where
|
|||
}
|
||||
}
|
||||
}
|
||||
StateProj::Call(future) => {
|
||||
StateProj::Call { future } => {
|
||||
return future
|
||||
.poll(cx)
|
||||
.map(|result| result.map(|response| response.map(crate::body::box_body)));
|
||||
|
|
|
@ -444,7 +444,7 @@ where
|
|||
let res = Handler::call(handler, req).await;
|
||||
Ok(res)
|
||||
});
|
||||
future::IntoServiceFuture(future)
|
||||
future::IntoServiceFuture { future }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -9,10 +9,12 @@ macro_rules! opaque_future {
|
|||
};
|
||||
|
||||
($(#[$m:meta])* pub type $name:ident<$($param:ident),*> = $actual:ty;) => {
|
||||
#[pin_project::pin_project]
|
||||
$(#[$m])*
|
||||
pub struct $name<$($param),*>(#[pin] pub(crate) $actual)
|
||||
where;
|
||||
pin_project_lite::pin_project! {
|
||||
$(#[$m])*
|
||||
pub struct $name<$($param),*> {
|
||||
#[pin] pub(crate) future: $actual,
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($param),*> std::fmt::Debug for $name<$($param),*> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
|
@ -27,7 +29,7 @@ macro_rules! opaque_future {
|
|||
type Output = <$actual as std::future::Future>::Output;
|
||||
#[inline]
|
||||
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
|
||||
self.project().0.poll(cx)
|
||||
self.project().future.poll(cx)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -11,7 +11,7 @@ use async_trait::async_trait;
|
|||
use bytes::Bytes;
|
||||
use futures_util::future;
|
||||
use http::{Method, Request, Response, StatusCode, Uri};
|
||||
use pin_project::pin_project;
|
||||
use pin_project_lite::pin_project;
|
||||
use regex::Regex;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
|
@ -385,13 +385,16 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// The response future for [`Route`].
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct RouteFuture<S, F, B>(#[pin] RouteFutureInner<S, F, B>)
|
||||
where
|
||||
S: Service<Request<B>>,
|
||||
F: Service<Request<B>>;
|
||||
pin_project! {
|
||||
/// The response future for [`Route`].
|
||||
#[derive(Debug)]
|
||||
pub struct RouteFuture<S, F, B>
|
||||
where
|
||||
S: Service<Request<B>>,
|
||||
F: Service<Request<B>> {
|
||||
#[pin] inner: RouteFutureInner<S, F, B>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, F, B> RouteFuture<S, F, B>
|
||||
where
|
||||
|
@ -399,23 +402,29 @@ where
|
|||
F: Service<Request<B>>,
|
||||
{
|
||||
pub(crate) fn a(a: Oneshot<S, Request<B>>) -> Self {
|
||||
RouteFuture(RouteFutureInner::A(a))
|
||||
RouteFuture {
|
||||
inner: RouteFutureInner::A { a },
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn b(b: Oneshot<F, Request<B>>) -> Self {
|
||||
RouteFuture(RouteFutureInner::B(b))
|
||||
RouteFuture {
|
||||
inner: RouteFutureInner::B { b },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project(project = RouteFutureInnerProj)]
|
||||
#[derive(Debug)]
|
||||
enum RouteFutureInner<S, F, B>
|
||||
where
|
||||
S: Service<Request<B>>,
|
||||
F: Service<Request<B>>,
|
||||
{
|
||||
A(#[pin] Oneshot<S, Request<B>>),
|
||||
B(#[pin] Oneshot<F, Request<B>>),
|
||||
pin_project! {
|
||||
#[project = RouteFutureInnerProj]
|
||||
#[derive(Debug)]
|
||||
enum RouteFutureInner<S, F, B>
|
||||
where
|
||||
S: Service<Request<B>>,
|
||||
F: Service<Request<B>>,
|
||||
{
|
||||
A { #[pin] a: Oneshot<S, Request<B>> },
|
||||
B { #[pin] b: Oneshot<F, Request<B>> },
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, F, B> Future for RouteFuture<S, F, B>
|
||||
|
@ -426,9 +435,9 @@ where
|
|||
type Output = Result<Response<BoxBody>, S::Error>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.project().0.project() {
|
||||
RouteFutureInnerProj::A(inner) => inner.poll(cx),
|
||||
RouteFutureInnerProj::B(inner) => inner.poll(cx),
|
||||
match self.project().inner.project() {
|
||||
RouteFutureInnerProj::A { a } => a.poll(cx),
|
||||
RouteFutureInnerProj::B { b } => b.poll(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -510,7 +519,9 @@ impl<B, E> Service<Request<B>> for EmptyRouter<E> {
|
|||
fn call(&mut self, _req: Request<B>) -> Self::Future {
|
||||
let mut res = Response::new(crate::body::empty());
|
||||
*res.status_mut() = self.status;
|
||||
EmptyRouterFuture(future::ok(res))
|
||||
EmptyRouterFuture {
|
||||
future: future::ok(res),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -655,15 +666,14 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// The response future for [`BoxRoute`].
|
||||
#[pin_project]
|
||||
pub struct BoxRouteFuture<B, E>
|
||||
where
|
||||
E: Into<BoxError>,
|
||||
{
|
||||
#[pin]
|
||||
inner:
|
||||
Oneshot<MpscBuffer<BoxService<Request<B>, Response<BoxBody>, E>, Request<B>>, Request<B>>,
|
||||
pin_project! {
|
||||
/// The response future for [`BoxRoute`].
|
||||
pub struct BoxRouteFuture<B, E>
|
||||
where
|
||||
E: Into<BoxError>,
|
||||
{
|
||||
#[pin] inner: Oneshot<MpscBuffer<BoxService<Request<B>, Response<BoxBody>, E>, Request<B>>, Request<B>>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, E> Future for BoxRouteFuture<B, E>
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::{
|
|||
use bytes::Bytes;
|
||||
use futures_util::ready;
|
||||
use http::Response;
|
||||
use pin_project::pin_project;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
|
@ -15,13 +15,14 @@ use std::{
|
|||
};
|
||||
use tower::BoxError;
|
||||
|
||||
/// Response future for [`HandleError`](super::HandleError).
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct HandleErrorFuture<Fut, F> {
|
||||
#[pin]
|
||||
pub(super) inner: Fut,
|
||||
pub(super) f: Option<F>,
|
||||
pin_project! {
|
||||
/// Response future for [`HandleError`](super::HandleError).
|
||||
#[derive(Debug)]
|
||||
pub struct HandleErrorFuture<Fut, F> {
|
||||
#[pin]
|
||||
pub(super) inner: Fut,
|
||||
pub(super) f: Option<F>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<Fut, F, E, E2, B, Res> Future for HandleErrorFuture<Fut, F>
|
||||
|
|
|
@ -94,7 +94,7 @@ use crate::{
|
|||
use bytes::Bytes;
|
||||
use futures_util::ready;
|
||||
use http::{Request, Response};
|
||||
use pin_project::pin_project;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
fmt,
|
||||
|
@ -645,14 +645,17 @@ where
|
|||
|
||||
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
||||
let fut = self.inner.clone().oneshot(req);
|
||||
BoxResponseBodyFuture(fut)
|
||||
BoxResponseBodyFuture { future: fut }
|
||||
}
|
||||
}
|
||||
|
||||
/// Response future for [`BoxResponseBody`].
|
||||
#[pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct BoxResponseBodyFuture<F>(#[pin] F);
|
||||
pin_project! {
|
||||
/// Response future for [`BoxResponseBody`].
|
||||
#[derive(Debug)]
|
||||
pub struct BoxResponseBodyFuture<F> {
|
||||
#[pin] future: F,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, B, E> Future for BoxResponseBodyFuture<F>
|
||||
where
|
||||
|
@ -663,7 +666,7 @@ where
|
|||
type Output = Result<Response<BoxBody>, E>;
|
||||
|
||||
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let res = ready!(self.project().0.poll(cx))?;
|
||||
let res = ready!(self.project().future.poll(cx))?;
|
||||
let res = res.map(box_body);
|
||||
Poll::Ready(Ok(res))
|
||||
}
|
||||
|
|
95
src/sse.rs
95
src/sse.rs
|
@ -81,7 +81,7 @@ use futures_util::{
|
|||
};
|
||||
use http::{Request, Response};
|
||||
use hyper::Body;
|
||||
use pin_project::pin_project;
|
||||
use pin_project_lite::pin_project;
|
||||
use serde::Serialize;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
|
@ -245,49 +245,51 @@ where
|
|||
let handler = self.handler.clone();
|
||||
let keep_alive = self.keep_alive.clone();
|
||||
|
||||
ResponseFuture(Box::pin(async move {
|
||||
let mut req = RequestParts::new(req);
|
||||
let input = match T::from_request(&mut req).await {
|
||||
Ok(input) => input,
|
||||
Err(err) => {
|
||||
return Ok(err.into_response().map(box_body));
|
||||
}
|
||||
};
|
||||
ResponseFuture {
|
||||
future: Box::pin(async move {
|
||||
let mut req = RequestParts::new(req);
|
||||
let input = match T::from_request(&mut req).await {
|
||||
Ok(input) => input,
|
||||
Err(err) => {
|
||||
return Ok(err.into_response().map(box_body));
|
||||
}
|
||||
};
|
||||
|
||||
let stream = match handler.call(input).await {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
return Ok(err.into_response().map(box_body));
|
||||
}
|
||||
};
|
||||
let stream = match handler.call(input).await {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
return Ok(err.into_response().map(box_body));
|
||||
}
|
||||
};
|
||||
|
||||
let stream = if let Some(keep_alive) = keep_alive {
|
||||
KeepAliveStream {
|
||||
event_stream: stream,
|
||||
comment_text: keep_alive.comment_text,
|
||||
max_interval: keep_alive.max_interval,
|
||||
alive_timer: tokio::time::sleep(keep_alive.max_interval),
|
||||
}
|
||||
.left_stream()
|
||||
} else {
|
||||
stream.into_stream().right_stream()
|
||||
};
|
||||
let stream = if let Some(keep_alive) = keep_alive {
|
||||
KeepAliveStream {
|
||||
event_stream: stream,
|
||||
comment_text: keep_alive.comment_text,
|
||||
max_interval: keep_alive.max_interval,
|
||||
alive_timer: tokio::time::sleep(keep_alive.max_interval),
|
||||
}
|
||||
.left_stream()
|
||||
} else {
|
||||
stream.into_stream().right_stream()
|
||||
};
|
||||
|
||||
let stream = stream
|
||||
.map_ok(|event| event.to_string())
|
||||
.map_err(|err| BoxStdError(err.into()))
|
||||
.into_stream();
|
||||
let stream = stream
|
||||
.map_ok(|event| event.to_string())
|
||||
.map_err(|err| BoxStdError(err.into()))
|
||||
.into_stream();
|
||||
|
||||
let body = box_body(Body::wrap_stream(stream));
|
||||
let body = box_body(Body::wrap_stream(stream));
|
||||
|
||||
let response = Response::builder()
|
||||
.header(http::header::CONTENT_TYPE, "text/event-stream")
|
||||
.header(http::header::CACHE_CONTROL, "no-cache")
|
||||
.body(body)
|
||||
.unwrap();
|
||||
let response = Response::builder()
|
||||
.header(http::header::CONTENT_TYPE, "text/event-stream")
|
||||
.header(http::header::CACHE_CONTROL, "no-cache")
|
||||
.body(body)
|
||||
.unwrap();
|
||||
|
||||
Ok(response)
|
||||
}))
|
||||
Ok(response)
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -483,14 +485,15 @@ impl Default for KeepAlive {
|
|||
}
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
struct KeepAliveStream<S> {
|
||||
#[pin]
|
||||
event_stream: S,
|
||||
comment_text: Cow<'static, str>,
|
||||
max_interval: Duration,
|
||||
#[pin]
|
||||
alive_timer: Sleep,
|
||||
pin_project! {
|
||||
struct KeepAliveStream<S> {
|
||||
#[pin]
|
||||
event_stream: S,
|
||||
comment_text: Cow<'static, str>,
|
||||
max_interval: Duration,
|
||||
#[pin]
|
||||
alive_timer: Sleep,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream for KeepAliveStream<S>
|
||||
|
|
196
src/ws/mod.rs
196
src/ws/mod.rs
|
@ -277,110 +277,112 @@ where
|
|||
let this = self.clone();
|
||||
let protocols = self.protocols.clone();
|
||||
|
||||
ResponseFuture(Box::pin(async move {
|
||||
if req.method() != http::Method::GET {
|
||||
return response(StatusCode::NOT_FOUND, "Request method must be `GET`");
|
||||
}
|
||||
ResponseFuture {
|
||||
future: Box::pin(async move {
|
||||
if req.method() != http::Method::GET {
|
||||
return response(StatusCode::NOT_FOUND, "Request method must be `GET`");
|
||||
}
|
||||
|
||||
if !header_contains(&req, header::CONNECTION, "upgrade") {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Connection header did not include 'upgrade'",
|
||||
);
|
||||
}
|
||||
if !header_contains(&req, header::CONNECTION, "upgrade") {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Connection header did not include 'upgrade'",
|
||||
);
|
||||
}
|
||||
|
||||
if !header_eq(&req, header::UPGRADE, "websocket") {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Upgrade` header did not include 'websocket'",
|
||||
);
|
||||
}
|
||||
if !header_eq(&req, header::UPGRADE, "websocket") {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Upgrade` header did not include 'websocket'",
|
||||
);
|
||||
}
|
||||
|
||||
if !header_eq(&req, header::SEC_WEBSOCKET_VERSION, "13") {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Version` header did not include '13'",
|
||||
);
|
||||
}
|
||||
if !header_eq(&req, header::SEC_WEBSOCKET_VERSION, "13") {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Version` header did not include '13'",
|
||||
);
|
||||
}
|
||||
|
||||
// check requested protocols
|
||||
let protocol =
|
||||
req.headers()
|
||||
.get(&header::SEC_WEBSOCKET_PROTOCOL)
|
||||
.and_then(|req_protocols| {
|
||||
let req_protocols = req_protocols.to_str().ok()?;
|
||||
req_protocols
|
||||
.split(',')
|
||||
.map(|req_p| req_p.trim())
|
||||
.find(|req_p| protocols.iter().any(|p| p == req_p))
|
||||
});
|
||||
let protocol = match protocol {
|
||||
Some(protocol) => {
|
||||
if let Ok(protocol) = HeaderValue::from_str(protocol) {
|
||||
Some(protocol)
|
||||
} else {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Protocol` header is invalid",
|
||||
);
|
||||
// check requested protocols
|
||||
let protocol =
|
||||
req.headers()
|
||||
.get(&header::SEC_WEBSOCKET_PROTOCOL)
|
||||
.and_then(|req_protocols| {
|
||||
let req_protocols = req_protocols.to_str().ok()?;
|
||||
req_protocols
|
||||
.split(',')
|
||||
.map(|req_p| req_p.trim())
|
||||
.find(|req_p| protocols.iter().any(|p| p == req_p))
|
||||
});
|
||||
let protocol = match protocol {
|
||||
Some(protocol) => {
|
||||
if let Ok(protocol) = HeaderValue::from_str(protocol) {
|
||||
Some(protocol)
|
||||
} else {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Protocol` header is invalid",
|
||||
);
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
|
||||
key
|
||||
} else {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Key` header missing",
|
||||
);
|
||||
};
|
||||
|
||||
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
|
||||
|
||||
let config = this.config;
|
||||
let callback = this.callback.clone();
|
||||
|
||||
let mut req = RequestParts::new(req);
|
||||
let input = match T::from_request(&mut req).await {
|
||||
Ok(input) => input,
|
||||
Err(rejection) => {
|
||||
let res = rejection.into_response().map(box_body);
|
||||
return Ok(res);
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
let upgraded = on_upgrade.await.unwrap();
|
||||
let socket = WebSocketStream::from_raw_socket(
|
||||
upgraded,
|
||||
protocol::Role::Server,
|
||||
Some(config),
|
||||
)
|
||||
.await;
|
||||
let socket = WebSocket { inner: socket };
|
||||
callback.call(socket, input).await;
|
||||
});
|
||||
|
||||
let mut builder = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(
|
||||
http::header::CONNECTION,
|
||||
HeaderValue::from_str("upgrade").unwrap(),
|
||||
)
|
||||
.header(
|
||||
http::header::UPGRADE,
|
||||
HeaderValue::from_str("websocket").unwrap(),
|
||||
)
|
||||
.header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()));
|
||||
if let Some(protocol) = protocol {
|
||||
builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
let res = builder.body(box_body(Full::new(Bytes::new()))).unwrap();
|
||||
|
||||
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
|
||||
key
|
||||
} else {
|
||||
return response(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"`Sec-Websocket-Key` header missing",
|
||||
);
|
||||
};
|
||||
|
||||
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
|
||||
|
||||
let config = this.config;
|
||||
let callback = this.callback.clone();
|
||||
|
||||
let mut req = RequestParts::new(req);
|
||||
let input = match T::from_request(&mut req).await {
|
||||
Ok(input) => input,
|
||||
Err(rejection) => {
|
||||
let res = rejection.into_response().map(box_body);
|
||||
return Ok(res);
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
let upgraded = on_upgrade.await.unwrap();
|
||||
let socket = WebSocketStream::from_raw_socket(
|
||||
upgraded,
|
||||
protocol::Role::Server,
|
||||
Some(config),
|
||||
)
|
||||
.await;
|
||||
let socket = WebSocket { inner: socket };
|
||||
callback.call(socket, input).await;
|
||||
});
|
||||
|
||||
let mut builder = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(
|
||||
http::header::CONNECTION,
|
||||
HeaderValue::from_str("upgrade").unwrap(),
|
||||
)
|
||||
.header(
|
||||
http::header::UPGRADE,
|
||||
HeaderValue::from_str("websocket").unwrap(),
|
||||
)
|
||||
.header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()));
|
||||
if let Some(protocol) = protocol {
|
||||
builder = builder.header(http::header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
||||
}
|
||||
let res = builder.body(box_body(Full::new(Bytes::new()))).unwrap();
|
||||
|
||||
Ok(res)
|
||||
}))
|
||||
Ok(res)
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue