mirror of
https://github.com/tokio-rs/axum.git
synced 2025-01-16 22:43:03 +01:00
Add middleware::from_extractor_with_state
(#1396)
Fixes https://github.com/tokio-rs/axum/issues/1373
This commit is contained in:
parent
112f5354ab
commit
611c50ec8b
3 changed files with 92 additions and 41 deletions
|
@ -15,11 +15,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- **fixed:** Support streaming/chunked requests in `ContentLengthLimit` ([#1389])
|
||||
- **fixed:** Used `400 Bad Request` for `FailedToDeserializeQueryString`
|
||||
rejections, instead of `422 Unprocessable Entity` ([#1387])
|
||||
- **added:** Add `middleware::from_extractor_with_state` and
|
||||
`middleware::from_extractor_with_state_arc` ([#1396])
|
||||
- **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397])
|
||||
|
||||
[#1371]: https://github.com/tokio-rs/axum/pull/1371
|
||||
[#1387]: https://github.com/tokio-rs/axum/pull/1387
|
||||
[#1389]: https://github.com/tokio-rs/axum/pull/1389
|
||||
[#1396]: https://github.com/tokio-rs/axum/pull/1396
|
||||
[#1397]: https://github.com/tokio-rs/axum/pull/1397
|
||||
|
||||
# 0.6.0-rc.2 (10. September, 2022)
|
||||
|
|
|
@ -10,6 +10,7 @@ use std::{
|
|||
future::Future,
|
||||
marker::PhantomData,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tower_layer::Layer;
|
||||
|
@ -90,8 +91,25 @@ use tower_service::Service;
|
|||
/// ```
|
||||
///
|
||||
/// [`Bytes`]: bytes::Bytes
|
||||
pub fn from_extractor<E>() -> FromExtractorLayer<E> {
|
||||
FromExtractorLayer(PhantomData)
|
||||
pub fn from_extractor<E>() -> FromExtractorLayer<E, ()> {
|
||||
from_extractor_with_state(())
|
||||
}
|
||||
|
||||
/// Create a middleware from an extractor with the given state.
|
||||
///
|
||||
/// See [`State`](crate::extract::State) for more details about accessing state.
|
||||
pub fn from_extractor_with_state<E, S>(state: S) -> FromExtractorLayer<E, S> {
|
||||
from_extractor_with_state_arc(Arc::new(state))
|
||||
}
|
||||
|
||||
/// Create a middleware from an extractor with the given [`Arc`]'ed state.
|
||||
///
|
||||
/// See [`State`](crate::extract::State) for more details about accessing state.
|
||||
pub fn from_extractor_with_state_arc<E, S>(state: Arc<S>) -> FromExtractorLayer<E, S> {
|
||||
FromExtractorLayer {
|
||||
state,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// [`Layer`] that applies [`FromExtractor`] that runs an extractor and
|
||||
|
@ -100,28 +118,39 @@ pub fn from_extractor<E>() -> FromExtractorLayer<E> {
|
|||
/// See [`from_extractor`] for more details.
|
||||
///
|
||||
/// [`Layer`]: tower::Layer
|
||||
pub struct FromExtractorLayer<E>(PhantomData<fn() -> E>);
|
||||
pub struct FromExtractorLayer<E, S> {
|
||||
state: Arc<S>,
|
||||
_marker: PhantomData<fn() -> E>,
|
||||
}
|
||||
|
||||
impl<E> Clone for FromExtractorLayer<E> {
|
||||
impl<E, S> Clone for FromExtractorLayer<E, S> {
|
||||
fn clone(&self) -> Self {
|
||||
Self(PhantomData)
|
||||
Self {
|
||||
state: Arc::clone(&self.state),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> fmt::Debug for FromExtractorLayer<E> {
|
||||
impl<E, S> fmt::Debug for FromExtractorLayer<E, S>
|
||||
where
|
||||
S: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("FromExtractorLayer")
|
||||
.field("state", &self.state)
|
||||
.field("extractor", &format_args!("{}", std::any::type_name::<E>()))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E, S> Layer<S> for FromExtractorLayer<E> {
|
||||
type Service = FromExtractor<S, E>;
|
||||
impl<E, T, S> Layer<T> for FromExtractorLayer<E, S> {
|
||||
type Service = FromExtractor<T, E, S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
fn layer(&self, inner: T) -> Self::Service {
|
||||
FromExtractor {
|
||||
inner,
|
||||
state: Arc::clone(&self.state),
|
||||
_extractor: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -130,52 +159,57 @@ impl<E, S> Layer<S> for FromExtractorLayer<E> {
|
|||
/// Middleware that runs an extractor and discards the value.
|
||||
///
|
||||
/// See [`from_extractor`] for more details.
|
||||
pub struct FromExtractor<S, E> {
|
||||
inner: S,
|
||||
pub struct FromExtractor<T, E, S> {
|
||||
inner: T,
|
||||
state: Arc<S>,
|
||||
_extractor: PhantomData<fn() -> E>,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn traits() {
|
||||
use crate::test_helpers::*;
|
||||
assert_send::<FromExtractor<(), NotSendSync>>();
|
||||
assert_sync::<FromExtractor<(), NotSendSync>>();
|
||||
assert_send::<FromExtractor<(), NotSendSync, ()>>();
|
||||
assert_sync::<FromExtractor<(), NotSendSync, ()>>();
|
||||
}
|
||||
|
||||
impl<S, E> Clone for FromExtractor<S, E>
|
||||
impl<T, E, S> Clone for FromExtractor<T, E, S>
|
||||
where
|
||||
S: Clone,
|
||||
T: Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
inner: self.inner.clone(),
|
||||
state: Arc::clone(&self.state),
|
||||
_extractor: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, E> fmt::Debug for FromExtractor<S, E>
|
||||
impl<T, E, S> fmt::Debug for FromExtractor<T, E, S>
|
||||
where
|
||||
T: fmt::Debug,
|
||||
S: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("FromExtractor")
|
||||
.field("inner", &self.inner)
|
||||
.field("state", &self.state)
|
||||
.field("extractor", &format_args!("{}", std::any::type_name::<E>()))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, E, B> Service<Request<B>> for FromExtractor<S, E>
|
||||
impl<T, E, B, S> Service<Request<B>> for FromExtractor<T, E, S>
|
||||
where
|
||||
E: FromRequestParts<()> + 'static,
|
||||
E: FromRequestParts<S> + 'static,
|
||||
B: Default + Send + 'static,
|
||||
S: Service<Request<B>> + Clone,
|
||||
S::Response: IntoResponse,
|
||||
T: Service<Request<B>> + Clone,
|
||||
T::Response: IntoResponse,
|
||||
S: Send + Sync + 'static,
|
||||
{
|
||||
type Response = Response;
|
||||
type Error = S::Error;
|
||||
type Future = ResponseFuture<B, S, E>;
|
||||
type Error = T::Error;
|
||||
type Future = ResponseFuture<B, T, E, S>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
|
@ -183,9 +217,10 @@ where
|
|||
}
|
||||
|
||||
fn call(&mut self, req: Request<B>) -> Self::Future {
|
||||
let state = Arc::clone(&self.state);
|
||||
let extract_future = Box::pin(async move {
|
||||
let (mut parts, body) = req.into_parts();
|
||||
let extracted = E::from_request_parts(&mut parts, &()).await;
|
||||
let extracted = E::from_request_parts(&mut parts, &state).await;
|
||||
let req = Request::from_parts(parts, body);
|
||||
(req, extracted)
|
||||
});
|
||||
|
@ -202,39 +237,39 @@ where
|
|||
pin_project! {
|
||||
/// Response future for [`FromExtractor`].
|
||||
#[allow(missing_debug_implementations)]
|
||||
pub struct ResponseFuture<B, S, E>
|
||||
pub struct ResponseFuture<B, T, E, S>
|
||||
where
|
||||
E: FromRequestParts<()>,
|
||||
S: Service<Request<B>>,
|
||||
E: FromRequestParts<S>,
|
||||
T: Service<Request<B>>,
|
||||
{
|
||||
#[pin]
|
||||
state: State<B, S, E>,
|
||||
svc: Option<S>,
|
||||
state: State<B, T, E, S>,
|
||||
svc: Option<T>,
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
#[project = StateProj]
|
||||
enum State<B, S, E>
|
||||
enum State<B, T, E, S>
|
||||
where
|
||||
E: FromRequestParts<()>,
|
||||
S: Service<Request<B>>,
|
||||
E: FromRequestParts<S>,
|
||||
T: Service<Request<B>>,
|
||||
{
|
||||
Extracting {
|
||||
future: BoxFuture<'static, (Request<B>, Result<E, E::Rejection>)>,
|
||||
},
|
||||
Call { #[pin] future: S::Future },
|
||||
Call { #[pin] future: T::Future },
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, S, E> Future for ResponseFuture<B, S, E>
|
||||
impl<B, T, E, S> Future for ResponseFuture<B, T, E, S>
|
||||
where
|
||||
E: FromRequestParts<()>,
|
||||
S: Service<Request<B>>,
|
||||
S::Response: IntoResponse,
|
||||
E: FromRequestParts<S>,
|
||||
T: Service<Request<B>>,
|
||||
T::Response: IntoResponse,
|
||||
B: Default,
|
||||
{
|
||||
type Output = Result<Response, S::Error>;
|
||||
type Output = Result<Response, T::Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
loop {
|
||||
|
@ -272,29 +307,35 @@ where
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::{handler::Handler, routing::get, test_helpers::*, Router};
|
||||
use axum_core::extract::FromRef;
|
||||
use http::{header, request::Parts, StatusCode};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_from_extractor() {
|
||||
#[derive(Clone)]
|
||||
struct Secret(&'static str);
|
||||
|
||||
struct RequireAuth;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<S> FromRequestParts<S> for RequireAuth
|
||||
where
|
||||
S: Send + Sync,
|
||||
Secret: FromRef<S>,
|
||||
{
|
||||
type Rejection = StatusCode;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
_state: &S,
|
||||
state: &S,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let Secret(secret) = Secret::from_ref(state);
|
||||
if let Some(auth) = parts
|
||||
.headers
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
if auth == "secret" {
|
||||
if auth == secret {
|
||||
return Ok(Self);
|
||||
}
|
||||
}
|
||||
|
@ -305,7 +346,11 @@ mod tests {
|
|||
|
||||
async fn handler() {}
|
||||
|
||||
let app = Router::new().route("/", get(handler.layer(from_extractor::<RequireAuth>())));
|
||||
let state = Secret("secret");
|
||||
let app = Router::new().route(
|
||||
"/",
|
||||
get(handler.layer(from_extractor_with_state::<RequireAuth, _>(state))),
|
||||
);
|
||||
|
||||
let client = TestClient::new(app);
|
||||
|
||||
|
|
|
@ -5,7 +5,10 @@
|
|||
mod from_extractor;
|
||||
mod from_fn;
|
||||
|
||||
pub use self::from_extractor::{from_extractor, FromExtractor, FromExtractorLayer};
|
||||
pub use self::from_extractor::{
|
||||
from_extractor, from_extractor_with_state, from_extractor_with_state_arc, FromExtractor,
|
||||
FromExtractorLayer,
|
||||
};
|
||||
pub use self::from_fn::{
|
||||
from_fn, from_fn_with_state, from_fn_with_state_arc, FromFn, FromFnLayer, Next,
|
||||
};
|
||||
|
|
Loading…
Reference in a new issue