Add a separate trait for optional extractors (#2475)

This commit is contained in:
Jonas Platte 2024-12-09 21:54:59 -05:00 committed by GitHub
parent fd11d8efde
commit ec75ee3827
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 306 additions and 84 deletions

View file

@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
# Unreleased
- **breaking:**: `Option<T>` as an extractor now requires `T` to implement the
new trait `OptionalFromRequest` (if used as the last extractor) or
`OptionalFromRequestParts` (other extractors) ([#2475])
[#2475]: https://github.com/tokio-rs/axum/pull/2475
# 0.5.0
## alpha.1

View file

@ -13,11 +13,16 @@ pub mod rejection;
mod default_body_limit;
mod from_ref;
mod option;
mod request_parts;
mod tuple;
pub(crate) use self::default_body_limit::DefaultBodyLimitKind;
pub use self::{default_body_limit::DefaultBodyLimit, from_ref::FromRef};
pub use self::{
default_body_limit::DefaultBodyLimit,
from_ref::FromRef,
option::{OptionalFromRequest, OptionalFromRequestParts},
};
/// Type alias for [`http::Request`] whose body type defaults to [`Body`], the most common body
/// type used with axum.
@ -102,33 +107,6 @@ where
}
}
impl<S, T> FromRequestParts<S> for Option<T>
where
T: FromRequestParts<S>,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Option<T>, Self::Rejection> {
Ok(T::from_request_parts(parts, state).await.ok())
}
}
impl<S, T> FromRequest<S> for Option<T>
where
T: FromRequest<S>,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: Request, state: &S) -> Result<Option<T>, Self::Rejection> {
Ok(T::from_request(req, state).await.ok())
}
}
impl<S, T> FromRequestParts<S> for Result<T, T::Rejection>
where
T: FromRequestParts<S>,

View file

@ -0,0 +1,63 @@
use std::future::Future;
use http::request::Parts;
use crate::response::IntoResponse;
use super::{private, FromRequest, FromRequestParts, Request};
/// Customize the behavior of `Option<Self>` as a [`FromRequestParts`]
/// extractor.
pub trait OptionalFromRequestParts<S>: Sized {
/// If the extractor fails, it will use this "rejection" type.
///
/// A rejection is a kind of error that can be converted into a response.
type Rejection: IntoResponse;
/// Perform the extraction.
fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> impl Future<Output = Result<Option<Self>, Self::Rejection>> + Send;
}
/// Customize the behavior of `Option<Self>` as a [`FromRequest`] extractor.
pub trait OptionalFromRequest<S, M = private::ViaRequest>: Sized {
/// If the extractor fails, it will use this "rejection" type.
///
/// A rejection is a kind of error that can be converted into a response.
type Rejection: IntoResponse;
/// Perform the extraction.
fn from_request(
req: Request,
state: &S,
) -> impl Future<Output = Result<Option<Self>, Self::Rejection>> + Send;
}
impl<S, T> FromRequestParts<S> for Option<T>
where
T: OptionalFromRequestParts<S>,
S: Send + Sync,
{
type Rejection = T::Rejection;
fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> impl Future<Output = Result<Option<T>, Self::Rejection>> {
T::from_request_parts(parts, state)
}
}
impl<S, T> FromRequest<S> for Option<T>
where
T: OptionalFromRequest<S>,
S: Send + Sync,
{
type Rejection = T::Rejection;
async fn from_request(req: Request, state: &S) -> Result<Option<T>, Self::Rejection> {
T::from_request(req, state).await
}
}

View file

