diff --git a/axum-core/src/layer.rs b/axum-core/src/layer.rs new file mode 100644 index 00000000..f8681dfb --- /dev/null +++ b/axum-core/src/layer.rs @@ -0,0 +1,5 @@ +trait Layer { + type Service; + + fn layer(&self, inner: S) -> Self::Service; +} diff --git a/axum-core/src/lib.rs b/axum-core/src/lib.rs index 974e5e18..ee3426a5 100644 --- a/axum-core/src/lib.rs +++ b/axum-core/src/lib.rs @@ -48,13 +48,19 @@ #![allow(elided_lifetimes_in_paths, clippy::type_complexity)] #![forbid(unsafe_code)] #![cfg_attr(test, allow(clippy::float_cmp))] +#![feature(type_alias_impl_trait)] #[macro_use] pub(crate) mod macros; mod error; mod ext_traits; +//mod layer; +mod service; + pub use self::error::Error; +pub use self::ext_traits::{request::RequestExt, request_parts::RequestPartsExt}; +pub use self::service::Service; pub mod body; pub mod extract; @@ -62,5 +68,3 @@ pub mod response; /// Alias for a type-erased error type. pub type BoxError = Box; - -pub use self::ext_traits::{request::RequestExt, request_parts::RequestPartsExt}; diff --git a/axum-core/src/service.rs b/axum-core/src/service.rs new file mode 100644 index 00000000..f6fbfd79 --- /dev/null +++ b/axum-core/src/service.rs @@ -0,0 +1,39 @@ +#![allow(missing_docs)] // temporary + +use http::Request; +use std::{ + convert::Infallible, + future::Future, + task::{Context, Poll}, +}; +use tower_service::Service as TowerService; + +use crate::response::{IntoResponse, Response}; + +pub trait Service { + type Future: Future; + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>; + fn call(&mut self, req: Request, state: &S) -> Self::Future; +} + +impl Service for T +where + T: TowerService, Response = Resp, Error = Infallible>, + Resp: IntoResponse, +{ + type Future = impl Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { + TowerService::poll_ready(self, cx).map(|result| result.unwrap_or_else(|e| match e {})) + } + + fn call(&mut self, req: Request, _state: &S) -> Self::Future { + let fut = TowerService::call(self, req); + async move { + match fut.await { + Ok(res) => res.into_response(), + Err(e) => match e {}, + } + } + } +} diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index d80f2507..08907bb8 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -145,7 +145,7 @@ pub trait Handler: Clone + Send + Sized + 'static { fn layer(self, layer: L) -> Layered where L: Layer> + Clone, - L::Service: Service>, + L::Service: crate::Service, { Layered { layer, diff --git a/axum/src/lib.rs b/axum/src/lib.rs index 314f844f..c625181e 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -487,7 +487,7 @@ pub use self::typed_header::TypedHeader; pub use self::form::Form; #[doc(inline)] -pub use axum_core::{BoxError, Error, RequestExt, RequestPartsExt}; +pub use axum_core::{BoxError, Error, RequestExt, RequestPartsExt, Service}; #[cfg(feature = "macros")] pub use axum_macros::debug_handler; diff --git a/axum/src/middleware/from_extractor.rs b/axum/src/middleware/from_extractor.rs index a6822af8..542f20a2 100644 --- a/axum/src/middleware/from_extractor.rs +++ b/axum/src/middleware/from_extractor.rs @@ -1,6 +1,7 @@ use crate::{ extract::FromRequestParts, response::{IntoResponse, Response}, + Service, }; use futures_util::{future::BoxFuture, ready}; use http::Request; @@ -13,7 +14,6 @@ use std::{ task::{Context, Poll}, }; use tower_layer::Layer; -use tower_service::Service; /// Create a middleware from an extractor. /// @@ -90,16 +90,8 @@ use tower_service::Service; /// ``` /// /// [`Bytes`]: bytes::Bytes -pub fn from_extractor() -> FromExtractorLayer { - from_extractor_with_state(()) -} - -/// Create a middleware from an extractor with the given state. -/// -/// See [`State`](crate::extract::State) for more details about accessing state. -pub fn from_extractor_with_state(state: S) -> FromExtractorLayer { +pub fn from_extractor() -> FromExtractorLayer { FromExtractorLayer { - state, _marker: PhantomData, } } @@ -110,45 +102,32 @@ pub fn from_extractor_with_state(state: S) -> FromExtractorLayer { /// See [`from_extractor`] for more details. /// /// [`Layer`]: tower::Layer -pub struct FromExtractorLayer { - state: S, +pub struct FromExtractorLayer { _marker: PhantomData E>, } -impl Clone for FromExtractorLayer -where - S: Clone, -{ +impl Clone for FromExtractorLayer { fn clone(&self) -> Self { Self { - state: self.state.clone(), _marker: PhantomData, } } } -impl fmt::Debug for FromExtractorLayer -where - S: fmt::Debug, -{ +impl fmt::Debug for FromExtractorLayer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromExtractorLayer") - .field("state", &self.state) .field("extractor", &format_args!("{}", std::any::type_name::())) .finish() } } -impl Layer for FromExtractorLayer -where - S: Clone, -{ - type Service = FromExtractor; +impl Layer for FromExtractorLayer { + type Service = FromExtractor; fn layer(&self, inner: T) -> Self::Service { FromExtractor { inner, - state: self.state.clone(), _extractor: PhantomData, } } @@ -157,66 +136,58 @@ where /// Middleware that runs an extractor and discards the value. /// /// See [`from_extractor`] for more details. -pub struct FromExtractor { +pub struct FromExtractor { inner: T, - state: S, _extractor: PhantomData E>, } #[test] fn traits() { use crate::test_helpers::*; - assert_send::>(); - assert_sync::>(); + assert_send::>(); + assert_sync::>(); } -impl Clone for FromExtractor +impl Clone for FromExtractor where T: Clone, - S: Clone, { fn clone(&self) -> Self { Self { inner: self.inner.clone(), - state: self.state.clone(), _extractor: PhantomData, } } } -impl fmt::Debug for FromExtractor +impl fmt::Debug for FromExtractor where T: fmt::Debug, - S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromExtractor") .field("inner", &self.inner) - .field("state", &self.state) .field("extractor", &format_args!("{}", std::any::type_name::())) .finish() } } -impl Service> for FromExtractor +impl Service for FromExtractor where E: FromRequestParts + 'static, B: Send + 'static, - T: Service> + Clone, - T::Response: IntoResponse, + T: Service + Clone, S: Clone + Send + Sync + 'static, { - type Response = Response; - type Error = T::Error; type Future = ResponseFuture; #[inline] - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { self.inner.poll_ready(cx) } - fn call(&mut self, req: Request) -> Self::Future { - let state = self.state.clone(); + fn call(&mut self, req: Request, st: &S) -> Self::Future { + let state = st.to_owned(); let extract_future = Box::pin(async move { let (mut parts, body) = req.into_parts(); let extracted = E::from_request_parts(&mut parts, &state).await; @@ -226,6 +197,7 @@ where ResponseFuture { state: State::Extracting { + st: st.clone(), future: extract_future, }, svc: Some(self.inner.clone()), @@ -239,7 +211,7 @@ pin_project! { pub struct ResponseFuture where E: FromRequestParts, - T: Service>, + T: Service, { #[pin] state: State, @@ -252,9 +224,10 @@ pin_project! { enum State where E: FromRequestParts, - T: Service>, + T: Service, { Extracting { + st: S, future: BoxFuture<'static, (Request, Result)>, }, Call { #[pin] future: T::Future }, @@ -264,35 +237,32 @@ pin_project! { impl Future for ResponseFuture where E: FromRequestParts, - T: Service>, - T::Response: IntoResponse, + T: Service, { - type Output = Result; + type Output = Response; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { let mut this = self.as_mut().project(); let new_state = match this.state.as_mut().project() { - StateProj::Extracting { future } => { + StateProj::Extracting { future, st } => { let (req, extracted) = ready!(future.as_mut().poll(cx)); match extracted { Ok(_) => { let mut svc = this.svc.take().expect("future polled after completion"); - let future = svc.call(req); + let future = svc.call(req, st); State::Call { future } } Err(err) => { let res = err.into_response(); - return Poll::Ready(Ok(res)); + return Poll::Ready(res); } } } StateProj::Call { future } => { - return future - .poll(cx) - .map(|result| result.map(IntoResponse::into_response)); + return future.poll(cx); } }; @@ -346,10 +316,7 @@ mod tests { async fn handler() {} let state = Secret("secret"); - let app = Router::new().route( - "/", - get(handler.layer(from_extractor_with_state::(state))), - ); + let app = Router::new().route("/", get(handler.layer(from_extractor()))); let client = TestClient::new(app); diff --git a/axum/src/middleware/mod.rs b/axum/src/middleware/mod.rs index 22dab143..5e75bb53 100644 --- a/axum/src/middleware/mod.rs +++ b/axum/src/middleware/mod.rs @@ -7,9 +7,7 @@ mod from_fn; mod map_request; mod map_response; -pub use self::from_extractor::{ - from_extractor, from_extractor_with_state, FromExtractor, FromExtractorLayer, -}; +pub use self::from_extractor::{from_extractor, FromExtractor, FromExtractorLayer}; pub use self::from_fn::{from_fn, from_fn_with_state, FromFn, FromFnLayer, Next}; pub use self::map_request::{ map_request, map_request_with_state, IntoMapRequestResult, MapRequest, MapRequestLayer, diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 4e2dae35..0c476181 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -919,10 +919,8 @@ where pub fn layer(self, layer: L) -> MethodRouter where L: Layer> + Clone + Send + 'static, - L::Service: Service> + Clone + Send + 'static, - >>::Response: IntoResponse + 'static, - >>::Error: Into + 'static, - >>::Future: Send + 'static, + L::Service: crate::Service + Clone + Send + 'static, + >::Future: Send + 'static, E: 'static, S: 'static, NewReqBody: HttpBody + 'static, diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index e7b717ea..8d489f56 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -314,10 +314,8 @@ where pub fn layer(self, layer: L) -> Router where L: Layer> + Clone + Send + 'static, - L::Service: Service> + Clone + Send + 'static, - >>::Response: IntoResponse + 'static, - >>::Error: Into + 'static, - >>::Future: Send + 'static, + L::Service: crate::Service + Clone + Send + 'static, + >::Future: Send + 'static, NewReqBody: HttpBody + 'static, { let routes = self @@ -688,10 +686,8 @@ where fn layer(self, layer: L) -> Endpoint where L: Layer> + Clone + Send + 'static, - L::Service: Service> + Clone + Send + 'static, - >>::Response: IntoResponse + 'static, - >>::Error: Into + 'static, - >>::Future: Send + 'static, + L::Service: crate::Service + Clone + Send + 'static, + >::Future: Send + 'static, NewReqBody: HttpBody + 'static, { match self { diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 25c8b859..b498ec87 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -51,10 +51,8 @@ impl Route { pub(crate) fn layer(self, layer: L) -> Route where L: Layer> + Clone + Send + 'static, - L::Service: Service> + Clone + Send + 'static, - >>::Response: IntoResponse + 'static, - >>::Error: Into + 'static, - >>::Future: Send + 'static, + L::Service: crate::Service + Clone + Send + 'static, + >::Future: Send + 'static, NewReqBody: 'static, NewError: 'static, {