1
0
Fork 0
mirror of https://github.com/tokio-rs/axum.git synced 2025-04-16 02:30:56 +02:00

Add a separate trait for optional extractors

This commit is contained in:
Jonas Platte 2023-12-30 22:47:46 +01:00
parent d2cea5cdbd
commit 265e65a783
10 changed files with 166 additions and 50 deletions
axum-core/src/extract
axum-extra/src
axum
examples
oauth/src
todos/src

View file

@ -13,11 +13,16 @@ pub mod rejection;
mod default_body_limit;
mod from_ref;
mod option;
mod request_parts;
mod tuple;
pub(crate) use self::default_body_limit::DefaultBodyLimitKind;
pub use self::{default_body_limit::DefaultBodyLimit, from_ref::FromRef};
pub use self::{
default_body_limit::DefaultBodyLimit,
from_ref::FromRef,
option::{OptionalFromRequest, OptionalFromRequestParts},
};
/// Type alias for [`http::Request`] whose body type defaults to [`Body`], the most common body
/// type used with axum.
@ -99,35 +104,6 @@ where
}
}
#[async_trait]
impl<S, T> FromRequestParts<S> for Option<T>
where
T: FromRequestParts<S>,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Option<T>, Self::Rejection> {
Ok(T::from_request_parts(parts, state).await.ok())
}
}
#[async_trait]
impl<S, T> FromRequest<S> for Option<T>
where
T: FromRequest<S>,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: Request, state: &S) -> Result<Option<T>, Self::Rejection> {
Ok(T::from_request(req, state).await.ok())
}
}
#[async_trait]
impl<S, T> FromRequestParts<S> for Result<T, T::Rejection>
where

View file

@ -0,0 +1,60 @@
use async_trait::async_trait;
use http::request::Parts;
use crate::response::IntoResponse;
use super::{private, FromRequest, FromRequestParts, Request};
/// TODO: DOCS
#[async_trait]
pub trait OptionalFromRequestParts<S>: Sized {
/// If the extractor fails it'll use this "rejection" type. A rejection is
/// a kind of error that can be converted into a response.
type Rejection: IntoResponse;
/// Perform the extraction.
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Option<Self>, Self::Rejection>;
}
/// TODO: DOCS
#[async_trait]
pub trait OptionalFromRequest<S, M = private::ViaRequest>: Sized {
/// If the extractor fails it'll use this "rejection" type. A rejection is
/// a kind of error that can be converted into a response.
type Rejection: IntoResponse;
/// Perform the extraction.
async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection>;
}
#[async_trait]
impl<S, T> FromRequestParts<S> for Option<T>
where
T: OptionalFromRequestParts<S>,
S: Send + Sync,
{
type Rejection = T::Rejection;
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Option<T>, Self::Rejection> {
T::from_request_parts(parts, state).await
}
}
#[async_trait]
impl<S, T> FromRequest<S> for Option<T>
where
T: OptionalFromRequest<S>,
S: Send + Sync,
{
type Rejection = T::Rejection;
async fn from_request(req: Request, state: &S) -> Result<Option<T>, Self::Rejection> {
T::from_request(req, state).await
}
}

View file

@ -2,7 +2,7 @@
use axum::{
async_trait,
extract::FromRequestParts,
extract::{FromRequestParts, OptionalFromRequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use headers::{Header, HeaderMapExt};
@ -80,6 +80,31 @@ where
}
}
#[async_trait]
impl<T, S> OptionalFromRequestParts<S> for TypedHeader<T>
where
T: Header,
S: Send + Sync,
{
type Rejection = TypedHeaderRejection;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
let mut values = parts.headers.get_all(T::name()).iter();
let is_missing = values.size_hint() == (0, Some(0));
match T::decode(&mut values) {
Ok(res) => Ok(Some(Self(res))),
Err(_) if is_missing => Ok(None),
Err(err) => Err(TypedHeaderRejection {
name: T::name(),
reason: TypedHeaderRejectionReason::Error(err),
}),
}
}
}
axum_core::__impl_deref!(TypedHeader);
impl<T> IntoResponseParts for TypedHeader<T>

View file

@ -114,6 +114,7 @@ rustversion = "1.0.9"
[dev-dependencies]
anyhow = "1.0"
axum-extra = { path = "../axum-extra", features = ["typed-header"] }
axum-macros = { path = "../axum-macros", version = "0.4.0", features = ["__private"] }
quickcheck = "1.0"
quickcheck_macros = "1.0"

