Don't internally Arc the state (#1460)

This commit is contained in:
David Pedersen 2022-10-09 22:55:28 +02:00 committed by GitHub
parent a2ab338e68
commit f9dc96fdce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 206 additions and 185 deletions

View file

@ -6,7 +6,7 @@ use axum::{
response::{IntoResponse, Response},
};
use futures_util::future::{BoxFuture, FutureExt, Map};
use std::{future::Future, marker::PhantomData, sync::Arc};
use std::{future::Future, marker::PhantomData};
mod or;
@ -24,11 +24,7 @@ pub trait HandlerCallWithExtractors<T, S, B>: Sized {
type Future: Future<Output = Response> + Send + 'static;
/// Call the handler with the extracted inputs.
fn call(
self,
extractors: T,
state: Arc<S>,
) -> <Self as HandlerCallWithExtractors<T, S, B>>::Future;
fn call(self, extractors: T, state: S) -> <Self as HandlerCallWithExtractors<T, S, B>>::Future;
/// Conver this `HandlerCallWithExtractors` into [`Handler`].
fn into_handler(self) -> IntoHandler<Self, T, S, B> {
@ -133,7 +129,7 @@ macro_rules! impl_handler_call_with {
fn call(
self,
($($ty,)*): ($($ty,)*),
_state: Arc<S>,
_state: S,
) -> <Self as HandlerCallWithExtractors<($($ty,)*), S, B>>::Future {
self($($ty,)*).map(IntoResponse::into_response)
}
@ -178,7 +174,7 @@ where
{
type Future = BoxFuture<'static, Response>;
fn call(self, req: http::Request<B>, state: Arc<S>) -> Self::Future {
fn call(self, req: http::Request<B>, state: S) -> Self::Future {
Box::pin(async move {
match T::from_request(req, &state).await {
Ok(t) => self.handler.call(t, state).await,

View file

@ -7,7 +7,7 @@ use axum::{
response::{IntoResponse, Response},
};
use futures_util::future::{BoxFuture, Either as EitherFuture, FutureExt, Map};
use std::{future::Future, marker::PhantomData, sync::Arc};
use std::{future::Future, marker::PhantomData};
/// [`Handler`] that runs one [`Handler`] and if that rejects it'll fallback to another
/// [`Handler`].
@ -37,7 +37,7 @@ where
fn call(
self,
extractors: Either<Lt, Rt>,
state: Arc<S>,
state: S,
) -> <Self as HandlerCallWithExtractors<Either<Lt, Rt>, S, B>>::Future {
match extractors {
Either::E1(lt) => self
@ -68,7 +68,7 @@ where
// this puts `futures_util` in our public API but thats fine in axum-extra
type Future = BoxFuture<'static, Response>;
fn call(self, req: Request<B>, state: Arc<S>) -> Self::Future {
fn call(self, req: Request<B>, state: S) -> Self::Future {
Box::pin(async move {
let (mut parts, body) = req.into_parts();

View file

@ -178,7 +178,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
impl<S, B> RouterExt<S, B> for Router<S, B>
where
B: axum::body::HttpBody + Send + 'static,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
#[cfg(feature = "typed-routing")]
fn typed_get<H, T, P>(self, handler: H) -> Self

View file

@ -53,7 +53,7 @@ where
impl<S, B> Resource<S, B>
where
B: axum::body::HttpBody + Send + 'static,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
/// Create a `Resource` with the given name and state.
///

View file

@ -18,7 +18,7 @@ struct Extractor {
other: Query<HashMap<String, String>>,
}
#[derive(Default)]
#[derive(Default, Clone)]
struct AppState {
inner: InnerState,
}

View file

@ -1,4 +1,4 @@
use std::{convert::Infallible, sync::Arc};
use std::convert::Infallible;
use super::Handler;
use crate::routing::Route;
@ -7,7 +7,7 @@ pub(crate) struct BoxedHandler<S, B, E = Infallible>(Box<dyn ErasedHandler<S, B,
impl<S, B> BoxedHandler<S, B>
where
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
B: Send + 'static,
{
pub(crate) fn new<H, T>(handler: H) -> Self
@ -17,7 +17,7 @@ where
{
Self(Box::new(MakeErasedHandler {
handler,
into_route: |handler, state| Route::new(Handler::with_state_arc(handler, state)),
into_route: |handler, state| Route::new(Handler::with_state(handler, state)),
}))
}
}
@ -38,7 +38,7 @@ impl<S, B, E> BoxedHandler<S, B, E> {
}))
}
pub(crate) fn into_route(self, state: Arc<S>) -> Route<B, E> {
pub(crate) fn into_route(self, state: S) -> Route<B, E> {
self.0.into_route(state)
}
}
@ -51,12 +51,13 @@ impl<S, B, E> Clone for BoxedHandler<S, B, E> {
trait ErasedHandler<S, B, E = Infallible>: Send {
fn clone_box(&self) -> Box<dyn ErasedHandler<S, B, E>>;
fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B, E>;
fn into_route(self: Box<Self>, state: S) -> Route<B, E>;
}
struct MakeErasedHandler<H, S, B> {
handler: H,
into_route: fn(H, Arc<S>) -> Route<B>,
into_route: fn(H, S) -> Route<B>,
}
impl<H, S, B> ErasedHandler<S, B> for MakeErasedHandler<H, S, B>
@ -69,7 +70,7 @@ where
Box::new(self.clone())
}
fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B> {
fn into_route(self: Box<Self>, state: S) -> Route<B> {
(self.into_route)(self.handler, state)
}
}
@ -103,7 +104,7 @@ where
})
}
fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B2, E2> {
fn into_route(self: Box<Self>, state: S) -> Route<B2, E2> {
(self.layer)(self.handler.into_route(state))
}
}

View file

@ -0,0 +1,84 @@
use super::Handler;
use crate::response::Response;
use http::Request;
use std::{
convert::Infallible,
fmt,
marker::PhantomData,
task::{Context, Poll},
};
use tower_service::Service;
pub(crate) struct IntoServiceStateInExtension<H, T, S, B> {
handler: H,
_marker: PhantomData<fn() -> (T, S, B)>,
}
#[test]
fn traits() {
use crate::test_helpers::*;
assert_send::<IntoServiceStateInExtension<(), NotSendSync, (), NotSendSync>>();
assert_sync::<IntoServiceStateInExtension<(), NotSendSync, (), NotSendSync>>();
}
impl<H, T, S, B> IntoServiceStateInExtension<H, T, S, B> {
pub(crate) fn new(handler: H) -> Self {
Self {
handler,
_marker: PhantomData,
}
}
}
impl<H, T, S, B> fmt::Debug for IntoServiceStateInExtension<H, T, S, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IntoServiceStateInExtension")
.finish_non_exhaustive()
}
}
impl<H, T, S, B> Clone for IntoServiceStateInExtension<H, T, S, B>
where
H: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
_marker: PhantomData,
}
}
}
impl<H, T, S, B> Service<Request<B>> for IntoServiceStateInExtension<H, T, S, B>
where
H: Handler<T, S, B> + Clone + Send + 'static,
B: Send + 'static,
S: Send + Sync + 'static,
{
type Response = Response;
type Error = Infallible;
type Future = super::future::IntoServiceFuture<H::Future>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// `IntoServiceStateInExtension` can only be constructed from async functions which are always ready, or
// from `Layered` which bufferes in `<Layered as Handler>::call` and is therefore
// also always ready.
Poll::Ready(Ok(()))
}
fn call(&mut self, mut req: Request<B>) -> Self::Future {
use futures_util::future::FutureExt;
let state = req
.extensions_mut()
.remove::<S>()
.expect("state extension missing. This is a bug in axum, please file an issue");
let handler = self.handler.clone();
let future = Handler::call(handler, req, state);
let future = future.map(Ok as _);
super::future::IntoServiceFuture::new(future)
}
}

View file

@ -44,7 +44,7 @@ use crate::{
routing::IntoMakeService,
};
use http::Request;
use std::{convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin, sync::Arc};
use std::{convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin};
use tower::ServiceExt;
use tower_layer::Layer;
use tower_service::Service;
@ -101,7 +101,7 @@ pub trait Handler<T, S, B = Body>: Clone + Send + Sized + 'static {
type Future: Future<Output = Response> + Send + 'static;
/// Call the handler with the given request.
fn call(self, req: Request<B>, state: Arc<S>) -> Self::Future;
fn call(self, req: Request<B>, state: S) -> Self::Future;
/// Apply a [`tower::Layer`] to the handler.
///
@ -152,11 +152,6 @@ pub trait Handler<T, S, B = Body>: Clone + Send + Sized + 'static {
/// Convert the handler into a [`Service`] by providing the state
fn with_state(self, state: S) -> HandlerService<Self, T, S, B> {
self.with_state_arc(Arc::new(state))
}
/// Convert the handler into a [`Service`] by providing the state
fn with_state_arc(self, state: Arc<S>) -> HandlerService<Self, T, S, B> {
HandlerService::new(self, state)
}
}
@ -170,7 +165,7 @@ where
{
type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, _req: Request<B>, _state: Arc<S>) -> Self::Future {
fn call(self, _req: Request<B>, _state: S) -> Self::Future {
Box::pin(async move { self().await.into_response() })
}
}
@ -192,7 +187,7 @@ macro_rules! impl_handler {
{
type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, req: Request<B>, state: Arc<S>) -> Self::Future {
fn call(self, req: Request<B>, state: S) -> Self::Future {
Box::pin(async move {
let (mut parts, body) = req.into_parts();
let state = &state;
@ -269,10 +264,10 @@ where
{
type Future = future::LayeredFuture<B, L::Service>;
fn call(self, req: Request<B>, state: Arc<S>) -> Self::Future {
fn call(self, req: Request<B>, state: S) -> Self::Future {
use futures_util::future::{FutureExt, Map};
let svc = self.handler.with_state_arc(state);
let svc = self.handler.with_state(state);
let svc = self.layer.layer(svc);
let future: Map<

View file

@ -8,20 +8,18 @@ use std::{
convert::Infallible,
fmt,
marker::PhantomData,
sync::Arc,
task::{Context, Poll},
};
use tower_service::Service;
/// An adapter that makes a [`Handler`] into a [`Service`].
///
/// Created with [`Handler::with_state`], [`Handler::with_state_arc`] or
/// [`HandlerWithoutStateExt::into_service`].
/// Created with [`Handler::with_state`] or [`HandlerWithoutStateExt::into_service`].
///
/// [`HandlerWithoutStateExt::into_service`]: super::HandlerWithoutStateExt::into_service
pub struct HandlerService<H, T, S, B> {
handler: H,
state: Arc<S>,
state: S,
_marker: PhantomData<fn() -> (T, B)>,
}
@ -119,7 +117,7 @@ fn traits() {
}
impl<H, T, S, B> HandlerService<H, T, S, B> {
pub(super) fn new(handler: H, state: Arc<S>) -> Self {
pub(super) fn new(handler: H, state: S) -> Self {
Self {
handler,
state,
@ -137,11 +135,12 @@ impl<H, T, S, B> fmt::Debug for HandlerService<H, T, S, B> {
impl<H, T, S, B> Clone for HandlerService<H, T, S, B>
where
H: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
state: Arc::clone(&self.state),
state: self.state.clone(),
_marker: PhantomData,
}
}
@ -151,7 +150,7 @@ impl<H, T, S, B> Service<Request<B>> for HandlerService<H, T, S, B>
where
H: Handler<T, S, B> + Clone + Send + 'static,
B: Send + 'static,
S: Send + Sync,
S: Clone + Send + Sync,
{
type Response = Response;
type Error = Infallible;
@ -169,7 +168,7 @@ where
use futures_util::future::FutureExt;
let handler = self.handler.clone();
let future = Handler::call(handler, req, Arc::clone(&self.state));
let future = Handler::call(handler, req, self.state.clone());
let future = future.map(Ok as _);
super::future::IntoServiceFuture::new(future)

View file

@ -10,7 +10,6 @@ use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower_layer::Layer;
@ -99,13 +98,6 @@ pub fn from_extractor<E>() -> FromExtractorLayer<E, ()> {
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn from_extractor_with_state<E, S>(state: S) -> FromExtractorLayer<E, S> {
from_extractor_with_state_arc(Arc::new(state))
}
/// Create a middleware from an extractor with the given [`Arc`]'ed state.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn from_extractor_with_state_arc<E, S>(state: Arc<S>) -> FromExtractorLayer<E, S> {
FromExtractorLayer {
state,
_marker: PhantomData,
@ -119,14 +111,17 @@ pub fn from_extractor_with_state_arc<E, S>(state: Arc<S>) -> FromExtractorLayer<
///
/// [`Layer`]: tower::Layer
pub struct FromExtractorLayer<E, S> {
state: Arc<S>,
state: S,
_marker: PhantomData<fn() -> E>,
}
impl<E, S> Clone for FromExtractorLayer<E, S> {
impl<E, S> Clone for FromExtractorLayer<E, S>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
state: self.state.clone(),
_marker: PhantomData,
}
}
@ -144,13 +139,16 @@ where
}
}
impl<E, T, S> Layer<T> for FromExtractorLayer<E, S> {
impl<E, T, S> Layer<T> for FromExtractorLayer<E, S>
where
S: Clone,
{
type Service = FromExtractor<T, E, S>;
fn layer(&self, inner: T) -> Self::Service {
FromExtractor {
inner,
state: Arc::clone(&self.state),
state: self.state.clone(),
_extractor: PhantomData,
}
}
@ -161,7 +159,7 @@ impl<E, T, S> Layer<T> for FromExtractorLayer<E, S> {
/// See [`from_extractor`] for more details.
pub struct FromExtractor<T, E, S> {
inner: T,
state: Arc<S>,
state: S,
_extractor: PhantomData<fn() -> E>,
}
@ -175,11 +173,12 @@ fn traits() {
impl<T, E, S> Clone for FromExtractor<T, E, S>
where
T: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
state: Arc::clone(&self.state),
state: self.state.clone(),
_extractor: PhantomData,
}
}
@ -205,7 +204,7 @@ where
B: Default + Send + 'static,
T: Service<Request<B>> + Clone,
T::Response: IntoResponse,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
type Response = Response;
type Error = T::Error;
@ -217,7 +216,7 @@ where
}
fn call(&mut self, req: Request<B>) -> Self::Future {
let state = Arc::clone(&self.state);
let state = self.state.clone();
let extract_future = Box::pin(async move {
let (mut parts, body) = req.into_parts();
let extracted = E::from_request_parts(&mut parts, &state).await;

View file

@ -9,7 +9,6 @@ use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower::{util::BoxCloneService, ServiceBuilder};
@ -140,15 +139,6 @@ pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
/// # let app: Router<_> = app;
/// ```
pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
from_fn_with_state_arc(Arc::new(state), f)
}
/// Create a middleware from an async function with the given [`Arc`]'ed state.
///
/// See [`from_fn_with_state`] for an example.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn from_fn_with_state_arc<F, S, T>(state: Arc<S>, f: F) -> FromFnLayer<F, S, T> {
FromFnLayer {
f,
state,
@ -163,18 +153,19 @@ pub fn from_fn_with_state_arc<F, S, T>(state: Arc<S>, f: F) -> FromFnLayer<F, S,
/// Created with [`from_fn`]. See that function for more details.
pub struct FromFnLayer<F, S, T> {
f: F,
state: Arc<S>,
state: S,
_extractor: PhantomData<fn() -> T>,
}
impl<F, S, T> Clone for FromFnLayer<F, S, T>
where
F: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
state: Arc::clone(&self.state),
state: self.state.clone(),
_extractor: self._extractor,
}
}
@ -183,13 +174,14 @@ where
impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
where
F: Clone,
S: Clone,
{
type Service = FromFn<F, S, I, T>;
fn layer(&self, inner: I) -> Self::Service {
FromFn {
f: self.f.clone(),
state: Arc::clone(&self.state),
state: self.state.clone(),
inner,
_extractor: PhantomData,
}
@ -215,7 +207,7 @@ where
pub struct FromFn<F, S, I, T> {
f: F,
inner: I,
state: Arc<S>,
state: S,
_extractor: PhantomData<fn() -> T>,
}
@ -223,12 +215,13 @@ impl<F, S, I, T> Clone for FromFn<F, S, I, T>
where
F: Clone,
I: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
inner: self.inner.clone(),
state: Arc::clone(&self.state),
state: self.state.clone(),
_extractor: self._extractor,
}
}
@ -253,7 +246,7 @@ macro_rules! impl_service {
I::Response: IntoResponse,
I::Future: Send + 'static,
B: Send + 'static,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
type Response = Response;
type Error = Infallible;
@ -268,7 +261,7 @@ macro_rules! impl_service {
let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
let mut f = self.f.clone();
let state = Arc::clone(&self.state);
let state = self.state.clone();
let future = Box::pin(async move {
let (mut parts, body) = req.into_parts();

View file

@ -9,7 +9,6 @@ use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower_layer::Layer;
@ -112,7 +111,7 @@ use tower_service::Service;
/// # let _: Router = app;
/// ```
///
/// Note that to access state you must use either [`map_request_with_state`] or [`map_request_with_state_arc`].
/// Note that to access state you must use either [`map_request_with_state`].
pub fn map_request<F, T>(f: F) -> MapRequestLayer<F, (), T> {
map_request_with_state((), f)
}
@ -155,16 +154,6 @@ pub fn map_request<F, T>(f: F) -> MapRequestLayer<F, (), T> {
/// # let app: Router<_> = app;
/// ```
pub fn map_request_with_state<F, S, T>(state: S, f: F) -> MapRequestLayer<F, S, T> {
map_request_with_state_arc(Arc::new(state), f)
}
/// Create a middleware from an async function that transforms a request, with the given [`Arc`]'ed
/// state.
///
/// See [`map_request_with_state`] for an example.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn map_request_with_state_arc<F, S, T>(state: Arc<S>, f: F) -> MapRequestLayer<F, S, T> {
MapRequestLayer {
f,
state,
@ -177,18 +166,19 @@ pub fn map_request_with_state_arc<F, S, T>(state: Arc<S>, f: F) -> MapRequestLay
/// Created with [`map_request`]. See that function for more details.
pub struct MapRequestLayer<F, S, T> {
f: F,
state: Arc<S>,
state: S,
_extractor: PhantomData<fn() -> T>,
}
impl<F, S, T> Clone for MapRequestLayer<F, S, T>
where
F: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
state: Arc::clone(&self.state),
state: self.state.clone(),
_extractor: self._extractor,
}
}
@ -197,13 +187,14 @@ where
impl<S, I, F, T> Layer<I> for MapRequestLayer<F, S, T>
where
F: Clone,
S: Clone,
{
type Service = MapRequest<F, S, I, T>;
fn layer(&self, inner: I) -> Self::Service {
MapRequest {
f: self.f.clone(),
state: Arc::clone(&self.state),
state: self.state.clone(),
inner,
_extractor: PhantomData,
}
@ -229,7 +220,7 @@ where
pub struct MapRequest<F, S, I, T> {
f: F,
inner: I,
state: Arc<S>,
state: S,
_extractor: PhantomData<fn() -> T>,
}
@ -237,12 +228,13 @@ impl<F, S, I, T> Clone for MapRequest<F, S, I, T>
where
F: Clone,
I: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
inner: self.inner.clone(),
state: Arc::clone(&self.state),
state: self.state.clone(),
_extractor: self._extractor,
}
}
@ -267,7 +259,7 @@ macro_rules! impl_service {
I::Response: IntoResponse,
I::Future: Send + 'static,
B: Send + 'static,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
type Response = Response;
type Error = Infallible;
@ -282,7 +274,7 @@ macro_rules! impl_service {
let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
let mut f = self.f.clone();
let state = Arc::clone(&self.state);
let state = self.state.clone();
let future = Box::pin(async move {
let (mut parts, body) = req.into_parts();
@ -363,7 +355,7 @@ mod private {
}
/// Trait implemented by types that can be returned from [`map_request`],
/// [`map_request_with_state`], and [`map_request_with_state_arc`].
/// [`map_request_with_state`].
///
/// This trait is sealed such that it cannot be implemented outside this crate.
pub trait IntoMapRequestResult<B>: private::Sealed<B> {

View file

@ -9,7 +9,6 @@ use std::{
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower_layer::Layer;
@ -67,7 +66,7 @@ use tower_service::Service;
/// # let _: Router = app;
/// ```
///
/// Note that to access state you must use either [`map_response_with_state`] or [`map_response_with_state_arc`].
/// Note that to access state you must use either [`map_response_with_state`].
pub fn map_response<F, T>(f: F) -> MapResponseLayer<F, (), T> {
map_response_with_state((), f)
}
@ -110,16 +109,6 @@ pub fn map_response<F, T>(f: F) -> MapResponseLayer<F, (), T> {
/// # let app: Router<_> = app;
/// ```
pub fn map_response_with_state<F, S, T>(state: S, f: F) -> MapResponseLayer<F, S, T> {
map_response_with_state_arc(Arc::new(state), f)
}
/// Create a middleware from an async function that transforms a response, with the given [`Arc`]'ed
/// state.
///
/// See [`map_response_with_state`] for an example.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn map_response_with_state_arc<F, S, T>(state: Arc<S>, f: F) -> MapResponseLayer<F, S, T> {
MapResponseLayer {
f,
state,
@ -132,18 +121,19 @@ pub fn map_response_with_state_arc<F, S, T>(state: Arc<S>, f: F) -> MapResponseL
/// Created with [`map_response`]. See that function for more details.
pub struct MapResponseLayer<F, S, T> {
f: F,
state: Arc<S>,
state: S,
_extractor: PhantomData<fn() -> T>,
}
impl<F, S, T> Clone for MapResponseLayer<F, S, T>
where
F: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
state: Arc::clone(&self.state),
state: self.state.clone(),
_extractor: self._extractor,
}
}
@ -152,13 +142,14 @@ where
impl<S, I, F, T> Layer<I> for MapResponseLayer<F, S, T>
where
F: Clone,
S: Clone,
{
type Service = MapResponse<F, S, I, T>;
fn layer(&self, inner: I) -> Self::Service {
MapResponse {
f: self.f.clone(),
state: Arc::clone(&self.state),
state: self.state.clone(),
inner,
_extractor: PhantomData,
}
@ -184,7 +175,7 @@ where
pub struct MapResponse<F, S, I, T> {
f: F,
inner: I,
state: Arc<S>,
state: S,
_extractor: PhantomData<fn() -> T>,
}
@ -192,12 +183,13 @@ impl<F, S, I, T> Clone for MapResponse<F, S, I, T>
where
F: Clone,
I: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
inner: self.inner.clone(),
state: Arc::clone(&self.state),
state: self.state.clone(),
_extractor: self._extractor,
}
}
@ -221,7 +213,7 @@ macro_rules! impl_service {
I::Future: Send + 'static,
B: Send + 'static,
ResBody: Send + 'static,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
type Response = Response;
type Error = Infallible;
@ -237,7 +229,7 @@ macro_rules! impl_service {
let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
let mut f = self.f.clone();
let _state = Arc::clone(&self.state);
let _state = self.state.clone();
let future = Box::pin(async move {
let (mut parts, body) = req.into_parts();

View file

@ -8,19 +8,14 @@ mod map_request;
mod map_response;
pub use self::from_extractor::{
from_extractor, from_extractor_with_state, from_extractor_with_state_arc, FromExtractor,
FromExtractorLayer,
};
pub use self::from_fn::{
from_fn, from_fn_with_state, from_fn_with_state_arc, FromFn, FromFnLayer, Next,
from_extractor, from_extractor_with_state, FromExtractor, FromExtractorLayer,
};
pub use self::from_fn::{from_fn, from_fn_with_state, FromFn, FromFnLayer, Next};
pub use self::map_request::{
map_request, map_request_with_state, map_request_with_state_arc, IntoMapRequestResult,
MapRequest, MapRequestLayer,
map_request, map_request_with_state, IntoMapRequestResult, MapRequest, MapRequestLayer,
};
pub use self::map_response::{
map_response, map_response_with_state, map_response_with_state_arc, MapResponse,
MapResponseLayer,
map_response, map_response_with_state, MapResponse, MapResponseLayer,
};
pub use crate::extension::AddExtension;

View file

@ -17,7 +17,6 @@ use bytes::BytesMut;
use std::{
convert::Infallible,
fmt,
sync::Arc,
task::{Context, Poll},
};
use tower::{service_fn, util::MapResponseLayer};
@ -85,6 +84,7 @@ macro_rules! top_level_service_fn {
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
B: Send + 'static,
S: Clone,
{
on_service(MethodFilter::$method, svc)
}
@ -145,7 +145,7 @@ macro_rules! top_level_handler_fn {
H: Handler<T, S, B>,
B: Send + 'static,
T: 'static,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
on(MethodFilter::$method, handler)
}
@ -328,6 +328,7 @@ where
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
B: Send + 'static,
S: Clone,
{
MethodRouter::new().on_service(filter, svc)
}
@ -391,6 +392,7 @@ where
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
B: Send + 'static,
S: Clone,
{
MethodRouter::new()
.fallback_service(svc)
@ -430,7 +432,7 @@ where
H: Handler<T, S, B>,
B: Send + 'static,
T: 'static,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
MethodRouter::new().on(filter, handler)
}
@ -477,7 +479,7 @@ where
H: Handler<T, S, B>,
B: Send + 'static,
T: 'static,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
MethodRouter::new().fallback(handler).skip_allow_header()
}
@ -570,6 +572,7 @@ impl<S, B, E> fmt::Debug for MethodRouter<S, B, E> {
impl<S, B> MethodRouter<S, B, Infallible>
where
B: Send + 'static,
S: Clone,
{
/// Chain an additional handler that will accept requests matching the given
/// `MethodFilter`.
@ -705,6 +708,7 @@ where
impl<S, B, E> MethodRouter<S, B, E>
where
B: Send + 'static,
S: Clone,
{
/// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all
/// requests.
@ -731,13 +735,6 @@ where
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn with_state(self, state: S) -> WithState<B, E> {
self.with_state_arc(Arc::new(state))
}
/// Provide the [`Arc`]'ed state.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn with_state_arc(self, state: Arc<S>) -> WithState<B, E> {
WithState {
get: self.get.into_route(&state),
head: self.head.into_route(&state),
@ -752,7 +749,7 @@ where
}
}
pub(crate) fn map_state<S2>(self, state: &Arc<S>) -> MethodRouter<S2, B, E>
pub(crate) fn map_state<S2>(self, state: &S) -> MethodRouter<S2, B, E>
where
E: 'static,
S: 'static,
@ -841,6 +838,7 @@ where
methods: &[&'static str],
) where
MethodEndpoint<S, B, E>: Clone,
S: Clone,
{
if endpoint_filter.contains(filter) {
if out.is_some() {
@ -1174,6 +1172,7 @@ impl<S, B, E> Clone for MethodRouter<S, B, E> {
impl<S, B, E> Default for MethodRouter<S, B, E>
where
B: Send + 'static,
S: Clone,
{
fn default() -> Self {
Self::new()
@ -1186,7 +1185,10 @@ enum MethodEndpoint<S, B, E> {
BoxedHandler(BoxedHandler<S, B, E>),
}
impl<S, B, E> MethodEndpoint<S, B, E> {
impl<S, B, E> MethodEndpoint<S, B, E>
where
S: Clone,
{
fn is_some(&self) -> bool {
matches!(self, Self::Route(_) | Self::BoxedHandler(_))
}
@ -1211,7 +1213,7 @@ impl<S, B, E> MethodEndpoint<S, B, E> {
}
}
fn map_state<S2>(self, state: &Arc<S>) -> MethodEndpoint<S2, B, E> {
fn map_state<S2>(self, state: &S) -> MethodEndpoint<S2, B, E> {
match self {
Self::None => MethodEndpoint::None,
Self::Route(route) => MethodEndpoint::Route(route),
@ -1237,7 +1239,7 @@ impl<S, B, E> MethodEndpoint<S, B, E> {
}
}
fn into_route(self, state: &Arc<S>) -> Option<Route<B, E>> {
fn into_route(self, state: &S) -> Option<Route<B, E>> {
match self {
Self::None => None,
Self::Route(route) => Some(route),

View file

@ -65,13 +65,16 @@ impl RouteId {
/// The router type for composing handlers and services.
pub struct Router<S = (), B = Body> {
state: Option<Arc<S>>,
state: Option<S>,
routes: HashMap<RouteId, Endpoint<S, B>>,
node: Arc<Node>,
fallback: Fallback<S, B>,
}
impl<S, B> Clone for Router<S, B> {
impl<S, B> Clone for Router<S, B>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
@ -85,7 +88,7 @@ impl<S, B> Clone for Router<S, B> {
impl<S, B> Default for Router<S, B>
where
B: HttpBody + Send + 'static,
S: Default + Send + Sync + 'static,
S: Default + Clone + Send + Sync + 'static,
{
fn default() -> Self {
Self::with_state(S::default())
@ -125,7 +128,7 @@ where
impl<S, B> Router<S, B>
where
B: HttpBody + Send + 'static,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
/// Create a new `Router` with the given state.
///
@ -134,39 +137,6 @@ where
/// Unless you add additional routes this will respond with `404 Not Found` to
/// all requests.
pub fn with_state(state: S) -> Self {
Self::with_state_arc(Arc::new(state))
}
/// Create a new `Router` with the given [`Arc`]'ed state.
///
/// See [`State`] for more details about accessing state.
///
/// Unless you add additional routes this will respond with `404 Not Found` to
/// all requests.
///
/// Note that the state type you extract with [`State`] must implement [`FromRef<S>`]. If
/// you're extracting `S` itself that requires `S` to implement `Clone`. That is still the
/// case, even if you're using this method:
///
/// ```
/// use axum::{Router, routing::get, extract::State};
/// use std::sync::Arc;
///
/// // `AppState` must implement `Clone` to be extracted...
/// #[derive(Clone)]
/// struct AppState {}
///
/// // ...even though we're wrapping it an an `Arc`
/// let state = Arc::new(AppState {});
///
/// let app: Router<AppState> = Router::with_state_arc(state).route("/", get(handler));
///
/// async fn handler(state: State<AppState>) {}
/// ```
///
/// [`FromRef<S>`]: crate::extract::FromRef
/// [`State`]: crate::extract::State
pub fn with_state_arc(state: Arc<S>) -> Self {
Self {
state: Some(state),
routes: Default::default(),
@ -272,11 +242,11 @@ where
#[track_caller]
pub fn nest<S2>(self, path: &str, mut router: Router<S2, B>) -> Self
where
S2: Send + Sync + 'static,
S2: Clone + Send + Sync + 'static,
{
if router.state.is_none() {
let s = self.state.clone();
router.state = match try_downcast::<Option<Arc<S2>>, Option<Arc<S>>>(s) {
router.state = match try_downcast::<Option<S2>, Option<S>>(s) {
Ok(state) => state,
Err(_) => panic!(
"can't nest a `Router` that wants to inherit state of type `{}` \
@ -335,7 +305,7 @@ where
pub fn merge<S2, R>(mut self, other: R) -> Self
where
R: Into<Router<S2, B>>,
S2: Send + Sync + 'static,
S2: Clone + Send + Sync + 'static,
{
let Router {
state,
@ -373,7 +343,7 @@ where
where
B: Send + 'static,
S: 'static,
S2: 'static,
S2: Clone + 'static,
{
r.downcast_state().unwrap()
}
@ -600,8 +570,11 @@ enum Fallback<S, B, E = Infallible> {
BoxedHandler(BoxedHandler<S, B, E>),
}
impl<S, B, E> Fallback<S, B, E> {
fn map_state<S2>(self, state: &Arc<S>) -> Fallback<S2, B, E> {
impl<S, B, E> Fallback<S, B, E>
where
S: Clone,
{
fn map_state<S2>(self, state: &S) -> Fallback<S2, B, E> {
match self {
Self::Default(route) => Fallback::Default(route),
Self::Service(route) => Fallback::Service(route),
@ -635,7 +608,7 @@ impl<S, B, E> Fallback<S, B, E> {
}
}
fn into_route(self, state: &Arc<S>) -> Route<B, E> {
fn into_route(self, state: &S) -> Route<B, E> {
match self {
Self::Default(route) => route,
Self::Service(route) => route,

View file

@ -33,7 +33,7 @@ where
#[track_caller]
pub(super) fn new<S>(router: Router<S, B>) -> Self
where
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
let state = router
.state
@ -45,7 +45,7 @@ where
.map(|(route_id, endpoint)| {
let route = match endpoint {
Endpoint::MethodRouter(method_router) => {
Route::new(method_router.with_state_arc(Arc::clone(&state)))
Route::new(method_router.with_state(state.clone()))
}
Endpoint::Route(route) => route,
};

View file

@ -17,7 +17,7 @@ pub(crate) struct TestClient {
impl TestClient {
pub(crate) fn new<S>(router: Router<S, Body>) -> Self
where
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
Self::from_service(router.into_service())
}