mirror of
https://github.com/tokio-rs/axum.git
synced 2025-04-16 02:30:56 +02:00
Add a separate trait for optional extractors
This commit is contained in:
parent
d2cea5cdbd
commit
265e65a783
10 changed files with 166 additions and 50 deletions
axum-core/src/extract
axum-extra/src
axum
examples
|
@ -13,11 +13,16 @@ pub mod rejection;
|
|||
|
||||
mod default_body_limit;
|
||||
mod from_ref;
|
||||
mod option;
|
||||
mod request_parts;
|
||||
mod tuple;
|
||||
|
||||
pub(crate) use self::default_body_limit::DefaultBodyLimitKind;
|
||||
pub use self::{default_body_limit::DefaultBodyLimit, from_ref::FromRef};
|
||||
pub use self::{
|
||||
default_body_limit::DefaultBodyLimit,
|
||||
from_ref::FromRef,
|
||||
option::{OptionalFromRequest, OptionalFromRequestParts},
|
||||
};
|
||||
|
||||
/// Type alias for [`http::Request`] whose body type defaults to [`Body`], the most common body
|
||||
/// type used with axum.
|
||||
|
@ -99,35 +104,6 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S, T> FromRequestParts<S> for Option<T>
|
||||
where
|
||||
T: FromRequestParts<S>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
state: &S,
|
||||
) -> Result<Option<T>, Self::Rejection> {
|
||||
Ok(T::from_request_parts(parts, state).await.ok())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S, T> FromRequest<S> for Option<T>
|
||||
where
|
||||
T: FromRequest<S>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request(req: Request, state: &S) -> Result<Option<T>, Self::Rejection> {
|
||||
Ok(T::from_request(req, state).await.ok())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S, T> FromRequestParts<S> for Result<T, T::Rejection>
|
||||
where
|
||||
|
|
60
axum-core/src/extract/option.rs
Normal file
60
axum-core/src/extract/option.rs
Normal file
|
@ -0,0 +1,60 @@
|
|||
use async_trait::async_trait;
|
||||
use http::request::Parts;
|
||||
|
||||
use crate::response::IntoResponse;
|
||||
|
||||
use super::{private, FromRequest, FromRequestParts, Request};
|
||||
|
||||
/// TODO: DOCS
|
||||
#[async_trait]
|
||||
pub trait OptionalFromRequestParts<S>: Sized {
|
||||
/// If the extractor fails it'll use this "rejection" type. A rejection is
|
||||
/// a kind of error that can be converted into a response.
|
||||
type Rejection: IntoResponse;
|
||||
|
||||
/// Perform the extraction.
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
state: &S,
|
||||
) -> Result<Option<Self>, Self::Rejection>;
|
||||
}
|
||||
|
||||
/// TODO: DOCS
|
||||
#[async_trait]
|
||||
pub trait OptionalFromRequest<S, M = private::ViaRequest>: Sized {
|
||||
/// If the extractor fails it'll use this "rejection" type. A rejection is
|
||||
/// a kind of error that can be converted into a response.
|
||||
type Rejection: IntoResponse;
|
||||
|
||||
/// Perform the extraction.
|
||||
async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S, T> FromRequestParts<S> for Option<T>
|
||||
where
|
||||
T: OptionalFromRequestParts<S>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = T::Rejection;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
state: &S,
|
||||
) -> Result<Option<T>, Self::Rejection> {
|
||||
T::from_request_parts(parts, state).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S, T> FromRequest<S> for Option<T>
|
||||
where
|
||||
T: OptionalFromRequest<S>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = T::Rejection;
|
||||
|
||||
async fn from_request(req: Request, state: &S) -> Result<Option<T>, Self::Rejection> {
|
||||
T::from_request(req, state).await
|
||||
}
|
||||
}
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
use axum::{
|
||||
async_trait,
|
||||
extract::FromRequestParts,
|
||||
extract::{FromRequestParts, OptionalFromRequestParts},
|
||||
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
|
||||
};
|
||||
use headers::{Header, HeaderMapExt};
|
||||
|
@ -80,6 +80,31 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, S> OptionalFromRequestParts<S> for TypedHeader<T>
|
||||
where
|
||||
T: Header,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = TypedHeaderRejection;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
_state: &S,
|
||||
) -> Result<Option<Self>, Self::Rejection> {
|
||||
let mut values = parts.headers.get_all(T::name()).iter();
|
||||
let is_missing = values.size_hint() == (0, Some(0));
|
||||
match T::decode(&mut values) {
|
||||
Ok(res) => Ok(Some(Self(res))),
|
||||
Err(_) if is_missing => Ok(None),
|
||||
Err(err) => Err(TypedHeaderRejection {
|
||||
name: T::name(),
|
||||
reason: TypedHeaderRejectionReason::Error(err),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
axum_core::__impl_deref!(TypedHeader);
|
||||
|
||||
impl<T> IntoResponseParts for TypedHeader<T>
|
||||
|
|
|
@ -114,6 +114,7 @@ rustversion = "1.0.9"
|
|||
|
||||
[dev-dependencies]
|
||||
anyhow = "1.0"
|
||||
axum-extra = { path = "../axum-extra", features = ["typed-header"] }
|
||||
axum-macros = { path = "../axum-macros", version = "0.4.0", features = ["__private"] }
|
||||
quickcheck = "1.0"
|
||||
quickcheck_macros = "1.0"
|
||||
|
|
|
@ -219,26 +219,22 @@ and all others implement [`FromRequestParts`].
|
|||
|
||||
# Optional extractors
|
||||
|
||||
All extractors defined in axum will reject the request if it doesn't match.
|
||||
If you wish to make an extractor optional you can wrap it in `Option`:
|
||||
TODO: Docs, more realistic example
|
||||
|
||||
```rust,no_run
|
||||
use axum::{
|
||||
extract::Json,
|
||||
routing::post,
|
||||
Router,
|
||||
};
|
||||
use axum::{routing::post, Router};
|
||||
use axum_extra::{headers::UserAgent, TypedHeader};
|
||||
use serde_json::Value;
|
||||
|
||||
async fn create_user(payload: Option<Json<Value>>) {
|
||||
if let Some(payload) = payload {
|
||||
// We got a valid JSON payload
|
||||
async fn foo(user_agent: Option<TypedHeader<UserAgent>>) {
|
||||
if let Some(TypedHeader(user_agent)) = user_agent {
|
||||
// The client sent a user agent
|
||||
} else {
|
||||
// Payload wasn't valid JSON
|
||||
// No user agent header
|
||||
}
|
||||
}
|
||||
|
||||
let app = Router::new().route("/users", post(create_user));
|
||||
let app = Router::new().route("/foo", post(foo));
|
||||
# let _: Router = app;
|
||||
```
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use super::{rejection::*, FromRequestParts};
|
||||
use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE};
|
||||
use async_trait::async_trait;
|
||||
use axum_core::extract::OptionalFromRequestParts;
|
||||
use http::request::Parts;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
|
@ -81,6 +82,21 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S> OptionalFromRequestParts<S> for MatchedPath
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = MatchedPathRejection;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
_state: &S,
|
||||
) -> Result<Option<Self>, Self::Rejection> {
|
||||
Ok(parts.extensions.get::<Self>().cloned())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct MatchedNestedPath(Arc<str>);
|
||||
|
||||
|
|
|
@ -18,7 +18,10 @@ mod request_parts;
|
|||
mod state;
|
||||
|
||||
#[doc(inline)]
|
||||
pub use axum_core::extract::{DefaultBodyLimit, FromRef, FromRequest, FromRequestParts, Request};
|
||||
pub use axum_core::extract::{
|
||||
DefaultBodyLimit, FromRef, FromRequest, FromRequestParts, OptionalFromRequest,
|
||||
OptionalFromRequestParts, Request,
|
||||
};
|
||||
|
||||
#[cfg(feature = "macros")]
|
||||
pub use axum_macros::{FromRef, FromRequest, FromRequestParts};
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{rejection::*, FromRequestParts};
|
||||
use super::{rejection::*, FromRequestParts, OptionalFromRequestParts};
|
||||
use async_trait::async_trait;
|
||||
use http::{request::Parts, Uri};
|
||||
use serde::de::DeserializeOwned;
|
||||
|
@ -59,6 +59,28 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T, S> OptionalFromRequestParts<S> for Query<T>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = QueryRejection;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
_state: &S,
|
||||
) -> Result<Option<Self>, Self::Rejection> {
|
||||
if let Some(query) = parts.uri.query() {
|
||||
let value = serde_urlencoded::from_str(query)
|
||||
.map_err(FailedToDeserializeQueryString::from_err)?;
|
||||
Ok(Some(Self(value)))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Query<T>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
|
|
|
@ -12,7 +12,7 @@ use anyhow::{Context, Result};
|
|||
use async_session::{MemoryStore, Session, SessionStore};
|
||||
use axum::{
|
||||
async_trait,
|
||||
extract::{FromRef, FromRequestParts, Query, State},
|
||||
extract::{FromRef, FromRequestParts, OptionalFromRequestParts, Query, State},
|
||||
http::{header::SET_COOKIE, HeaderMap},
|
||||
response::{IntoResponse, Redirect, Response},
|
||||
routing::get,
|
||||
|
@ -25,7 +25,7 @@ use oauth2::{
|
|||
ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use std::{convert::Infallible, env};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
static COOKIE_NAME: &str = "SESSION";
|
||||
|
@ -285,6 +285,25 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S> OptionalFromRequestParts<S> for User
|
||||
where
|
||||
MemoryStore: FromRef<S>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
state: &S,
|
||||
) -> Result<Option<Self>, Self::Rejection> {
|
||||
match FromRequestParts::from_request_parts(parts, state).await {
|
||||
Ok(res) => Ok(Some(res)),
|
||||
Err(AuthRedirect) => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use anyhow, define error and enable '?'
|
||||
// For a simplified example of using anyhow in axum check /examples/anyhow-error-response
|
||||
#[derive(Debug)]
|
||||
|
|
|
@ -82,13 +82,11 @@ pub struct Pagination {
|
|||
}
|
||||
|
||||
async fn todos_index(
|
||||
pagination: Option<Query<Pagination>>,
|
||||
Query(pagination): Query<Pagination>,
|
||||
State(db): State<Db>,
|
||||
) -> impl IntoResponse {
|
||||
let todos = db.read().unwrap();
|
||||
|
||||
let Query(pagination) = pagination.unwrap_or_default();
|
||||
|
||||
let todos = todos
|
||||
.values()
|
||||
.skip(pagination.offset.unwrap_or(0))
|
||||
|
|
Loading…
Add table
Reference in a new issue