Add OptionalQuery extractor (#2310)

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
This commit is contained in:
Mikhail Antoshkin 2023-11-18 21:38:30 +09:00 committed by GitHub
parent 6e984b754a
commit 39cc596e45
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 205 additions and 2 deletions

View file

@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning].
# Unreleased
- **added:** `OptionalQuery` extractor ([#2310])
- **added:** `TypedHeader` which used to be in `axum` ([#1850])
- **added:** `Clone` implementation for `ErasedJson` ([#2142])
- **breaking:** Update to prost 0.12. Used for the `Protobuf` extractor
@ -14,6 +15,7 @@ and this project adheres to [Semantic Versioning].
[#1850]: https://github.com/tokio-rs/axum/pull/1850
[#2142]: https://github.com/tokio-rs/axum/pull/2142
[#2310]: https://github.com/tokio-rs/axum/pull/2310
# 0.7.4 (18. April, 2023)

View file

@ -31,7 +31,7 @@ pub use self::cookie::SignedCookieJar;
pub use self::form::{Form, FormRejection};
#[cfg(feature = "query")]
pub use self::query::{Query, QueryRejection};
pub use self::query::{OptionalQuery, OptionalQueryRejection, Query, QueryRejection};
#[cfg(feature = "multipart")]
pub use self::multipart::Multipart;

View file

@ -112,6 +112,124 @@ impl std::error::Error for QueryRejection {
}
}
/// Extractor that deserializes query strings into `None` if no query parameters are present.
/// Otherwise behaviour is identical to [`Query`]
///
/// `T` is expected to implement [`serde::Deserialize`].
///
/// # Example
///
/// ```rust,no_run
/// use axum::{routing::get, Router};
/// use axum_extra::extract::OptionalQuery;
/// use serde::Deserialize;
///
/// #[derive(Deserialize)]
/// struct Pagination {
/// page: usize,
/// per_page: usize,
/// }
///
/// // This will parse query strings like `?page=2&per_page=30` into `Some(Pagination)` and
/// // empty query string into `None`
/// async fn list_things(OptionalQuery(pagination): OptionalQuery<Pagination>) {
/// match pagination {
/// Some(Pagination{ page, per_page }) => { /* return specified page */ },
/// None => { /* return fist page */ }
/// }
/// // ...
/// }
///
/// let app = Router::new().route("/list_things", get(list_things));
/// # let _: Router = app;
/// ```
///
/// If the query string cannot be parsed it will reject the request with a `400
/// Bad Request` response.
///
/// For handling values being empty vs missing see the [query-params-with-empty-strings][example]
/// example.
///
/// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs
#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
#[derive(Debug, Clone, Copy, Default)]
pub struct OptionalQuery<T>(pub Option<T>);
#[async_trait]
impl<T, S> FromRequestParts<S> for OptionalQuery<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = OptionalQueryRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(query) = parts.uri.query() {
let value = serde_html_form::from_str(query).map_err(|err| {
OptionalQueryRejection::FailedToDeserializeQueryString(Error::new(err))
})?;
Ok(OptionalQuery(Some(value)))
} else {
Ok(OptionalQuery(None))
}
}
}
impl<T> std::ops::Deref for OptionalQuery<T> {
type Target = Option<T>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> std::ops::DerefMut for OptionalQuery<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
/// Rejection used for [`OptionalQuery`].
///
/// Contains one variant for each way the [`OptionalQuery`] extractor can fail.
#[derive(Debug)]
#[non_exhaustive]
#[cfg(feature = "query")]
pub enum OptionalQueryRejection {
#[allow(missing_docs)]
FailedToDeserializeQueryString(Error),
}
impl IntoResponse for OptionalQueryRejection {
fn into_response(self) -> Response {
match self {
Self::FailedToDeserializeQueryString(inner) => (
StatusCode::BAD_REQUEST,
format!("Failed to deserialize query string: {inner}"),
)
.into_response(),
}
}
}
impl fmt::Display for OptionalQueryRejection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::FailedToDeserializeQueryString(inner) => inner.fmt(f),
}
}
}
impl std::error::Error for OptionalQueryRejection {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::FailedToDeserializeQueryString(inner) => Some(inner),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -121,7 +239,7 @@ mod tests {
use serde::Deserialize;
#[tokio::test]
async fn supports_multiple_values() {
async fn query_supports_multiple_values() {
#[derive(Deserialize)]
struct Data {
#[serde(rename = "value")]
@ -145,4 +263,87 @@ mod tests {
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "one,two");
}
#[tokio::test]
async fn optional_query_supports_multiple_values() {
#[derive(Deserialize)]
struct Data {
#[serde(rename = "value")]
values: Vec<String>,
}
let app = Router::new().route(
"/",
post(|OptionalQuery(data): OptionalQuery<Data>| async move {
data.map(|Data { values }| values.join(","))
.unwrap_or("None".to_owned())
}),
);
let client = TestClient::new(app);
let res = client
.post("/?value=one&value=two")
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.body("")
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "one,two");
}
#[tokio::test]
async fn optional_query_deserializes_no_parameters_into_none() {
#[derive(Deserialize)]
struct Data {
value: String,
}
let app = Router::new().route(
"/",
post(|OptionalQuery(data): OptionalQuery<Data>| async move {
match data {
None => "None".into(),
Some(data) => data.value,
}
}),
);
let client = TestClient::new(app);
let res = client.post("/").body("").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "None");
}
#[tokio::test]
async fn optional_query_preserves_parsing_errors() {
#[derive(Deserialize)]
struct Data {
value: String,
}
let app = Router::new().route(
"/",
post(|OptionalQuery(data): OptionalQuery<Data>| async move {
match data {
None => "None".into(),
Some(data) => data.value,
}
}),
);
let client = TestClient::new(app);
let res = client
.post("/?other=something")
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.body("")
.send()
.await;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
}