1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-04-05 14:01:05 +02:00

Add Cached extractor ()

* extra: Add `Cached` extractor

`Cached` wraps another extractor and caches its result in request
extensions.

* Use newtype to avoid overriding extensions of the same type

* Rename type param
This commit is contained in:
David Pedersen 2021-11-25 11:14:31 +01:00 committed by GitHub
parent 5a5800c1ae
commit 96b353b556
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 254 additions and 0 deletions
axum-extra
axum/src

View file

@ -15,6 +15,7 @@ erased-json = ["serde", "serde_json"]
[dependencies]
axum = { path = "../axum", version = "0.3" }
http = "0.2"
mime = "0.3"
tower-service = "0.3"

View file

@ -0,0 +1,235 @@
use axum::{
async_trait,
body::{boxed, BoxBody},
extract::{
rejection::{ExtensionRejection, ExtensionsAlreadyExtracted},
Extension, FromRequest, RequestParts,
},
http::Response,
response::IntoResponse,
};
use std::{
fmt,
ops::{Deref, DerefMut},
};
/// Cache results of other extractors.
///
/// `Cached` wraps another extractor and caches its result in [request extensions].
///
/// This is useful if you have a tree of extractors that share common sub-extractors that
/// you only want to run once, perhaps because they're expensive.
///
/// The cache purely type based so you can only cache one value of each type. The cache is also
/// local to the current request and not reused across requests.
///
/// # Example
///
/// ```rust
/// use axum_extra::extract::Cached;
/// use axum::{
/// async_trait,
/// extract::{FromRequest, RequestParts},
/// body::{self, BoxBody},
/// response::IntoResponse,
/// http::{StatusCode, Response},
/// };
///
/// #[derive(Clone)]
/// struct Session { /* ... */ }
///
/// #[async_trait]
/// impl<B> FromRequest<B> for Session
/// where
/// B: Send,
/// {
/// type Rejection = (StatusCode, String);
///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
/// // load session...
/// # unimplemented!()
/// }
/// }
///
/// struct CurrentUser { /* ... */ }
///
/// #[async_trait]
/// impl<B> FromRequest<B> for CurrentUser
/// where
/// B: Send,
/// {
/// type Rejection = Response<BoxBody>;
///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
/// // loading a `CurrentUser` requires first loading the `Session`
/// //
/// // by using `Cached<Session>` we avoid extracting the session more than
/// // once, in case other extractors for the same request also loads the session
/// let session: Session = Cached::<Session>::from_request(req)
/// .await
/// .map_err(|err| err.into_response().map(body::boxed))?
/// .0;
///
/// // load user from session...
/// # unimplemented!()
/// }
/// }
///
/// // handler that extracts the current user and the session
/// //
/// // the session will only be loaded once, even though `CurrentUser`
/// // also loads it
/// async fn handler(
/// current_user: CurrentUser,
/// // we have to use `Cached<Session>` here otherwise the
/// // cached session would not be used
/// Cached(session): Cached<Session>,
/// ) {
/// // ...
/// }
/// ```
///
/// [request extensions]: http::Extensions
#[derive(Debug, Clone, Default)]
pub struct Cached<T>(pub T);
#[derive(Clone)]
struct CachedEntry<T>(T);
#[async_trait]
impl<B, T> FromRequest<B> for Cached<T>
where
B: Send,
T: FromRequest<B> + Clone + Send + Sync + 'static,
{
type Rejection = CachedRejection<T::Rejection>;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
match Extension::<CachedEntry<T>>::from_request(req).await {
Ok(Extension(CachedEntry(value))) => Ok(Self(value)),
Err(ExtensionRejection::ExtensionsAlreadyExtracted(err)) => {
Err(CachedRejection::ExtensionsAlreadyExtracted(err))
}
Err(_) => {
let value = T::from_request(req).await.map_err(CachedRejection::Inner)?;
req.extensions_mut()
.ok_or_else(|| {
CachedRejection::ExtensionsAlreadyExtracted(
ExtensionsAlreadyExtracted::default(),
)
})?
.insert(CachedEntry(value.clone()));
Ok(Self(value))
}
}
}
}
impl<T> Deref for Cached<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Cached<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
/// Rejection used for [`Cached`].
///
/// Contains one variant for each way the [`Cached`] extractor can fail.
#[derive(Debug)]
#[non_exhaustive]
pub enum CachedRejection<R> {
#[allow(missing_docs)]
ExtensionsAlreadyExtracted(ExtensionsAlreadyExtracted),
#[allow(missing_docs)]
Inner(R),
}
impl<R> IntoResponse for CachedRejection<R>
where
R: IntoResponse,
{
type Body = BoxBody;
type BodyError = <Self::Body as axum::body::HttpBody>::Error;
fn into_response(self) -> Response<Self::Body> {
match self {
Self::ExtensionsAlreadyExtracted(inner) => inner.into_response().map(boxed),
Self::Inner(inner) => inner.into_response().map(boxed),
}
}
}
impl<R> fmt::Display for CachedRejection<R>
where
R: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ExtensionsAlreadyExtracted(inner) => write!(f, "{}", inner),
Self::Inner(inner) => write!(f, "{}", inner),
}
}
}
impl<R> std::error::Error for CachedRejection<R>
where
R: std::error::Error + 'static,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::ExtensionsAlreadyExtracted(inner) => Some(inner),
Self::Inner(inner) => Some(inner),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::Request;
use std::{
convert::Infallible,
sync::atomic::{AtomicU32, Ordering},
time::Instant,
};
#[tokio::test]
async fn works() {
static COUNTER: AtomicU32 = AtomicU32::new(0);
#[derive(Clone, Debug, PartialEq, Eq)]
struct Extractor(Instant);
#[async_trait]
impl<B> FromRequest<B> for Extractor
where
B: Send,
{
type Rejection = Infallible;
async fn from_request(_req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
COUNTER.fetch_add(1, Ordering::SeqCst);
Ok(Self(Instant::now()))
}
}
let mut req = RequestParts::new(Request::new(()));
let first = Cached::<Extractor>::from_request(&mut req).await.unwrap().0;
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
let second = Cached::<Extractor>::from_request(&mut req).await.unwrap().0;
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
assert_eq!(first, second);
}
}

View file

@ -0,0 +1,11 @@
//! Additional extractors.
mod cached;
pub use self::cached::Cached;
pub mod rejection {
//! Rejection response types.
pub use super::cached::CachedRejection;
}

View file

@ -43,5 +43,6 @@
#![cfg_attr(docsrs, feature(doc_cfg))]
#![cfg_attr(test, allow(clippy::float_cmp))]
pub mod extract;
pub mod response;
pub mod routing;

View file

@ -76,6 +76,12 @@ macro_rules! define_rejection {
}
impl std::error::Error for $name {}
impl Default for $name {
fn default() -> Self {
Self
}
}
};
(