Always store state in an Arc (#1270)

* Add extension and state benchmarks

* wip

* Arc the state everywhere

* don't require `S: Clone`

* fix example
This commit is contained in:
David Pedersen 2022-08-17 22:08:24 +02:00 committed by GitHub
parent 423308de3c
commit e7f1c88cd4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
56 changed files with 199 additions and 141 deletions

View file

@ -8,7 +8,7 @@ use self::rejection::*;
use crate::response::IntoResponse;
use async_trait::async_trait;
use http::{Extensions, HeaderMap, Method, Request, Uri, Version};
use std::convert::Infallible;
use std::{convert::Infallible, sync::Arc};
pub mod rejection;
@ -49,7 +49,7 @@ pub use self::from_ref::FromRef;
/// where
/// // these bounds are required by `async_trait`
/// B: Send,
/// S: Send,
/// S: Send + Sync,
/// {
/// type Rejection = http::StatusCode;
///
@ -79,7 +79,7 @@ pub trait FromRequest<S, B>: Sized {
/// Has several convenience methods for getting owned parts of the request.
#[derive(Debug)]
pub struct RequestParts<S, B> {
state: S,
pub(crate) state: Arc<S>,
method: Method,
uri: Uri,
version: Version,
@ -110,6 +110,17 @@ impl<S, B> RequestParts<S, B> {
///
/// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html
pub fn with_state(state: S, req: Request<B>) -> Self {
Self::with_state_arc(Arc::new(state), req)
}
/// Create a new `RequestParts` with the given [`Arc`]'ed state.
///
/// You generally shouldn't need to construct this type yourself, unless
/// using extractors outside of axum for example to implement a
/// [`tower::Service`].
///
/// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html
pub fn with_state_arc(state: Arc<S>, req: Request<B>) -> Self {
let (
http::request::Parts {
method,
@ -153,7 +164,7 @@ impl<S, B> RequestParts<S, B> {
/// impl<S, B> FromRequest<S, B> for MyExtractor
/// where
/// B: Send,
/// S: Send,
/// S: Send + Sync,
/// {
/// type Rejection = Infallible;
///
@ -285,7 +296,7 @@ impl<S, T, B> FromRequest<S, B> for Option<T>
where
T: FromRequest<S, B>,
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Infallible;
@ -299,7 +310,7 @@ impl<S, T, B> FromRequest<S, B> for Result<T, T::Rejection>
where
T: FromRequest<S, B>,
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Infallible;

View file

@ -3,13 +3,13 @@ use crate::BoxError;
use async_trait::async_trait;
use bytes::Bytes;
use http::{Extensions, HeaderMap, Method, Request, Uri, Version};
use std::convert::Infallible;
use std::{convert::Infallible, sync::Arc};
#[async_trait]
impl<S, B> FromRequest<S, B> for Request<B>
where
B: Send,
S: Clone + Send,
S: Send + Sync,
{
type Rejection = BodyAlreadyExtracted;
@ -17,7 +17,7 @@ where
let req = std::mem::replace(
req,
RequestParts {
state: req.state().clone(),
state: Arc::clone(&req.state),
method: req.method.clone(),
version: req.version,
uri: req.uri.clone(),
@ -35,7 +35,7 @@ where
impl<S, B> FromRequest<S, B> for Method
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Infallible;
@ -48,7 +48,7 @@ where
impl<S, B> FromRequest<S, B> for Uri
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Infallible;
@ -61,7 +61,7 @@ where
impl<S, B> FromRequest<S, B> for Version
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Infallible;
@ -79,7 +79,7 @@ where
impl<S, B> FromRequest<S, B> for HeaderMap
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Infallible;
@ -94,7 +94,7 @@ where
B: http_body::Body + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
S: Send + Sync,
{
type Rejection = BytesRejection;
@ -115,7 +115,7 @@ where
B: http_body::Body + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
S: Send + Sync,
{
type Rejection = StringRejection;
@ -137,7 +137,7 @@ where
impl<S, B> FromRequest<S, B> for http::request::Parts
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Infallible;

View file

@ -7,7 +7,7 @@ use std::convert::Infallible;
impl<S, B> FromRequest<S, B> for ()
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Infallible;
@ -26,7 +26,7 @@ macro_rules! impl_from_request {
where
$( $ty: FromRequest<S, B> + Send, )*
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Response;

View file

@ -195,7 +195,7 @@ macro_rules! impl_traits_for_either {
$($ident: FromRequest<S, B>),*,
$last: FromRequest<S, B>,
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = $last::Rejection;

View file

@ -33,7 +33,7 @@ use std::ops::{Deref, DerefMut};
/// impl<S, B> FromRequest<S, B> for Session
/// where
/// B: Send,
/// S: Send,
/// S: Send + Sync,
/// {
/// type Rejection = (StatusCode, String);
///
@ -49,7 +49,7 @@ use std::ops::{Deref, DerefMut};
/// impl<S, B> FromRequest<S, B> for CurrentUser
/// where
/// B: Send,
/// S: Send,
/// S: Send + Sync,
/// {
/// type Rejection = Response;
///
@ -93,7 +93,7 @@ struct CachedEntry<T>(T);
impl<S, B, T> FromRequest<S, B> for Cached<T>
where
B: Send,
S: Send,
S: Send + Sync,
T: FromRequest<S, B> + Clone + Send + Sync + 'static,
{
type Rejection = T::Rejection;
@ -145,7 +145,7 @@ mod tests {
impl<S, B> FromRequest<S, B> for Extractor
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Infallible;

View file

@ -91,7 +91,7 @@ pub struct CookieJar {
impl<S, B> FromRequest<S, B> for CookieJar
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Infallible;

View file

@ -90,7 +90,7 @@ impl<K> fmt::Debug for PrivateCookieJar<K> {
impl<S, B, K> FromRequest<S, B> for PrivateCookieJar<K>
where
B: Send,
S: Send,
S: Send + Sync,
K: FromRef<S> + Into<Key>,
{
type Rejection = Infallible;

View file

@ -108,7 +108,7 @@ impl<K> fmt::Debug for SignedCookieJar<K> {
impl<S, B, K> FromRequest<S, B> for SignedCookieJar<K>
where
B: Send,
S: Send,
S: Send + Sync,
K: FromRef<S> + Into<Key>,
{
type Rejection = Infallible;

View file

@ -61,7 +61,7 @@ where
B: HttpBody + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
S: Send + Sync,
{
type Rejection = FormRejection;

View file

@ -62,7 +62,7 @@ impl<T, S, B> FromRequest<S, B> for Query<T>
where
T: DeserializeOwned,
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = QueryRejection;

View file

@ -110,7 +110,7 @@ impl<E, R> DerefMut for WithRejection<E, R> {
impl<B, E, R, S> FromRequest<S, B> for WithRejection<E, R>
where
B: Send,
S: Send,
S: Send + Sync,
E: FromRequest<S, B>,
R: From<E::Rejection> + IntoResponse,
{
@ -138,7 +138,7 @@ mod tests {
impl<S, B> FromRequest<S, B> for TestExtractor
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = ();

View file

@ -6,7 +6,7 @@ use axum::{
response::{IntoResponse, Response},
};
use futures_util::future::{BoxFuture, FutureExt, Map};
use std::{future::Future, marker::PhantomData};
use std::{future::Future, marker::PhantomData, sync::Arc};
mod or;
@ -24,7 +24,11 @@ pub trait HandlerCallWithExtractors<T, S, B>: Sized {
type Future: Future<Output = Response> + Send + 'static;
/// Call the handler with the extracted inputs.
fn call(self, state: S, extractors: T) -> <Self as HandlerCallWithExtractors<T, S, B>>::Future;
fn call(
self,
state: Arc<S>,
extractors: T,
) -> <Self as HandlerCallWithExtractors<T, S, B>>::Future;
/// Conver this `HandlerCallWithExtractors` into [`Handler`].
fn into_handler(self) -> IntoHandler<Self, T, S, B> {
@ -70,7 +74,7 @@ pub trait HandlerCallWithExtractors<T, S, B>: Sized {
/// impl<S, B> FromRequest<S, B> for AdminPermissions
/// where
/// B: Send,
/// S: Send,
/// S: Send + Sync,
/// {
/// // check for admin permissions...
/// # type Rejection = ();
@ -85,7 +89,7 @@ pub trait HandlerCallWithExtractors<T, S, B>: Sized {
/// impl<S, B> FromRequest<S, B> for User
/// where
/// B: Send,
/// S: Send,
/// S: Send + Sync,
/// {
/// // check for a logged in user...
/// # type Rejection = ();
@ -130,7 +134,7 @@ macro_rules! impl_handler_call_with {
fn call(
self,
_state: S,
_state: Arc<S>,
($($ty,)*): ($($ty,)*),
) -> <Self as HandlerCallWithExtractors<($($ty,)*), S, B>>::Future {
self($($ty,)*).map(IntoResponse::into_response)
@ -172,13 +176,13 @@ where
T: FromRequest<S, B> + Send + 'static,
T::Rejection: Send,
B: Send + 'static,
S: Clone + Send + 'static,
S: Send + Sync + 'static,
{
type Future = BoxFuture<'static, Response>;
fn call(self, state: S, req: http::Request<B>) -> Self::Future {
fn call(self, state: Arc<S>, req: http::Request<B>) -> Self::Future {
Box::pin(async move {
let mut req = RequestParts::with_state(state.clone(), req);
let mut req = RequestParts::with_state_arc(Arc::clone(&state), req);
match req.extract::<T>().await {
Ok(t) => self.handler.call(state, t).await,
Err(rejection) => rejection.into_response(),

View file

@ -8,7 +8,7 @@ use axum::{
};
use futures_util::future::{BoxFuture, Either as EitherFuture, FutureExt, Map};
use http::StatusCode;
use std::{future::Future, marker::PhantomData};
use std::{future::Future, marker::PhantomData, sync::Arc};
/// [`Handler`] that runs one [`Handler`] and if that rejects it'll fallback to another
/// [`Handler`].
@ -37,7 +37,7 @@ where
fn call(
self,
state: S,
state: Arc<S>,
extractors: Either<Lt, Rt>,
) -> <Self as HandlerCallWithExtractors<Either<Lt, Rt>, S, B>>::Future {
match extractors {
@ -64,14 +64,14 @@ where
Lt::Rejection: Send,
Rt::Rejection: Send,
B: Send + 'static,
S: Clone + Send + 'static,
S: Send + Sync + 'static,
{
// this puts `futures_util` in our public API but thats fine in axum-extra
type Future = BoxFuture<'static, Response>;
fn call(self, state: S, req: Request<B>) -> Self::Future {
fn call(self, state: Arc<S>, req: Request<B>) -> Self::Future {
Box::pin(async move {
let mut req = RequestParts::with_state(state.clone(), req);
let mut req = RequestParts::with_state_arc(Arc::clone(&state), req);
if let Ok(lt) = req.extract::<Lt>().await {
return self.lhs.call(state, lt).await;

View file

@ -104,7 +104,7 @@ where
B::Data: Into<Bytes>,
B::Error: Into<BoxError>,
T: DeserializeOwned,
S: Send,
S: Send + Sync,
{
type Rejection = BodyAlreadyExtracted;

View file

@ -103,7 +103,7 @@ where
B: HttpBody + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
S: Send + Sync,
{
type Rejection = ProtoBufRejection;

View file

@ -178,7 +178,7 @@ pub trait RouterExt<S, B>: sealed::Sealed {
impl<S, B> RouterExt<S, B> for Router<S, B>
where
B: axum::body::HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
#[cfg(feature = "typed-routing")]
fn typed_get<H, T, P>(self, handler: H) -> Self

View file

@ -53,7 +53,7 @@ where
impl<S, B> Resource<S, B>
where
B: axum::body::HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
/// Create a `Resource` with the given name and state.
///

View file

@ -223,7 +223,7 @@ fn impl_struct_by_extracting_each_field(
B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
B::Data: ::std::marker::Send,
B::Error: ::std::convert::Into<::axum::BoxError>,
S: Send,
S: ::std::marker::Send + ::std::marker::Sync,
{
type Rejection = #rejection_ident;
@ -659,7 +659,7 @@ fn impl_struct_by_extracting_all_at_once(
#path<#via_type_generics>: ::axum::extract::FromRequest<S, B>,
#rejection_bound
B: ::std::marker::Send,
S: ::std::marker::Send,
S: ::std::marker::Send + ::std::marker::Sync,
{
type Rejection = #associated_rejection_type;
@ -725,7 +725,7 @@ fn impl_enum_by_extracting_all_at_once(
B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
B::Data: ::std::marker::Send,
B::Error: ::std::convert::Into<::axum::BoxError>,
S: ::std::marker::Send,
S: ::std::marker::Send + ::std::marker::Sync,
{
type Rejection = #associated_rejection_type;

View file

@ -226,7 +226,7 @@ mod typed_path;
/// impl<S, B> FromRequest<S, B> for OtherExtractor
/// where
/// B: Send,
/// S: Send,
/// S: Send + Sync,
/// {
/// // this rejection doesn't implement `Display` and `Error`
/// type Rejection = (StatusCode, String);

View file

@ -130,7 +130,7 @@ fn expand_named_fields(
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = #rejection_assoc_type;
@ -233,7 +233,7 @@ fn expand_unnamed_fields(
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = #rejection_assoc_type;
@ -315,7 +315,7 @@ fn expand_unit_fields(
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = #rejection_assoc_type;

View file

@ -10,7 +10,7 @@ struct A;
impl<S, B> FromRequest<S, B> for A
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = ();

View file

@ -10,7 +10,7 @@ struct A;
impl<S, B> FromRequest<S, B> for A
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = ();

View file

@ -123,7 +123,7 @@ impl A {
impl<S, B> FromRequest<S, B> for A
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = ();

View file

@ -10,7 +10,7 @@ struct A;
impl<S, B> FromRequest<S, B> for A
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = ();

View file

@ -17,7 +17,7 @@ struct OtherExtractor;
impl<S, B> FromRequest<S, B> for OtherExtractor
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = OtherExtractorRejection;

View file

@ -31,7 +31,7 @@ struct OtherExtractor;
impl<S, B> FromRequest<S, B> for OtherExtractor
where
B: Send + 'static,
S: Send,
S: Send + Sync,
{
// this rejection doesn't implement `Display` and `Error`
type Rejection = (StatusCode, String);

View file

@ -1,6 +1,7 @@
use axum::{
extract::State,
routing::{get, post},
Json, Router, Server,
Extension, Json, Router, Server,
};
use hyper::server::conn::AddrIncoming;
use serde::{Deserialize, Serialize};
@ -50,6 +51,30 @@ fn main() {
}),
)
});
let state = AppState {
_string: "aaaaaaaaaaaaaaaaaa".to_owned(),
_vec: Vec::from([
"aaaaaaaaaaaaaaaaaa".to_owned(),
"bbbbbbbbbbbbbbbbbb".to_owned(),
"cccccccccccccccccc".to_owned(),
]),
};
benchmark("extension").run(|| {
Router::new()
.route("/", get(|_: Extension<AppState>| async {}))
.layer(Extension(state.clone()))
});
benchmark("state")
.run(|| Router::with_state(state.clone()).route("/", get(|_: State<AppState>| async {})));
}
#[derive(Clone)]
struct AppState {
_string: String,
_vec: Vec<String>,
}
#[derive(Deserialize, Serialize)]
@ -92,9 +117,10 @@ impl BenchmarkBuilder {
config_method!(headers, &'static [(&'static str, &'static str)]);
config_method!(body, &'static str);
fn run<F>(self, f: F)
fn run<F, S>(self, f: F)
where
F: FnOnce() -> Router,
F: FnOnce() -> Router<S>,
S: Clone + Send + Sync + 'static,
{
// support only running some benchmarks with
// ```

View file

@ -424,7 +424,7 @@ struct ExtractUserAgent(HeaderValue);
impl<S, B> FromRequest<S, B> for ExtractUserAgent
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
@ -476,7 +476,7 @@ struct AuthenticatedUser {
impl<S, B> FromRequest<S, B> for AuthenticatedUser
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Response;

View file

@ -77,7 +77,7 @@ impl<T, S, B> FromRequest<S, B> for Extension<T>
where
T: Clone + Send + Sync + 'static,
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = ExtensionRejection;

View file

@ -131,7 +131,7 @@ pub struct ConnectInfo<T>(pub T);
impl<S, B, T> FromRequest<S, B> for ConnectInfo<T>
where
B: Send,
S: Send,
S: Send + Sync,
T: Clone + Send + Sync + 'static,
{
type Rejection = <Extension<Self> as FromRequest<S, B>>::Rejection;

View file

@ -41,7 +41,7 @@ where
T: FromRequest<S, B>,
T::Rejection: IntoResponse,
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = ContentLengthLimitRejection<T::Rejection>;

View file

@ -24,7 +24,7 @@ pub struct Host(pub String);
impl<S, B> FromRequest<S, B> for Host
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = HostRejection;

View file

@ -67,7 +67,7 @@ impl MatchedPath {
impl<S, B> FromRequest<S, B> for MatchedPath
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = MatchedPathRejection;

View file

@ -54,7 +54,7 @@ impl<S, B> FromRequest<S, B> for Multipart
where
B: HttpBody<Data = Bytes> + Default + Unpin + Send + 'static,
B::Error: Into<BoxError>,
S: Send,
S: Send + Sync,
{
type Rejection = MultipartRejection;

View file

@ -167,7 +167,7 @@ impl<T, S, B> FromRequest<S, B> for Path<T>
where
T: DeserializeOwned + Send,
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = PathRejection;

View file

@ -53,7 +53,7 @@ impl<T, S, B> FromRequest<S, B> for Query<T>
where
T: DeserializeOwned,
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = QueryRejection;

View file

@ -30,7 +30,7 @@ pub struct RawQuery(pub Option<String>);
impl<S, B> FromRequest<S, B> for RawQuery
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Infallible;

View file

@ -89,7 +89,7 @@ pub struct OriginalUri(pub Uri);
impl<S, B> FromRequest<S, B> for OriginalUri
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Infallible;
@ -146,7 +146,7 @@ where
B: HttpBody + Send + 'static,
B::Data: Into<Bytes>,
B::Error: Into<BoxError>,
S: Send,
S: Send + Sync,
{
type Rejection = BodyAlreadyExtracted;
@ -201,7 +201,7 @@ pub struct RawBody<B = Body>(pub B);
impl<S, B> FromRequest<S, B> for RawBody<B>
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = BodyAlreadyExtracted;

View file

@ -153,7 +153,7 @@ use std::{
/// // keep `S` generic but require that it can produce a `MyLibraryState`
/// // this means users will have to implement `FromRef<UserState> for MyLibraryState`
/// MyLibraryState: FromRef<S>,
/// S: Send,
/// S: Send + Sync,
/// {
/// type Rejection = Infallible;
///
@ -182,7 +182,7 @@ impl<B, OuterState, InnerState> FromRequest<OuterState, B> for State<InnerState>
where
B: Send,
InnerState: FromRef<OuterState>,
OuterState: Send,
OuterState: Send + Sync,
{
type Rejection = Infallible;

View file

@ -278,7 +278,7 @@ impl WebSocketUpgrade {
impl<S, B> FromRequest<S, B> for WebSocketUpgrade
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = WebSocketUpgradeRejection;

View file

@ -62,7 +62,7 @@ where
B: HttpBody + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
S: Send + Sync,
{
type Rejection = FormRejection;

View file

@ -5,6 +5,7 @@ use std::{
convert::Infallible,
fmt,
marker::PhantomData,
sync::Arc,
task::{Context, Poll},
};
use tower_service::Service;
@ -16,7 +17,7 @@ use tower_service::Service;
/// [`HandlerWithoutStateExt::into_service`]: super::HandlerWithoutStateExt::into_service
pub struct IntoService<H, T, S, B> {
handler: H,
state: S,
state: Arc<S>,
_marker: PhantomData<fn() -> (T, B)>,
}
@ -35,7 +36,7 @@ fn traits() {
}
impl<H, T, S, B> IntoService<H, T, S, B> {
pub(super) fn new(handler: H, state: S) -> Self {
pub(super) fn new(handler: H, state: Arc<S>) -> Self {
Self {
handler,
state,
@ -55,12 +56,11 @@ impl<H, T, S, B> fmt::Debug for IntoService<H, T, S, B> {
impl<H, T, S, B> Clone for IntoService<H, T, S, B>
where
H: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
state: self.state.clone(),
state: Arc::clone(&self.state),
_marker: PhantomData,
}
}
@ -70,7 +70,7 @@ impl<H, T, S, B> Service<Request<B>> for IntoService<H, T, S, B>
where
H: Handler<T, S, B> + Clone + Send + 'static,
B: Send + 'static,
S: Clone,
S: Send + Sync,
{
type Response = Response;
type Error = Infallible;
@ -88,7 +88,7 @@ where
use futures_util::future::FutureExt;
let handler = self.handler.clone();
let future = Handler::call(handler, self.state.clone(), req);
let future = Handler::call(handler, Arc::clone(&self.state), req);
let future = future.map(Ok as _);
super::future::IntoServiceFuture::new(future)

View file

@ -5,6 +5,7 @@ use std::{
convert::Infallible,
fmt,
marker::PhantomData,
sync::Arc,
task::{Context, Poll},
};
use tower_service::Service;
@ -54,7 +55,7 @@ impl<H, T, S, B> Service<Request<B>> for IntoServiceStateInExtension<H, T, S, B>
where
H: Handler<T, S, B> + Clone + Send + 'static,
B: Send + 'static,
S: Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
type Response = Response;
type Error = Infallible;
@ -73,7 +74,7 @@ where
let state = req
.extensions_mut()
.remove::<S>()
.remove::<Arc<S>>()
.expect("state extension missing. This is a bug in axum, please file an issue");
let handler = self.handler.clone();

View file

@ -42,7 +42,7 @@ use crate::{
routing::IntoMakeService,
};
use http::Request;
use std::{convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin};
use std::{convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin, sync::Arc};
use tower::ServiceExt;
use tower_layer::Layer;
use tower_service::Service;
@ -100,7 +100,7 @@ pub trait Handler<T, S = (), B = Body>: Clone + Send + Sized + 'static {
type Future: Future<Output = Response> + Send + 'static;
/// Call the handler with the given request.
fn call(self, state: S, req: Request<B>) -> Self::Future;
fn call(self, state: Arc<S>, req: Request<B>) -> Self::Future;
/// Apply a [`tower::Layer`] to the handler.
///
@ -151,6 +151,11 @@ pub trait Handler<T, S = (), B = Body>: Clone + Send + Sized + 'static {
/// Convert the handler into a [`Service`] by providing the state
fn with_state(self, state: S) -> WithState<Self, T, S, B> {
self.with_state_arc(Arc::new(state))
}
/// Convert the handler into a [`Service`] by providing the state
fn with_state_arc(self, state: Arc<S>) -> WithState<Self, T, S, B> {
WithState {
service: IntoService::new(self, state),
}
@ -166,7 +171,7 @@ where
{
type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, _state: S, _req: Request<B>) -> Self::Future {
fn call(self, _state: Arc<S>, _req: Request<B>) -> Self::Future {
Box::pin(async move { self().await.into_response() })
}
}
@ -179,15 +184,15 @@ macro_rules! impl_handler {
F: FnOnce($($ty,)*) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
B: Send + 'static,
S: Send + 'static,
S: Send + Sync + 'static,
Res: IntoResponse,
$( $ty: FromRequest<S, B> + Send,)*
{
type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
fn call(self, state: S, req: Request<B>) -> Self::Future {
fn call(self, state: Arc<S>, req: Request<B>) -> Self::Future {
Box::pin(async move {
let mut req = RequestParts::with_state(state, req);
let mut req = RequestParts::with_state_arc(state, req);
$(
let $ty = match $ty::from_request(&mut req).await {
@ -254,10 +259,10 @@ where
{
type Future = future::LayeredFuture<B, L::Service>;
fn call(self, state: S, req: Request<B>) -> Self::Future {
fn call(self, state: Arc<S>, req: Request<B>) -> Self::Future {
use futures_util::future::{FutureExt, Map};
let svc = self.handler.with_state(state);
let svc = self.handler.with_state_arc(state);
let svc = self.layer.layer(svc);
let future: Map<

View file

@ -106,7 +106,7 @@ impl<H, T, S, B> Service<Request<B>> for WithState<H, T, S, B>
where
H: Handler<T, S, B> + Clone + Send + 'static,
B: Send + 'static,
S: Clone,
S: Send + Sync,
{
type Response = <IntoService<H, T, S, B> as Service<Request<B>>>::Response;
type Error = <IntoService<H, T, S, B> as Service<Request<B>>>::Error;
@ -134,7 +134,6 @@ impl<H, T, S, B> std::fmt::Debug for WithState<H, T, S, B> {
impl<H, T, S, B> Clone for WithState<H, T, S, B>
where
H: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {

View file

@ -100,7 +100,7 @@ where
B: HttpBody + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
S: Send + Sync,
{
type Rejection = JsonRejection;

View file

@ -48,7 +48,7 @@ use tower_service::Service;
/// impl<S, B> FromRequest<S, B> for RequireAuth
/// where
/// B: Send,
/// S: Send,
/// S: Send + Sync,
/// {
/// type Rejection = StatusCode;
///
@ -283,7 +283,7 @@ mod tests {
impl<S, B> FromRequest<S, B> for RequireAuth
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = StatusCode;

View file

@ -16,6 +16,7 @@ use std::{
convert::Infallible,
fmt,
marker::PhantomData,
sync::Arc,
task::{Context, Poll},
};
use tower::{service_fn, util::MapResponseLayer};
@ -143,7 +144,7 @@ macro_rules! top_level_handler_fn {
H: Handler<T, S, B>,
B: Send + 'static,
T: 'static,
S: Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
on(MethodFilter::$method, handler)
}
@ -279,7 +280,7 @@ macro_rules! chained_handler_fn {
where
H: Handler<T, S, B>,
T: 'static,
S: Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
self.on(MethodFilter::$method, handler)
}
@ -428,7 +429,7 @@ where
H: Handler<T, S, B>,
B: Send + 'static,
T: 'static,
S: Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
MethodRouter::new().on(filter, handler)
}
@ -475,7 +476,7 @@ where
H: Handler<T, S, B>,
B: Send + 'static,
T: 'static,
S: Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
MethodRouter::new()
.fallback_boxed_response_body(IntoServiceStateInExtension::new(handler))
@ -599,7 +600,7 @@ where
where
H: Handler<T, S, B>,
T: 'static,
S: Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
self.on_service_boxed_response_body(filter, IntoServiceStateInExtension::new(handler))
}
@ -618,7 +619,7 @@ where
where
H: Handler<T, S, B>,
T: 'static,
S: Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
self.fallback_service(IntoServiceStateInExtension::new(handler))
}
@ -727,6 +728,13 @@ where
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn with_state(self, state: S) -> WithState<S, B, E> {
self.with_state_arc(Arc::new(state))
}
/// Provide the [`Arc`]'ed state.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn with_state_arc(self, state: Arc<S>) -> WithState<S, B, E> {
WithState {
method_router: self,
state,
@ -1127,7 +1135,7 @@ where
/// Created with [`MethodRouter::with_state`]
pub struct WithState<S, B, E> {
method_router: MethodRouter<S, B, E>,
state: S,
state: Arc<S>,
}
impl<S, B, E> WithState<S, B, E> {
@ -1156,14 +1164,11 @@ impl<S, B, E> WithState<S, B, E> {
}
}
impl<S, B, E> Clone for WithState<S, B, E>
where
S: Clone,
{
impl<S, B, E> Clone for WithState<S, B, E> {
fn clone(&self) -> Self {
Self {
method_router: self.method_router.clone(),
state: self.state.clone(),
state: Arc::clone(&self.state),
}
}
}
@ -1183,7 +1188,7 @@ where
impl<S, B, E> Service<Request<B>> for WithState<S, B, E>
where
B: HttpBody,
S: Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
type Response = Response;
type Error = E;
@ -1232,7 +1237,7 @@ where
},
} = self;
req.extensions_mut().insert(state.clone());
req.extensions_mut().insert(Arc::clone(state));
call!(req, method, HEAD, head);
call!(req, method, HEAD, get);

View file

@ -62,19 +62,16 @@ impl RouteId {
/// The router type for composing handlers and services.
pub struct Router<S = (), B = Body> {
state: S,
state: Arc<S>,
routes: HashMap<RouteId, Endpoint<S, B>>,
node: Arc<Node>,
fallback: Fallback<B>,
}
impl<S, B> Clone for Router<S, B>
where
S: Clone,
{
impl<S, B> Clone for Router<S, B> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
state: Arc::clone(&self.state),
routes: self.routes.clone(),
node: Arc::clone(&self.node),
fallback: self.fallback.clone(),
@ -85,7 +82,7 @@ where
impl<S, B> Default for Router<S, B>
where
B: HttpBody + Send + 'static,
S: Default + Clone + Send + Sync + 'static,
S: Default + Send + Sync + 'static,
{
fn default() -> Self {
Self::with_state(S::default())
@ -125,7 +122,7 @@ where
impl<S, B> Router<S, B>
where
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
/// Create a new `Router` with the given state.
///
@ -134,6 +131,16 @@ where
/// Unless you add additional routes this will respond with `404 Not Found` to
/// all requests.
pub fn with_state(state: S) -> Self {
Self::with_state_arc(Arc::new(state))
}
/// Create a new `Router` with the given [`Arc`]'ed state.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
///
/// Unless you add additional routes this will respond with `404 Not Found` to
/// all requests.
pub fn with_state_arc(state: Arc<S>) -> Self {
Self {
state,
routes: Default::default(),
@ -262,7 +269,7 @@ where
pub fn merge<S2, R>(mut self, other: R) -> Self
where
R: Into<Router<S2, B>>,
S2: Clone + Send + Sync + 'static,
S2: Send + Sync + 'static,
{
let Router {
state,
@ -282,7 +289,7 @@ where
method_router
// this will set the state for each route
// such we don't override the inner state later in `MethodRouterWithState`
.layer(Extension(state.clone()))
.layer(Extension(Arc::clone(&state)))
.downcast_state(),
),
Endpoint::Route(route) => self.route_service(path, route),
@ -383,8 +390,8 @@ where
H: Handler<T, S, B>,
T: 'static,
{
let state = self.state.clone();
self.fallback_service(handler.with_state(state))
let state = Arc::clone(&self.state);
self.fallback_service(handler.with_state_arc(state))
}
/// Add a fallback [`Service`] to the router.
@ -484,7 +491,10 @@ where
.clone();
match &mut route {
Endpoint::MethodRouter(inner) => inner.clone().with_state(self.state.clone()).call(req),
Endpoint::MethodRouter(inner) => inner
.clone()
.with_state_arc(Arc::clone(&self.state))
.call(req),
Endpoint::Route(inner) => inner.call(req),
}
}
@ -498,7 +508,7 @@ where
impl<S, B> Service<Request<B>> for Router<S, B>
where
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
S: Send + Sync + 'static,
{
type Response = Response;
type Error = Infallible;
@ -618,10 +628,7 @@ enum Endpoint<S, B> {
Route(Route<B>),
}
impl<S, B> Clone for Endpoint<S, B>
where
S: Clone,
{
impl<S, B> Clone for Endpoint<S, B> {
fn clone(&self) -> Self {
match self {
Endpoint::MethodRouter(inner) => Endpoint::MethodRouter(inner.clone()),

View file

@ -56,7 +56,7 @@ impl<T, S, B> FromRequest<S, B> for TypedHeader<T>
where
T: headers::Header,
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = TypedHeaderRejection;

View file

@ -82,7 +82,7 @@ struct PrintRequestBody;
#[async_trait]
impl<S> FromRequest<S, BoxBody> for PrintRequestBody
where
S: Send + Clone,
S: Clone + Send + Sync,
{
type Rejection = Response;

View file

@ -58,7 +58,7 @@ struct Json<T>(T);
#[async_trait]
impl<S, B, T> FromRequest<S, B> for Json<T>
where
S: Send,
S: Send + Sync,
// these trait bounds are copied from `impl FromRequest for axum::Json`
T: DeserializeOwned,
B: axum::body::HttpBody + Send,

View file

@ -57,7 +57,7 @@ where
// these trait bounds are copied from `impl FromRequest for axum::extract::path::Path`
T: DeserializeOwned + Send,
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = (StatusCode, axum::Json<PathError>);

View file

@ -124,7 +124,7 @@ impl AuthBody {
#[async_trait]
impl<S, B> FromRequest<S, B> for Claims
where
S: Send,
S: Send + Sync,
B: Send,
{
type Rejection = AuthError;

View file

@ -63,7 +63,7 @@ pub struct ValidatedForm<T>(pub T);
impl<T, S, B> FromRequest<S, B> for ValidatedForm<T>
where
T: DeserializeOwned + Validate,
S: Send,
S: Send + Sync,
B: http_body::Body + Send,
B::Data: Send,
B::Error: Into<BoxError>,

View file

@ -51,7 +51,7 @@ enum Version {
impl<S, B> FromRequest<S, B> for Version
where
B: Send,
S: Send,
S: Send + Sync,
{
type Rejection = Response;