View file

@ -219,26 +219,22 @@ and all others implement [`FromRequestParts`].
# Optional extractors
All extractors defined in axum will reject the request if it doesn't match.
If you wish to make an extractor optional you can wrap it in `Option`:
TODO: Docs, more realistic example
```rust,no_run
use axum::{
extract::Json,
routing::post,
Router,
};
use axum::{routing::post, Router};
use axum_extra::{headers::UserAgent, TypedHeader};
use serde_json::Value;
async fn create_user(payload: Option<Json<Value>>) {
if let Some(payload) = payload {
// We got a valid JSON payload
async fn foo(user_agent: Option<TypedHeader<UserAgent>>) {
if let Some(TypedHeader(user_agent)) = user_agent {
// The client sent a user agent
} else {
// Payload wasn't valid JSON
// No user agent header
}
}
let app = Router::new().route("/users", post(create_user));
let app = Router::new().route("/foo", post(foo));
# let _: Router = app;
```

View file

@ -1,6 +1,7 @@
use super::{rejection::*, FromRequestParts};
use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE};
use async_trait::async_trait;
use axum_core::extract::OptionalFromRequestParts;
use http::request::Parts;
use std::{collections::HashMap, sync::Arc};
@ -81,6 +82,21 @@ where
}
}
#[async_trait]
impl<S> OptionalFromRequestParts<S> for MatchedPath
where
S: Send + Sync,
{
type Rejection = MatchedPathRejection;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
Ok(parts.extensions.get::<Self>().cloned())
}
}
#[derive(Clone, Debug)]
struct MatchedNestedPath(Arc<str>);

View file

@ -18,7 +18,10 @@ mod request_parts;
mod state;
#[doc(inline)]
pub use axum_core::extract::{DefaultBodyLimit, FromRef, FromRequest, FromRequestParts, Request};
pub use axum_core::extract::{
DefaultBodyLimit, FromRef, FromRequest, FromRequestParts, OptionalFromRequest,
OptionalFromRequestParts, Request,
};
#[cfg(feature = "macros")]
pub use axum_macros::{FromRef, FromRequest, FromRequestParts};

View file

@ -1,4 +1,4 @@
use super::{rejection::*, FromRequestParts};
use super::{rejection::*, FromRequestParts, OptionalFromRequestParts};
use async_trait::async_trait;
use http::{request::Parts, Uri};
use serde::de::DeserializeOwned;
@ -59,6 +59,28 @@ where
}
}
#[async_trait]
impl<T, S> OptionalFromRequestParts<S> for Query<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = QueryRejection;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
if let Some(query) = parts.uri.query() {
let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::from_err)?;
Ok(Some(Self(value)))
} else {
Ok(None)
}
}
}
impl<T> Query<T>
where
T: DeserializeOwned,

View file

@ -12,7 +12,7 @@ use anyhow::{Context, Result};
use async_session::{MemoryStore, Session, SessionStore};
use axum::{
async_trait,
extract::{FromRef, FromRequestParts, Query, State},
extract::{FromRef, FromRequestParts, OptionalFromRequestParts, Query, State},
http::{header::SET_COOKIE, HeaderMap},
response::{IntoResponse, Redirect, Response},
routing::get,
@ -25,7 +25,7 @@ use oauth2::{
ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
};
use serde::{Deserialize, Serialize};
use std::env;
use std::{convert::Infallible, env};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
static COOKIE_NAME: &str = "SESSION";
@ -285,6 +285,25 @@ where
}
}
#[async_trait]
impl<S> OptionalFromRequestParts<S> for User
where
MemoryStore: FromRef<S>,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Option<Self>, Self::Rejection> {
match FromRequestParts::from_request_parts(parts, state).await {
Ok(res) => Ok(Some(res)),
Err(AuthRedirect) => Ok(None),
}
}
}
// Use anyhow, define error and enable '?'
// For a simplified example of using anyhow in axum check /examples/anyhow-error-response
#[derive(Debug)]

View file

@ -82,13 +82,11 @@ pub struct Pagination {
}
async fn todos_index(
pagination: Option<Query<Pagination>>,
Query(pagination): Query<Pagination>,
State(db): State<Db>,
) -> impl IntoResponse {
let todos = db.read().unwrap();
let Query(pagination) = pagination.unwrap_or_default();
let todos = todos
.values()
.skip(pagination.offset.unwrap_or(0))