diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index c8e2d219..157bf404 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -13,11 +13,16 @@ pub mod rejection; mod default_body_limit; mod from_ref; +mod option; mod request_parts; mod tuple; pub(crate) use self::default_body_limit::DefaultBodyLimitKind; -pub use self::{default_body_limit::DefaultBodyLimit, from_ref::FromRef}; +pub use self::{ + default_body_limit::DefaultBodyLimit, + from_ref::FromRef, + option::{OptionalFromRequest, OptionalFromRequestParts}, +}; /// Type alias for [`http::Request`] whose body type defaults to [`Body`], the most common body /// type used with axum. @@ -99,35 +104,6 @@ where } } -#[async_trait] -impl<S, T> FromRequestParts<S> for Option<T> -where - T: FromRequestParts<S>, - S: Send + Sync, -{ - type Rejection = Infallible; - - async fn from_request_parts( - parts: &mut Parts, - state: &S, - ) -> Result<Option<T>, Self::Rejection> { - Ok(T::from_request_parts(parts, state).await.ok()) - } -} - -#[async_trait] -impl<S, T> FromRequest<S> for Option<T> -where - T: FromRequest<S>, - S: Send + Sync, -{ - type Rejection = Infallible; - - async fn from_request(req: Request, state: &S) -> Result<Option<T>, Self::Rejection> { - Ok(T::from_request(req, state).await.ok()) - } -} - #[async_trait] impl<S, T> FromRequestParts<S> for Result<T, T::Rejection> where diff --git a/axum-core/src/extract/option.rs b/axum-core/src/extract/option.rs new file mode 100644 index 00000000..7cc93113 --- /dev/null +++ b/axum-core/src/extract/option.rs @@ -0,0 +1,60 @@ +use async_trait::async_trait; +use http::request::Parts; + +use crate::response::IntoResponse; + +use super::{private, FromRequest, FromRequestParts, Request}; + +/// TODO: DOCS +#[async_trait] +pub trait OptionalFromRequestParts<S>: Sized { + /// If the extractor fails it'll use this "rejection" type. A rejection is + /// a kind of error that can be converted into a response. + type Rejection: IntoResponse; + + /// Perform the extraction. + async fn from_request_parts( + parts: &mut Parts, + state: &S, + ) -> Result<Option<Self>, Self::Rejection>; +} + +/// TODO: DOCS +#[async_trait] +pub trait OptionalFromRequest<S, M = private::ViaRequest>: Sized { + /// If the extractor fails it'll use this "rejection" type. A rejection is + /// a kind of error that can be converted into a response. + type Rejection: IntoResponse; + + /// Perform the extraction. + async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection>; +} + +#[async_trait] +impl<S, T> FromRequestParts<S> for Option<T> +where + T: OptionalFromRequestParts<S>, + S: Send + Sync, +{ + type Rejection = T::Rejection; + + async fn from_request_parts( + parts: &mut Parts, + state: &S, + ) -> Result<Option<T>, Self::Rejection> { + T::from_request_parts(parts, state).await + } +} + +#[async_trait] +impl<S, T> FromRequest<S> for Option<T> +where + T: OptionalFromRequest<S>, + S: Send + Sync, +{ + type Rejection = T::Rejection; + + async fn from_request(req: Request, state: &S) -> Result<Option<T>, Self::Rejection> { + T::from_request(req, state).await + } +} diff --git a/axum-extra/src/typed_header.rs b/axum-extra/src/typed_header.rs index f56f20b1..441fcb1d 100644 --- a/axum-extra/src/typed_header.rs +++ b/axum-extra/src/typed_header.rs @@ -2,7 +2,7 @@ use axum::{ async_trait, - extract::FromRequestParts, + extract::{FromRequestParts, OptionalFromRequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use headers::{Header, HeaderMapExt}; @@ -80,6 +80,31 @@ where } } +#[async_trait] +impl<T, S> OptionalFromRequestParts<S> for TypedHeader<T> +where + T: Header, + S: Send + Sync, +{ + type Rejection = TypedHeaderRejection; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result<Option<Self>, Self::Rejection> { + let mut values = parts.headers.get_all(T::name()).iter(); + let is_missing = values.size_hint() == (0, Some(0)); + match T::decode(&mut values) { + Ok(res) => Ok(Some(Self(res))), + Err(_) if is_missing => Ok(None), + Err(err) => Err(TypedHeaderRejection { + name: T::name(), + reason: TypedHeaderRejectionReason::Error(err), + }), + } + } +} + axum_core::__impl_deref!(TypedHeader); impl<T> IntoResponseParts for TypedHeader<T> diff --git a/axum/Cargo.toml b/axum/Cargo.toml index da8f5ef2..deac4a51 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -114,6 +114,7 @@ rustversion = "1.0.9" [dev-dependencies] anyhow = "1.0" +axum-extra = { path = "../axum-extra", features = ["typed-header"] } axum-macros = { path = "../axum-macros", version = "0.4.0", features = ["__private"] } quickcheck = "1.0" quickcheck_macros = "1.0" diff --git a/axum/src/docs/extract.md b/axum/src/docs/extract.md index 13d27171..f382e677 100644 --- a/axum/src/docs/extract.md +++ b/axum/src/docs/extract.md @@ -219,26 +219,22 @@ and all others implement [`FromRequestParts`]. # Optional extractors -All extractors defined in axum will reject the request if it doesn't match. -If you wish to make an extractor optional you can wrap it in `Option`: +TODO: Docs, more realistic example ```rust,no_run -use axum::{ - extract::Json, - routing::post, - Router, -}; +use axum::{routing::post, Router}; +use axum_extra::{headers::UserAgent, TypedHeader}; use serde_json::Value; -async fn create_user(payload: Option<Json<Value>>) { - if let Some(payload) = payload { - // We got a valid JSON payload +async fn foo(user_agent: Option<TypedHeader<UserAgent>>) { + if let Some(TypedHeader(user_agent)) = user_agent { + // The client sent a user agent } else { - // Payload wasn't valid JSON + // No user agent header } } -let app = Router::new().route("/users", post(create_user)); +let app = Router::new().route("/foo", post(foo)); # let _: Router = app; ``` diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index 6ac0397c..28e14e61 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -1,6 +1,7 @@ use super::{rejection::*, FromRequestParts}; use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE}; use async_trait::async_trait; +use axum_core::extract::OptionalFromRequestParts; use http::request::Parts; use std::{collections::HashMap, sync::Arc}; @@ -81,6 +82,21 @@ where } } +#[async_trait] +impl<S> OptionalFromRequestParts<S> for MatchedPath +where + S: Send + Sync, +{ + type Rejection = MatchedPathRejection; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result<Option<Self>, Self::Rejection> { + Ok(parts.extensions.get::<Self>().cloned()) + } +} + #[derive(Clone, Debug)] struct MatchedNestedPath(Arc<str>); diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index c02bc6f0..29ca9657 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -18,7 +18,10 @@ mod request_parts; mod state; #[doc(inline)] -pub use axum_core::extract::{DefaultBodyLimit, FromRef, FromRequest, FromRequestParts, Request}; +pub use axum_core::extract::{ + DefaultBodyLimit, FromRef, FromRequest, FromRequestParts, OptionalFromRequest, + OptionalFromRequestParts, Request, +}; #[cfg(feature = "macros")] pub use axum_macros::{FromRef, FromRequest, FromRequestParts}; diff --git a/axum/src/extract/query.rs b/axum/src/extract/query.rs index a0c0f77c..619ce163 100644 --- a/axum/src/extract/query.rs +++ b/axum/src/extract/query.rs @@ -1,4 +1,4 @@ -use super::{rejection::*, FromRequestParts}; +use super::{rejection::*, FromRequestParts, OptionalFromRequestParts}; use async_trait::async_trait; use http::{request::Parts, Uri}; use serde::de::DeserializeOwned; @@ -59,6 +59,28 @@ where } } +#[async_trait] +impl<T, S> OptionalFromRequestParts<S> for Query<T> +where + T: DeserializeOwned, + S: Send + Sync, +{ + type Rejection = QueryRejection; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result<Option<Self>, Self::Rejection> { + if let Some(query) = parts.uri.query() { + let value = serde_urlencoded::from_str(query) + .map_err(FailedToDeserializeQueryString::from_err)?; + Ok(Some(Self(value))) + } else { + Ok(None) + } + } +} + impl<T> Query<T> where T: DeserializeOwned, diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index 6ceffe8f..4cfcaeb6 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -12,7 +12,7 @@ use anyhow::{Context, Result}; use async_session::{MemoryStore, Session, SessionStore}; use axum::{ async_trait, - extract::{FromRef, FromRequestParts, Query, State}, + extract::{FromRef, FromRequestParts, OptionalFromRequestParts, Query, State}, http::{header::SET_COOKIE, HeaderMap}, response::{IntoResponse, Redirect, Response}, routing::get, @@ -25,7 +25,7 @@ use oauth2::{ ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl, }; use serde::{Deserialize, Serialize}; -use std::env; +use std::{convert::Infallible, env}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; static COOKIE_NAME: &str = "SESSION"; @@ -285,6 +285,25 @@ where } } +#[async_trait] +impl<S> OptionalFromRequestParts<S> for User +where + MemoryStore: FromRef<S>, + S: Send + Sync, +{ + type Rejection = Infallible; + + async fn from_request_parts( + parts: &mut Parts, + state: &S, + ) -> Result<Option<Self>, Self::Rejection> { + match FromRequestParts::from_request_parts(parts, state).await { + Ok(res) => Ok(Some(res)), + Err(AuthRedirect) => Ok(None), + } + } +} + // Use anyhow, define error and enable '?' // For a simplified example of using anyhow in axum check /examples/anyhow-error-response #[derive(Debug)] diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs index 2fdac41b..17da3670 100644 --- a/examples/todos/src/main.rs +++ b/examples/todos/src/main.rs @@ -82,13 +82,11 @@ pub struct Pagination { } async fn todos_index( - pagination: Option<Query<Pagination>>, + Query(pagination): Query<Pagination>, State(db): State<Db>, ) -> impl IntoResponse { let todos = db.read().unwrap(); - let Query(pagination) = pagination.unwrap_or_default(); - let todos = todos .values() .skip(pagination.offset.unwrap_or(0))