@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning].
# Unreleased
- **breaking:** `Option<Query<T>>` no longer swallows all error conditions, instead rejecting the
request in many cases; see its documentation for details ([#2475])
- **changed:** Deprecated `OptionalPath<T>` and `OptionalQuery<T>` ([#2475])
- **fixed:** `Host` extractor includes port number when parsing authority ([#2242])
- **changed:** The `multipart` feature is no longer on by default ([#3058])
- **added:** Add `RouterExt::typed_connect` ([#2961])
@ -16,6 +19,7 @@ and this project adheres to [Semantic Versioning].
- **added:** Add `FileStream` for easy construction of file stream responses ([#3047])
[#2242]: https://github.com/tokio-rs/axum/pull/2242
[#2475]: https://github.com/tokio-rs/axum/pull/2475
[#3058]: https://github.com/tokio-rs/axum/pull/3058
[#2961]: https://github.com/tokio-rs/axum/pull/2961
[#2962]: https://github.com/tokio-rs/axum/pull/2962

View file

@ -24,9 +24,9 @@ pub mod multipart;
#[cfg(feature = "scheme")]
mod scheme;
pub use self::{
cached::Cached, host::Host, optional_path::OptionalPath, with_rejection::WithRejection,
};
#[allow(deprecated)]
pub use self::optional_path::OptionalPath;
pub use self::{cached::Cached, host::Host, with_rejection::WithRejection};
#[cfg(feature = "cookie")]
pub use self::cookie::CookieJar;
@ -41,7 +41,10 @@ pub use self::cookie::SignedCookieJar;
pub use self::form::{Form, FormRejection};
#[cfg(feature = "query")]
pub use self::query::{OptionalQuery, OptionalQueryRejection, Query, QueryRejection};
#[allow(deprecated)]
pub use self::query::OptionalQuery;
#[cfg(feature = "query")]
pub use self::query::{OptionalQueryRejection, Query, QueryRejection};
#[cfg(feature = "multipart")]
pub use self::multipart::Multipart;

View file

@ -1,5 +1,5 @@
use axum::{
extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts, Path},
extract::{rejection::PathRejection, FromRequestParts, Path},
RequestPartsExt,
};
use serde::de::DeserializeOwned;
@ -31,9 +31,11 @@ use serde::de::DeserializeOwned;
/// .route("/blog/{page}", get(render_blog));
/// # let app: Router = app;
/// ```
#[deprecated = "Use Option<Path<_>> instead"]
#[derive(Debug)]
pub struct OptionalPath<T>(pub Option<T>);
#[allow(deprecated)]
impl<T, S> FromRequestParts<S> for OptionalPath<T>
where
T: DeserializeOwned + Send + 'static,
@ -45,19 +47,15 @@ where
parts: &mut http::request::Parts,
_: &S,
) -> Result<Self, Self::Rejection> {
match parts.extract::<Path<T>>().await {
Ok(Path(params)) => Ok(Self(Some(params))),
Err(PathRejection::FailedToDeserializePathParams(e))
if matches!(e.kind(), ErrorKind::WrongNumberOfParameters { got: 0, .. }) =>
{
Ok(Self(None))
}
Err(e) => Err(e),
}
parts
.extract::<Option<Path<T>>>()
.await
.map(|opt| Self(opt.map(|Path(x)| x)))
}
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use std::num::NonZeroU32;

View file

@ -1,5 +1,5 @@
use axum::{
extract::FromRequestParts,
extract::{FromRequestParts, OptionalFromRequestParts},
response::{IntoResponse, Response},
Error,
};
@ -18,6 +18,19 @@ use std::fmt;
/// with the `multiple` attribute. Those values can be collected into a `Vec` or other sequential
/// container.
///
/// # `Option<Query<T>>` behavior
///
/// If `Query<T>` itself is used as an extractor and there is no query string in
/// the request URL, `T`'s `Deserialize` implementation is called on an empty
/// string instead.
///
/// You can avoid this by using `Option<Query<T>>`, which gives you `None` in
/// the case that there is no query string in the request URL.
///
/// Note that an empty query string is not the same as no query string, that is
/// `https://example.org/` and `https://example.org/?` are not treated the same
/// in this case.
///
/// # Example
///
/// ```rust,no_run
@ -96,6 +109,27 @@ where
}
}
impl<T, S> OptionalFromRequestParts<S> for Query<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = QueryRejection;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
if let Some(query) = parts.uri.query() {
let value = serde_html_form::from_str(query)
.map_err(|err| QueryRejection::FailedToDeserializeQueryString(Error::new(err)))?;
Ok(Some(Self(value)))
} else {
Ok(None)
}
}
}
axum_core::__impl_deref!(Query);
/// Rejection used for [`Query`].
@ -182,9 +216,11 @@ impl std::error::Error for QueryRejection {
///
/// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs
#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
#[deprecated = "Use Option<Query<_>> instead"]
#[derive(Debug, Clone, Copy, Default)]
pub struct OptionalQuery<T>(pub Option<T>);
#[allow(deprecated)]
impl<T, S> FromRequestParts<S> for OptionalQuery<T>
where
T: DeserializeOwned,
@ -204,6 +240,7 @@ where
}
}
#[allow(deprecated)]
impl<T> std::ops::Deref for OptionalQuery<T> {
type Target = Option<T>;
@ -213,6 +250,7 @@ impl<T> std::ops::Deref for OptionalQuery<T> {
}
}
#[allow(deprecated)]
impl<T> std::ops::DerefMut for OptionalQuery<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
@ -260,6 +298,7 @@ impl std::error::Error for OptionalQueryRejection {
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use super::*;
use crate::test_helpers::*;

View file

@ -1,7 +1,7 @@
//! Extractor and response for typed headers.
use axum::{
extract::FromRequestParts,
extract::{FromRequestParts, OptionalFromRequestParts},
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use headers::{Header, HeaderMapExt};
@ -78,6 +78,30 @@ where
}
}
impl<T, S> OptionalFromRequestParts<S> for TypedHeader<T>
where
T: Header,
S: Send + Sync,
{
type Rejection = TypedHeaderRejection;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
let mut values = parts.headers.get_all(T::name()).iter();
let is_missing = values.size_hint() == (0, Some(0));
match T::decode(&mut values) {
Ok(res) => Ok(Some(Self(res))),
Err(_) if is_missing => Ok(None),
Err(err) => Err(TypedHeaderRejection {
name: T::name(),
reason: TypedHeaderRejectionReason::Error(err),
}),
}
}
}
axum_core::__impl_deref!(TypedHeader);
impl<T> IntoResponseParts for TypedHeader<T>

View file

@ -8,8 +8,6 @@ struct UsersShow {
id: String,
}
async fn option_handler(_: Option<UsersShow>) {}
async fn result_handler(_: Result<UsersShow, PathRejection>) {}
#[derive(TypedPath, Deserialize)]
@ -20,7 +18,6 @@ async fn result_handler_unit_struct(_: Result<UsersIndex, StatusCode>) {}
fn main() {
_ = axum::Router::<()>::new()
.typed_get(option_handler)
.typed_post(result_handler)
.typed_post(result_handler_unit_struct);
}

View file

@ -20,11 +20,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
This allows middleware to add bodies to requests without needing to manually set `content-length` ([#2897])
- **breaking:** Remove `WebSocket::close`.
Users should explicitly send close messages themselves. ([#2974])
- **breaking:** `Option<Path<T>>` and `Option<Query<T>>` no longer swallow all error conditions,
instead rejecting the request in many cases; see their documentation for details ([#2475])
- **added:** Extend `FailedToDeserializePathParams::kind` enum with (`ErrorKind::DeserializeError`)
This new variant captures both `key`, `value`, and `message` from named path parameters parse errors,
instead of only deserialization error message in `ErrorKind::Message`. ([#2720])
- **breaking:** Make `serve` generic over the listener and IO types ([#2941])
[#2475]: https://github.com/tokio-rs/axum/pull/2475
[#2897]: https://github.com/tokio-rs/axum/pull/2897
[#2903]: https://github.com/tokio-rs/axum/pull/2903
[#2894]: https://github.com/tokio-rs/axum/pull/2894

View file

@ -116,6 +116,7 @@ features = [
[dev-dependencies]
anyhow = "1.0"
axum-extra = { path = "../axum-extra", features = ["typed-header"] }
axum-macros = { path = "../axum-macros", features = ["__private"] }
hyper = { version = "1.1.0", features = ["client"] }
quickcheck = "1.0"

View file

@ -200,33 +200,11 @@ async fn handler(
axum enforces this by requiring the last extractor implements [`FromRequest`]
and all others implement [`FromRequestParts`].
# Optional extractors
# Handling extractor rejections
All extractors defined in axum will reject the request if it doesn't match.
If you wish to make an extractor optional you can wrap it in `Option`:
```rust,no_run
use axum::{
extract::Json,
routing::post,
Router,
};
use serde_json::Value;
async fn create_user(payload: Option<Json<Value>>) {
if let Some(payload) = payload {
// We got a valid JSON payload
} else {
// Payload wasn't valid JSON
}
}
let app = Router::new().route("/users", post(create_user));
# let _: Router = app;
```
Wrapping extractors in `Result` makes them optional and gives you the reason
the extraction failed:
If you want to handle the case of an extractor failing within a specific
handler, you can wrap it in `Result`, with the error being the rejection type
of the extractor:
```rust,no_run
use axum::{
@ -265,10 +243,33 @@ let app = Router::new().route("/users", post(create_user));
# let _: Router = app;
```
Another option is to make use of the optional extractors in [axum-extra] that
either returns `None` if there are no query parameters in the request URI,
or returns `Some(T)` if deserialization was successful.
If the deserialization was not successful, the request is rejected.
# Optional extractors
Some extractors implement [`OptionalFromRequestParts`] in addition to
[`FromRequestParts`], or [`OptionalFromRequest`] in addition to [`FromRequest`].
These extractors can be used inside of `Option`. It depends on the particular
`OptionalFromRequestParts` or `OptionalFromRequest` implementation what this
does: For example for `TypedHeader` from axum-extra, you get `None` if the
header you're trying to extract is not part of the request, but if the header
is present and fails to parse, the request is rejected.
```rust,no_run
use axum::{routing::post, Router};
use axum_extra::{headers::UserAgent, TypedHeader};
use serde_json::Value;
async fn foo(user_agent: Option<TypedHeader<UserAgent>>) {
if let Some(TypedHeader(user_agent)) = user_agent {
// The client sent a user agent
} else {
// No user agent header
}
}
let app = Router::new().route("/foo", post(foo));
# let _: Router = app;
```
# Customizing extractor responses

View file

@ -1,7 +1,8 @@
use super::{rejection::*, FromRequestParts};
use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE};
use axum_core::extract::OptionalFromRequestParts;
use http::request::Parts;
use std::{collections::HashMap, sync::Arc};
use std::{collections::HashMap, convert::Infallible, sync::Arc};
/// Access the path in the router that matches the request.
///
@ -79,6 +80,20 @@ where
}
}
impl<S> OptionalFromRequestParts<S> for MatchedPath
where
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
Ok(parts.extensions.get::<Self>().cloned())
}
}
#[derive(Clone, Debug)]
struct MatchedNestedPath(Arc<str>);

View file

@ -17,7 +17,10 @@ mod request_parts;
mod state;
#[doc(inline)]
pub use axum_core::extract::{DefaultBodyLimit, FromRef, FromRequest, FromRequestParts, Request};
pub use axum_core::extract::{
DefaultBodyLimit, FromRef, FromRequest, FromRequestParts, OptionalFromRequest,
OptionalFromRequestParts, Request,
};
#[cfg(feature = "macros")]
pub use axum_macros::{FromRef, FromRequest, FromRequestParts};

View file

@ -8,7 +8,11 @@ use crate::{
routing::url_params::UrlParams,
util::PercentDecodedStr,
};
use axum_core::response::{IntoResponse, Response};
use axum_core::{
extract::OptionalFromRequestParts,
response::{IntoResponse, Response},
RequestPartsExt as _,
};
use http::{request::Parts, StatusCode};
use serde::de::DeserializeOwned;
use std::{fmt, sync::Arc};
@ -20,6 +24,12 @@ use std::{fmt, sync::Arc};
/// parameters must be valid UTF-8, otherwise `Path` will fail and return a `400
/// Bad Request` response.
///
/// # `Option<Path<T>>` behavior
///
/// You can use `Option<Path<T>>` as an extractor to allow the same handler to
/// be used in a route with parameters that deserialize to `T`, and another
/// route with no parameters at all.
///
/// # Example
///
/// These examples assume the `serde` feature of the [`uuid`] crate is enabled.
@ -176,6 +186,29 @@ where
}
}
impl<T, S> OptionalFromRequestParts<S> for Path<T>
where
T: DeserializeOwned + Send + 'static,
S: Send + Sync,
{
type Rejection = PathRejection;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
match parts.extract::<Self>().await {
Ok(Self(params)) => Ok(Some(Self(params))),
Err(PathRejection::FailedToDeserializePathParams(e))
if matches!(e.kind(), ErrorKind::WrongNumberOfParameters { got: 0, .. }) =>
{
Ok(None)
}
Err(e) => Err(e),
}
}
}
// this wrapper type is used as the deserializer error to hide the `serde::de::Error` impl which
// would otherwise be public if we used `ErrorKind` as the error directly
#[derive(Debug)]

View file

@ -1,4 +1,4 @@
use super::{rejection::*, FromRequestParts};
use super::{rejection::*, FromRequestParts, OptionalFromRequestParts};
use http::{request::Parts, Uri};
use serde::de::DeserializeOwned;
@ -6,7 +6,20 @@ use serde::de::DeserializeOwned;
///
/// `T` is expected to implement [`serde::Deserialize`].
///
/// # Example
/// # `Option<Query<T>>` behavior
///
/// If `Query<T>` itself is used as an extractor and there is no query string in
/// the request URL, `T`'s `Deserialize` implementation is called on an empty
/// string instead.
///
/// You can avoid this by using `Option<Query<T>>`, which gives you `None` in
/// the case that there is no query string in the request URL.
///
/// Note that an empty query string is not the same as no query string, that is
/// `https://example.org/` and `https://example.org/?` are not treated the same
/// in this case.
///
/// # Examples
///
/// ```rust,no_run
/// use axum::{
@ -62,6 +75,27 @@ where
}
}
impl<T, S> OptionalFromRequestParts<S> for Query<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = QueryRejection;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
if let Some(query) = parts.uri.query() {
let value = serde_urlencoded::from_str(query)
.map_err(FailedToDeserializeQueryString::from_err)?;
Ok(Some(Self(value)))
} else {
Ok(None)
}
}
}
impl<T> Query<T>
where
T: DeserializeOwned,

View file

@ -11,7 +11,7 @@
use anyhow::{anyhow, Context, Result};
use async_session::{MemoryStore, Session, SessionStore};
use axum::{
extract::{FromRef, FromRequestParts, Query, State},
extract::{FromRef, FromRequestParts, OptionalFromRequestParts, Query, State},
http::{header::SET_COOKIE, HeaderMap},
response::{IntoResponse, Redirect, Response},
routing::get,
@ -24,7 +24,7 @@ use oauth2::{
ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
};
use serde::{Deserialize, Serialize};
use std::env;
use std::{convert::Infallible, env};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
static COOKIE_NAME: &str = "SESSION";
@ -351,6 +351,24 @@ where
}
}
impl<S> OptionalFromRequestParts<S> for User
where
MemoryStore: FromRef<S>,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut Parts,
state: &S,
) -> Result<Option<Self>, Self::Rejection> {
match <User as FromRequestParts<S>>::from_request_parts(parts, state).await {
Ok(res) => Ok(Some(res)),
Err(AuthRedirect) => Ok(None),
}
}
}
// Use anyhow, define error and enable '?'
// For a simplified example of using anyhow in axum check /examples/anyhow-error-response
#[derive(Debug)]